Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchServe inference.py incompatible with default PyTorch inference Transformer for UTF-8 Content Types #4869

Open
dillon-odonovan opened this issue Sep 12, 2024 · 0 comments

Comments

@dillon-odonovan
Copy link

Describe the bug
When running a SageMaker container using the default PyTorch inference Transformer, when specifying a UTF-8 Content-Type (application/json, text/csv), the TorchServe inference.py implementation will throw an error during de-serialization within input_fn. This is because the TorchServe inference input_fn function expects the input_data to be a bytes-like object, but it has already been decoded to a str by the Transformer. The NumpyDeserializer does support de-serializing from UTF-8 Content Types, but the code is effectively unreachable for input processing (can still be reached for output) without overriding the default Inference Handler / Handler Service / Transformer (transformer can't be specified if input_fn is specified).

The TorchServe inference.py script was implemented in #4662

With Python clients, or using the Predictor class from the SageMaker SDK, this is easily worked around. However, if trying to make predictions from other languages, such as Java, this is much more difficult as a JSON representation of the inference input cannot be provided, and custom serialization to match the NPY format would be necessary.

This is just one example use case - the issue may be applicable for different input beyond Numpy arrays / scikit-learn algorithms. Ownership of fix could lie either in the SageMaker python SDK, within the sagemaker-pytorch-inference repository,
or elsewhere. A change to any of these components could run the risk of impacting production behavior which clients may be reliant on.

To reproduce

import boto3

import io

import mlflow
from mlflow import MlflowClient
from mlflow.models import infer_signature

import numpy as np

import pandas as pd

from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split

X, y = datasets.load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

params = {
    "solver": "lbfgs",
    "max_iter": 1000,
    "multi_class": "auto",
    "random_state": 8888
}

lr = LogisticRegression(**params)
lr.fit(X_train, y_train)
y_pred = lr.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)

mlflow.set_tracking_uri(os.environ['MLFLOW_URI'])

with mlflow.start_run() as run:
    mlflow.log_params(params)
    mlflow.log_metric('accuracy', accuracy)
    mlflow.set_tag('Training Info', 'Basic LR model for iris data')
    signature = infer_signature(X_train, lr.predict(X_train))
    model_info = mlflow.sklearn.log_model(
        sk_model=lr,
        artifact_path='sklearn-model',
        signature=signature,
        input_example=X_train,
        registered_model_name='tracking-quickstart'
    )
model_uri = f'runs:/{run.info.run_id}/sklearn-model'
schema_builder = SchemaBuilder(sample_input=X_train, sample_output=y_pred)
model_builder = ModelBuilder(
    mode=Mode.SAGEMAKER_ENDPOINT,
    schema_builder=schema_builder,
    role_arn=os.environ['ROLE_ARN'],
    model_metadata={
        "MLFLOW_MODEL_PATH": model_uri,
        "MLFLOW_TRACKING_ARN": os.environ['MLFLOW_TRACKING_SERVER_ARN']
    }
)
model = model_builder.build()
predictor = model.deploy(initial_instance_count=1, instance_type="ml.t2.medium")
predictor.predict(X_test) # works as expected

sagemaker_runtime_client = boto3.client('sagemaker-runtime')

# works as expected:
buffer = io.BytesIO()
np.save(buffer, X_test)
sagemaker_runtime_client.invoke_endpoint(
    EndpointName=predictor.endpoint_name,
    Body=buffer.getvalue(),
    ContentType='application/x-npy'
)
predictions = np.load(io.BytesIO(invoke_response['Body'].read()))

# does not work as expected;  
json_body = json.dumps(X_test.tolist()).encode('utf-8')
invoke_response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=predictor.endpoint_name,
    Body=json_body,
    ContentType='application/json'
)

ModelError: An error occurred (ModelError) when calling the InvokeEndpoint operation: Received server error (500) from primary with message "<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">
<title>500 Internal Server Error<[/title](https://###REDACTED###.studio.us-west-2.sagemaker.aws/title)>
<h1>Internal Server Error<[/h1](https://###REDACTED###.studio.us-west-2.sagemaker.aws/h1)>
<p>The server encountered an internal error and was unable to complete your request. Either the server is overloaded or there is an error in the application.<[/p](https://###REDACTED###.studio.us-west-2.sagemaker.aws/p)>
". See https://us-west-2.console.aws.amazon.com/cloudwatch/home?region=us-west-2#logEventViewer:group=###REDACTED### in account ###REDACTED### for more information.

Along similar lines (can open separate issue if applicable) - it doesn't seem as though requesting the response as JSON via the Accept header works. This is perhaps expected, though is only evident upon attempting to de-serialize the returned stream:

import codecs

invoke_response_json_resp = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=predictor.endpoint_name,
    Body=buffer.getvalue(),
    ContentType='application/x-npy',
    Accept='application/json'
)

reader = codecs.getreader('utf-8')
json_response = reader(invoke_response_json_resp['Body'])
json.load(json_response)

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x93 in position 0: invalid start byte

Whereas the below works:

invoke_response_json_resp = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=predictor.endpoint_name,
    Body=buffer.getvalue(),
    ContentType='application/x-npy',
    Accept='application/json'
)
np.load(io.BytesIO(invoke_response_json_resp['Body'].read())) # evidently the response stream is not JSON

In general, the error messaging during serialization/de-serialization is unhelpful/misleading, as it suggests the (de-)serialization failed for pickled data, which is not always the case.

Expected behavior
I expect to be able to invoke the SageMaker endpoints with a JSON-serialized Numpy array and receive NPY response.

Screenshots or logs

2024-09-11T23:52:57.493Z    IP - - [11/Sep/2024:23:52:55 +0000] "POST /invocations HTTP/1.1" 200 368 "-" "AHC/2.0"
2024-09-11T23:55:30.402Z    2024-09-11 23:55:30,305 ERROR - inference - Exception on /invocations [POST]
2024-09-11T23:55:30.402Z    Traceback (most recent call last): 
    File "/opt/ml/code/inference.py", line 74, in input_fn 
        io.BytesIO(input_data), content_type[0]
2024-09-11T23:55:30.402Z    TypeError: a bytes-like object is required, not 'str'
2024-09-11T23:55:30.402Z    The above exception was the direct cause of the following exception:
2024-09-11T23:55:30.402Z    Traceback (most recent call last): 
    File "/miniconda3/lib/python3.8/site-packages/sagemaker_containers/_functions.py", line 93, in wrapper 
        return fn(*args, **kwargs) 
    File "/opt/ml/code/inference.py", line 77, in input_fn 
        raise Exception("Encountered error in deserialize_request.") from e
2024-09-11T23:55:30.402Z    Exception: Encountered error in deserialize_request.
2024-09-11T23:55:30.402Z    IP - - [11/Sep/2024:23:55:30 +0000] "POST /invocations HTTP/1.1" 500 290 "-" "AHC/2.0"
2024-09-11T23:55:30.402Z    During handling of the above exception, another exception occurred:
2024-09-11T23:55:30.402Z    Traceback (most recent call last): 
    File "/miniconda3/lib/python3.8/site-packages/flask/app.py", line 2446, in wsgi_app 
        response = self.full_dispatch_request() 
    File "/miniconda3/lib/python3.8/site-packages/flask/app.py", line 1951, in full_dispatch_request 
        rv = self.handle_user_exception(e) 
    File "/miniconda3/lib/python3.8/site-packages/flask/app.py", line 1820, in handle_user_exception 
        reraise(exc_type, exc_value, tb) 
    File "/miniconda3/lib/python3.8/site-packages/flask/_compat.py", line 39, in reraise 
        raise value 
    File "/miniconda3/lib/python3.8/site-packages/flask/app.py", line 1949, in full_dispatch_request 
        rv = self.dispatch_request() 
    File "/miniconda3/lib/python3.8/site-packages/flask/app.py", line 1935, in dispatch_request 
        return self.view_functions[rule.endpoint](**req.view_args) 
    File "/miniconda3/lib/python3.8/site-packages/sagemaker_containers/_transformer.py", line 199, in transform 
        result = self._transform_fn( 
    File "/miniconda3/lib/python3.8/site-packages/sagemaker_containers/_transformer.py", line 227, in _default_transform_fn 
        data = self._input_fn(content, content_type) 
    File "/miniconda3/lib/python3.8/site-packages/sagemaker_containers/_functions.py", line 95, in wrapper 
        six.reraise(error_class, error_class(e), sys.exc_info()[2]) 
    File "/miniconda3/lib/python3.8/site-packages/six.py", line 702, in reraise 
        raise value.with_traceback(tb) 
    File "/miniconda3/lib/python3.8/site-packages/sagemaker_containers/_functions.py", line 93, in wrapper 
        return fn(*args, **kwargs) 
    File "/opt/ml/code/inference.py", line 77, in input_fn 
        raise Exception("Encountered error in deserialize_request.") from e
2024-09-11T23:55:32.657Z    sagemaker_containers._errors.ClientError: Encountered error in deserialize_request.

System information
A description of your system. Please provide:

  • SageMaker Python SDK version: 2.231.0
  • Framework name (eg. PyTorch) or algorithm (eg. KMeans): PyTorch / SKLearn LogisticRegression
  • Framework version: 1.2.1
  • Python version: 3.8.17
  • CPU or GPU: CPU
  • Custom Docker image (Y/N): N

Additional context
Relevant guides / documentation used to generate example:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants