Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SageMaker Batch currently doesn't support Model entity with container definitions which use ModelDataSource attribute #4777

Open
windson opened this issue Jul 8, 2024 · 0 comments

Comments

@windson
Copy link

windson commented Jul 8, 2024

Describe the feature you'd like

Batch Transform deployment to support ModelDataSource for LLM batch transform operations.

How would this feature be used? Please describe.
A clear and concise description of the use case for this feature. Please provide an example, if possible.

import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id() 

from huggingface_hub import snapshot_download
from pathlib import Path
import os

# - This will download the model into the current directory where ever the jupyter notebook is running
local_model_path = Path(".")
local_model_path.mkdir(exist_ok=True)
model_name = 'mistralai/Mistral-7B-v0.1'
# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.txt", "*.model", "*.safetensors", "*.bin", "*.chk", "*.pth"]

# - Leverage the snapshot library to donload the model since the model is stored in repository using LFS
model_download_path = snapshot_download(
    repo_id=model_name, 
    cache_dir=local_model_path, 
    allow_patterns=allow_patterns, 
    token='<HF TOKEN>'
)
%%writefile {model_download_path}/serving.properties
engine=Python
option.tensor_parallel_degree=max
option.model_id={{model_id}}
option.max_rolling_batch_size=16
option.rolling_batch=vllm
import jinja2
from pathlib import Path
jinja_env = jinja2.Environment()
template = jinja_env.from_string(Path("serving.properties").open().read())
Path("serving.properties").open("w").write(
    template.render(model_id=base_model_s3_uri)
)
base_model_s3_uri = sess.upload_data(path=model_download_path, key_prefix="batch-transform-mistral/model")
print(f"Model uploaded to --- > {base_model_s3_uri}")

#https://github.com/aws/sagemaker-python-sdk/blob/master/tests/unit/test_djl_inference.py#L31-L33
image_uri = image_uris.retrieve(
        framework="djl-lmi",
        region=region,
        version="0.28.0"
    )

model_data = {
    "S3DataSource": {
        "S3Uri": f'{base_model_s3_uri}/',
        'S3DataType': 'S3Prefix',
        'CompressionType': 'None'
    }
}

# create your SageMaker Model
model = sagemaker.Model(image_uri=image_uri, model_data=model_data, role=role)

from sagemaker.utils import name_from_base

endpoint_name = name_from_base("lmi-batch-transform-mistral-gated")

# instance type you will deploy your model to
instance_type = "ml.g5.12xlarge"

# Creating the batch transformer object. If you have a large dataset you can
# divide it into smaller chunks and use more instances for faster inference
batch_transformer = model.transformer(
    instance_count=1,
    instance_type=instance_type,
    output_path=s3_output_data_path,
    assemble_with="Line",
    accept="text/csv",
    max_payload=1,
)
batch_transformer.env = hyper_params_dict

# Making the predictions on the input data
batch_transformer.transform(
    s3_input_data_path, content_type="application/jsonlines", split_type="Line"
)

batch_transformer.wait()

This throws the error:

---------------------------------------------------------------------------
ClientError                               Traceback (most recent call last)
Cell In[36], line 14
     11 batch_transformer.env = hyper_params_dict
     13 # Making the predictions on the input data
---> 14 batch_transformer.transform(
     15     s3_input_data_path, content_type="application[/jsonlines](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/jsonlines)", split_type="Line"
     16 )
     18 batch_transformer.wait()

File [/opt/conda/lib/python3.10/site-packages/sagemaker/workflow/pipeline_context.py:346](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/workflow/pipeline_context.py#line=345), in runnable_by_pipeline.<locals>.wrapper(*args, **kwargs)
    342         return context
    344     return _StepArguments(retrieve_caller_name(self_instance), run_func, *args, **kwargs)
--> 346 return run_func(*args, **kwargs)

File [/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py:302](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py#line=301), in Transformer.transform(self, data, data_type, content_type, compression_type, split_type, job_name, input_filter, output_filter, join_source, experiment_config, model_client_config, batch_data_capture_config, wait, logs)
    292 experiment_config = check_and_get_run_experiment_config(experiment_config)
    294 batch_data_capture_config = resolve_class_attribute_from_config(
    295     None,
    296     batch_data_capture_config,
   (...)
    299     sagemaker_session=self.sagemaker_session,
    300 )
--> 302 self.latest_transform_job = _TransformJob.start_new(
    303     self,
    304     data,
    305     data_type,
    306     content_type,
    307     compression_type,
    308     split_type,
    309     input_filter,
    310     output_filter,
    311     join_source,
    312     experiment_config,
    313     model_client_config,
    314     batch_data_capture_config,
    315 )
    317 if wait:
    318     self.latest_transform_job.wait(logs=logs)

File [/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py:636](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/transformer.py#line=635), in _TransformJob.start_new(cls, transformer, data, data_type, content_type, compression_type, split_type, input_filter, output_filter, join_source, experiment_config, model_client_config, batch_data_capture_config)
    619 """Placeholder docstring"""
    621 transform_args = cls._get_transform_args(
    622     transformer,
    623     data,
   (...)
    633     batch_data_capture_config,
    634 )
--> 636 transformer.sagemaker_session.transform(**transform_args)
    638 return cls(transformer.sagemaker_session, transformer._current_job_name)

File [/opt/conda/lib/python3.10/site-packages/sagemaker/session.py:3805](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/session.py#line=3804), in Session.transform(self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, input_config, output_config, resource_config, experiment_config, env, tags, data_processing, model_client_config, batch_data_capture_config)
   3802     logger.debug("Transform request: %s", json.dumps(request, indent=4))
   3803     self.sagemaker_client.create_transform_job(**request)
-> 3805 self._intercept_create_request(transform_request, submit, self.transform.__name__)

File [/opt/conda/lib/python3.10/site-packages/sagemaker/session.py:6497](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/session.py#line=6496), in Session._intercept_create_request(self, request, create, func_name)
   6480 def _intercept_create_request(
   6481     self,
   6482     request: typing.Dict,
   (...)
   6485     # pylint: disable=unused-argument
   6486 ):
   6487     """This function intercepts the create job request.
   6488 
   6489     PipelineSession inherits this Session class and will override
   (...)
   6495         func_name (str): the name of the function needed intercepting
   6496     """
-> 6497     return create(request)

File [/opt/conda/lib/python3.10/site-packages/sagemaker/session.py:3803](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/sagemaker/session.py#line=3802), in Session.transform.<locals>.submit(request)
   3801 logger.info("Creating transform job with name: %s", job_name)
   3802 logger.debug("Transform request: %s", json.dumps(request, indent=4))
-> 3803 self.sagemaker_client.create_transform_job(**request)

File [/opt/conda/lib/python3.10/site-packages/botocore/client.py:565](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/botocore/client.py#line=564), in ClientCreator._create_api_method.<locals>._api_call(self, *args, **kwargs)
    561     raise TypeError(
    562         f"{py_operation_name}() only accepts keyword arguments."
    563     )
    564 # The "self" in this scope is referring to the BaseClient.
--> 565 return self._make_api_call(operation_name, kwargs)

File [/opt/conda/lib/python3.10/site-packages/botocore/client.py:1021](https://0sbsckb2h3hamfx.studio.us-west-2.sagemaker.aws/opt/conda/lib/python3.10/site-packages/botocore/client.py#line=1020), in BaseClient._make_api_call(self, operation_name, api_params)
   1017     error_code = error_info.get("QueryErrorCode") or error_info.get(
   1018         "Code"
   1019     )
   1020     error_class = self.exceptions.from_code(error_code)
-> 1021     raise error_class(parsed_response, operation_name)
   1022 else:
   1023     return parsed_response

ClientError: An error occurred (ValidationException) when calling the CreateTransformJob operation: SageMaker Batch currently doesn't support Model entity with container definitions which use ModelDataSource attribute

Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.

Additional context
Add any other context or screenshots about the feature request here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant