Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for Exception raised while parsing Chat Completions streaming response, in some rare cases #39741

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sdk/ai/azure-ai-inference/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@

* Added support for chat completion messages with `developer` role.
* Updated package document with an example of how to set custom HTTP request headers,
and an example of providing chat completion "messages" as an array of Python `dicts`.
and an example of providing chat completion "messages" as an array of Python `dict` objects.
* Add a descriptive exception error message when `load_client` function or
`get_model_info` method fails to run on an endpoint that does not support the `/info` route.

### Bugs Fixed

* Fix for Exception raised while parsing Chat Completions streaming response, in some rare cases, for
multibyte UTF-8 languages like Chinese ([GitHub Issue 39565](https://github.com/Azure/azure-sdk-for-python/issues/39565)).

## 1.0.0b8 (2025-01-29)

### Features Added
Expand Down
44 changes: 24 additions & 20 deletions sdk/ai/azure-ai-inference/azure/ai/inference/models/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,14 @@ class BaseStreamingChatCompletions:

# The prefix of each line in the SSE stream that contains a JSON string
# to deserialize into a StreamingChatCompletionsUpdate object
_SSE_DATA_EVENT_PREFIX = "data: "
_SSE_DATA_EVENT_PREFIX = b"data: "

# The line indicating the end of the SSE stream
_SSE_DATA_EVENT_DONE = "data: [DONE]"
_SSE_DATA_EVENT_DONE = b"data: [DONE]"

def __init__(self):
self._queue: "queue.Queue[_models.StreamingChatCompletionsUpdate]" = queue.Queue()
self._incomplete_json = ""
self._incomplete_line = b""
self._done = False # Will be set to True when reading 'data: [DONE]' line

# See https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream
Expand All @@ -379,26 +379,34 @@ def _deserialize_and_add_to_queue(self, element: bytes) -> bool:
# Clear the queue of StreamingChatCompletionsUpdate before processing the next block
self._queue.queue.clear()

# Convert `bytes` to string and split the string by newline, while keeping the new line char.
# the last may be a partial "line" that does not contain a newline char at the end.
line_list: List[str] = re.split(r"(?<=\n)", element.decode("utf-8"))
# Split the single input bytes object at new line characters, and get a list of bytes objects, each
# representing a single "line". The bytes object at the end of the list may be a partial "line" that
# does not contain a new line character at the end.
# Note 1: DO NOT try to use something like this here:
# line_list: List[str] = re.split(r"(?<=\n)", element.decode("utf-8"))
# to do full UTF8 decoding of the whole input bytes object, as the last line in the list may be partial, and
# as such may contain a partial UTF8 Chinese character (for example). `decode("utf-8")` will raise an
# exception for such a case. See GitHub issue https://github.com/Azure/azure-sdk-for-python/issues/39565
# Note 2: Consider future re-write and simplifications of this code by using:
# `codecs.getincrementaldecoder("utf-8")`
line_list: List[bytes] = re.split(re.compile(b"(?<=\n)"), element)
dargilco marked this conversation as resolved.
Show resolved Hide resolved
for index, line in enumerate(line_list):

if self._ENABLE_CLASS_LOGS:
logger.debug("[Original line] %s", repr(line))

if index == 0:
line = self._incomplete_json + line
self._incomplete_json = ""
line = self._incomplete_line + line
self._incomplete_line = b""

if index == len(line_list) - 1 and not line.endswith("\n"):
self._incomplete_json = line
if index == len(line_list) - 1 and not line.endswith(b"\n"):
self._incomplete_line = line
return False

if self._ENABLE_CLASS_LOGS:
logger.debug("[Modified line] %s", repr(line))

if line == "\n": # Empty line, indicating flush output to client
if line == b"\n": # Empty line, indicating flush output to client
continue

if not line.startswith(self._SSE_DATA_EVENT_PREFIX):
Expand All @@ -411,23 +419,19 @@ def _deserialize_and_add_to_queue(self, element: bytes) -> bool:

# If you reached here, the line should contain `data: {...}\n`
# where the curly braces contain a valid JSON object.
# It is now safe to do UTF8 decoding of the line.
line_str = line.decode("utf-8")

# Deserialize it into a StreamingChatCompletionsUpdate object
# and add it to the queue.
# pylint: disable=W0212 # Access to a protected member _deserialize of a client class
update = _models.StreamingChatCompletionsUpdate._deserialize(
json.loads(line[len(self._SSE_DATA_EVENT_PREFIX) : -1]), []
json.loads(line_str[len(self._SSE_DATA_EVENT_PREFIX) : -1]), []
)

# We skip any update that has a None or empty choices list
# We skip any update that has a None or empty choices list, and does not have token usage info.
# (this is what OpenAI Python SDK does)
if update.choices or update.usage:

# We update all empty content strings to None
# (this is what OpenAI Python SDK does)
# for choice in update.choices:
# if not choice.delta.content:
# choice.delta.content = None

self._queue.put(update)

if self._ENABLE_CLASS_LOGS:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ async def sample_chat_completions_streaming_async():

# Iterate on the response to get chat completion updates, as they arrive from the service
async for update in response:
print(update.choices[0].delta.content or "", end="", flush=True)
if update.choices and update.choices[0].delta:
print(update.choices[0].delta.content or "", end="", flush=True)
if update.usage:
print(f"\n\nUsage: {update.usage}")


async def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ async def sample_chat_completions_streaming_azure_openai_async():

# Iterate on the response to get chat completion updates, as they arrive from the service
async for update in response:
if len(update.choices) > 0:
if update.choices and update.choices[0].delta:
print(update.choices[0].delta.content or "", end="", flush=True)
if update.usage:
print(f"\n\nUsage: {update.usage}")

await client.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def sample_chat_completions_streaming():
)

for update in response:
print(update.choices[0].delta.content or "", end="", flush=True)
if update.choices and update.choices[0].delta:
print(update.choices[0].delta.content or "", end="", flush=True)
if update.usage:
print(f"\n\nUsage: {update.usage}")

client.close()
# [END chat_completions_streaming]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def sample_chat_completions_streaming_with_entra_id_auth():
)

for update in response:
print(update.choices[0].delta.content or "", end="", flush=True)
if update.choices and update.choices[0].delta:
print(update.choices[0].delta.content or "", end="", flush=True)
if update.usage:
print(f"\n\nUsage: {update.usage}")

client.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def get_flight_info(origin_city: str, destination_city: str):

print("Model response = ", end="")
for update in response:
print(update.choices[0].delta.content or "", end="", flush=True)
if update.choices and update.choices[0].delta:
print(update.choices[0].delta.content or "", end="", flush=True)
if update.usage:
print(f"\n\nUsage: {update.usage}")

client.close()

Expand Down
14 changes: 6 additions & 8 deletions sdk/ai/azure-ai-inference/tests/test_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def test_image_embedding_input_load(self, **kwargs):
]
EXPECTED_STREAMING_RESPONSE = "There are 5,280 feet in a mile. This is a standard measurement used in the United States and a few other countries. If you need help converting other units of measurement, feel free to ask!"

# To test the case where a chinese chars are broken between two lines.
# To test the case where a Chinese character (3 bytes in UTF8 encoding) is broken between two SSE "lines".
# See GitHub issue https://github.com/Azure/azure-sdk-for-python/issues/39565
# Recorded from real chat completions streaming response
# - Using sample code `samples\sample_chat_completions_streaming.py`,
Expand All @@ -325,14 +325,11 @@ def test_image_embedding_input_load(self, **kwargs):
# b'data: {"choices":[{"delta":{"content":"\xe5\xa4\xa9"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\n',
# b'data: {"choices":[{"delta":{"content":"\xe7\xa9\xba"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\ndata: {"choices":[{"delta":{"content":"\xe6\x98\xaf"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\ndata: {"choices":[{"delta":{"content":"\xe8\x93\x9d"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\ndata: {"choices":[{"delta":{"content":"\xe8\x89\xb2"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\ndata: {"choices":[{"delta":{"content":"\xe7\x9a\x84"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\ndata: {"choices":[{"delta":{"content":""},"finish_reason":"stop","index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk","usage":{"completion_tokens":7,"prompt_tokens":41,"total_tokens":48}}\n\ndata: [DONE]\n\n'
# ]
# This is the manually modifed response, after splitting the 2nd line into two lines, in the middle of the 3-byte Chinese charcater. This is an
# attempt to repro what was reported in the GitHub issue, since I don't have a model and that actually generated such a response. The existing code
# works fine in this case, and I don't get the `decode("utf-8")` exception.
# If I add `\n` at the end of the 2nd line below (per inline comment), I do get the exception, but I don't think that's a valid SSE syntex. If this is the behaviour seen by the customer,
# I belive the service code for streaming needs to be fixed.
# Below is a manually modifed response of the above, after splitting the 2nd line into two lines, in the middle of the 3-byte Chinese charcater. This is an
# This represents the case presented in the above GitHub issue:
STREAMING_RESPONSE_BYTES_SPLIT_CHINISE_CHAR: list[bytes] = [
b'data: {"choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\n',
b'data: {"choices":[{"delta":{"content":"\xe5\xa4' # If you add \n at the end, parsing will fail, but I don't think that's a valid SSE syntax.
b'data: {"choices":[{"delta":{"content":"\xe5\xa4',
b'\xa9"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\n',
b'data: {"choices":[{"delta":{"content":"\xe7\xa9\xba"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\ndata: {"choices":[{"delta":{"content":"\xe6\x98\xaf"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\ndata: {"choices":[{"delta":{"content":"\xe8\x93\x9d"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\ndata: {"choices":[{"delta":{"content":"\xe8\x89\xb2"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\ndata: {"choices":[{"delta":{"content":"\xe7\x9a\x84"},"finish_reason":null,"index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk"}\n\ndata: {"choices":[{"delta":{"content":""},"finish_reason":"stop","index":0,"logprobs":null}],"created":1739388793,"id":"97f3900d75c94c028984e58bef2a2584","model":"mistral-large","object":"chat.completion.chunk","usage":{"completion_tokens":7,"prompt_tokens":41,"total_tokens":48}}\n\ndata: [DONE]\n\n',
]
Expand Down Expand Up @@ -367,7 +364,8 @@ async def test_streaming_response_parsing_async(self, **kwargs):
assert actual_response == TestUnitTests.EXPECTED_STREAMING_RESPONSE

# Regression test for the implementation of StreamingChatCompletions class,
# which does the SSE parsing for streaming response
# which does the SSE parsing for streaming response, with input SSE "lines"
# that have a Chinese character (3 bytes in UTF8 encoding) split between between two "lines".
def test_streaming_response_parsing_split_chinese_char(self, **kwargs):

http_response = HttpResponseForUnitTests(TestUnitTests.STREAMING_RESPONSE_BYTES_SPLIT_CHINISE_CHAR)
Expand Down