Skip to content

Commit

Permalink
draft implementation of streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
lilyydu committed Oct 23, 2024
1 parent e5ece1e commit 8ac9bc2
Show file tree
Hide file tree
Showing 14 changed files with 878 additions and 4 deletions.
81 changes: 80 additions & 1 deletion python/packages/ai/teams/ai/clients/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from botbuilder.core import TurnContext

from ...state import Memory, MemoryBase
from ..models import PromptCompletionModel, PromptResponse
from ...streaming.prompt_chunk import PromptChunk
from ...streaming.streaming_response import StreamingResponse
from ..models import PromptCompletionModel, PromptResponse, ResponseReceivedHandler, StreamingEventTypes
from ..prompts import (
ConversationHistorySection,
Message,
Expand Down Expand Up @@ -68,13 +70,25 @@ class LLMClientOptions:
Optional. When set the model will log requests
"""

start_streaming_message: Optional[str] = ""
"""
Optional message to send a client at the start of a streaming response.
"""

end_stream_handler: Optional[ResponseReceivedHandler] = None
"""
Optional handler to run when a stream is about to conclude.
"""


class LLMClient:
"""
LLMClient class that's used to complete prompts.
"""

_options: LLMClientOptions
_start_streaming_message: Optional[str] = ""
_end_stream_handler: Optional[ResponseReceivedHandler] = None

@property
def options(self) -> LLMClientOptions:
Expand Down Expand Up @@ -112,6 +126,54 @@ async def complete_prompt(

remaining_attempts = remaining_attempts or self._options.max_repair_attempts

# Define event handlers
is_streaming = False
streamer: Optional[StreamingResponse] = None

def before_completion(
ctx: TurnContext,
memory: MemoryBase,
functions: PromptFunctions,
tokenizer: Tokenizer,
template: PromptTemplate,
streaming: bool,
) -> None:
# Ignore events for other contexts
if context != ctx:
return

# Check for a streaming response
if streaming:
is_streaming = True

# Create streamer and send initial message
streamer = StreamingResponse(context)
memory.set("temp.streamer", streamer)
if self.options.start_streaming_message:
streamer.queue_informative_update(self.options.start_streaming_message)

def chunk_received(
ctx: TurnContext,
memory: MemoryBase,
chunk: PromptChunk,
) -> None:
if (context != ctx) or streamer is None:
return

# Send chunk to client
text = chunk.delta.content if (chunk.delta and chunk.delta.content) else ""
if len(text) > 0:
streamer.queue_text_chunk(text)

# Subscribe to model events
if self._options.model.events is not None:
self._options.model.events.subscribe(StreamingEventTypes.BEFORE_COMPLETION, before_completion)
self._options.model.events.subscribe(StreamingEventTypes.CHUNK_RECEIVED, chunk_received)

if self._options.end_stream_handler is not None:
handler: ResponseReceivedHandler = self._options.end_stream_handler
self._options.model.events.subscribe(StreamingEventTypes.RESPONSE_RECEIVED, handler)

try:
if remaining_attempts <= 0:
return PromptResponse(
Expand Down Expand Up @@ -187,9 +249,26 @@ async def complete_prompt(

self._add_message_to_history(memory, self._options.history_variable, res.input)
self._add_message_to_history(memory, self._options.history_variable, res.message)

if is_streaming and res.status == "success":
# Delete message from response to avoid sending it twice
res.message = None

# End the stream if streaming
if streamer is not None:
await streamer.end_stream()
return res
except Exception as err: # pylint: disable=broad-except
return PromptResponse(status="error", error=str(err))
finally:
# Unsubscribe from model events
if self._options.model.events is not None:
self._options.model.events.unsubscribe(StreamingEventTypes.BEFORE_COMPLETION, before_completion)
self._options.model.events.unsubscribe(StreamingEventTypes.CHUNK_RECEIVED, chunk_received)

if self._options.end_stream_handler is not None:
handler: ResponseReceivedHandler = self._options.end_stream_handler
self._options.model.events.unsubscribe(StreamingEventTypes.RESPONSE_RECEIVED, handler)

def _add_message_to_history(
self, memory: MemoryBase, variable: str, messages: Union[Message[Any], List[Message[Any]]]
Expand Down
12 changes: 12 additions & 0 deletions python/packages/ai/teams/ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
from .chat_completion_action import ChatCompletionAction
from .openai_model import AzureOpenAIModelOptions, OpenAIModel, OpenAIModelOptions
from .prompt_completion_model import PromptCompletionModel
from .prompt_completion_model_emitter import PromptCompletionModelEmitter
from ...streaming.streaming_events import (
BeforeCompletionHandler,
ChunkReceivedHandler,
ResponseReceivedHandler,
StreamingEventTypes
)
from .prompt_response import PromptResponse, PromptResponseStatus

__all__ = [
Expand All @@ -16,4 +23,9 @@
"PromptCompletionModel",
"PromptResponse",
"PromptResponseStatus",
"PromptCompletionModelEmitter",
"BeforeCompletionHandler",
"ChunkReceivedHandler",
"ResponseReceivedHandler",
"StreamingEventTypes",
]
85 changes: 83 additions & 2 deletions python/packages/ai/teams/ai/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@

from __future__ import annotations

import asyncio
import json
from dataclasses import dataclass
from logging import Logger
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Union, cast

import openai
from botbuilder.core import TurnContext
from openai import NOT_GIVEN
from openai import NOT_GIVEN, AsyncStream
from openai.types import chat, shared_params
from openai.types.chat.chat_completion_message_tool_call_param import Function

from teams.streaming.prompt_chunk import PromptChunk

from ...state import MemoryBase
from ..prompts.message import ActionCall, ActionFunction, Message, MessageContext
from ..prompts.prompt_functions import PromptFunctions
from ..prompts.prompt_template import PromptTemplate
from ..tokenizers import Tokenizer
from .prompt_completion_model import PromptCompletionModel
from .prompt_completion_model_emitter import PromptCompletionModelEmitter
from .prompt_response import PromptResponse


Expand All @@ -46,6 +50,9 @@ class OpenAIModelOptions:
logger: Optional[Logger] = None
"Optional. When set the model will log requests"

stream: bool = False
"Optional. Whether the model's responses should be streamed back."


@dataclass
class AzureOpenAIModelOptions:
Expand Down Expand Up @@ -76,6 +83,9 @@ class AzureOpenAIModelOptions:
logger: Optional[Logger] = None
"Optional. When set the model will log requests"

stream: bool = False
"Optional. Whether the model's responses should be streamed back."


class OpenAIModel(PromptCompletionModel):
"""
Expand Down Expand Up @@ -116,6 +126,7 @@ def __init__(self, options: Union[OpenAIModelOptions, AzureOpenAIModelOptions])
organization=options.organization,
default_headers={"User-Agent": self.user_agent},
)
self.events = PromptCompletionModelEmitter()

async def complete_prompt(
self,
Expand Down Expand Up @@ -163,6 +174,17 @@ async def complete_prompt(
else self._options.default_model
)

if self._options.stream and self.events is not None:
# Signal start of completion
self.events.emit_before_completion(
context=context,
memory=memory,
functions=functions,
tokenizer=tokenizer,
template=template,
streaming=True,
)

res = await template.prompt.render_as_messages(
context=context,
memory=memory,
Expand Down Expand Up @@ -259,8 +281,67 @@ async def complete_prompt(
tool_choice=tool_choice if len(tools) > 0 else NOT_GIVEN,
parallel_tool_calls=parallel_tool_calls if len(tools) > 0 else NOT_GIVEN,
extra_body=extra_body,
stream=self._options.stream,
)

if self._options.stream:
# Log start of streaming
if self._options.logger is not None:
self._options.logger.debug("STREAM STARTED:")

# Enumerate the stream chunks
message_content = ""
message: Message[str] = Message(role="assistant", content="")
completion = cast(AsyncStream[chat.ChatCompletionChunk], completion)

async for chunk in completion:
delta = chunk.choices[0].delta

if delta.role:
message.role = delta.role

if delta.content:
message_content += delta.content

# TODO: Handle tool calls

if self._options.logger is not None:
self._options.logger.debug("CHUNK", delta)

curr_delta_message = PromptChunk(
delta=Message[str](role=str(delta.role), content=delta.content)
)

if self.events is not None:
self.events.emit_chunk_received(context, memory, curr_delta_message)

message.content = message_content

# Log stream completion
if self._options.logger is not None:
self._options.logger.debug("STREAM COMPLETED:")

res_input: Optional[Union[Message, List[Message]]] = None
last_message = len(res.output) - 1

# Skips the first message which is the prompt
if last_message > 0 and res.output[last_message].role != "assistant":
res_input = res.output[last_message]

response = PromptResponse[str](input=res_input, message=message)

streamer = memory.get("temp.streamer")
if (
(self.events is not None)
and (streamer is not None)
):
self.events.emit_response_received(context, memory, response, streamer)

# Let any pending events flush before returning
await asyncio.sleep(0)
return response

completion = cast(chat.ChatCompletion, completion)
if self._options.logger is not None:
self._options.logger.debug("COMPLETION:\n%s", completion.model_dump_json())

Expand Down
4 changes: 4 additions & 0 deletions python/packages/ai/teams/ai/models/prompt_completion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional

from botbuilder.core import TurnContext

from ...state import MemoryBase
from ...user_agent import _UserAgent
from ..models.prompt_completion_model_emitter import PromptCompletionModelEmitter
from ..prompts.prompt_functions import PromptFunctions
from ..prompts.prompt_template import PromptTemplate
from ..tokenizers import Tokenizer
Expand All @@ -22,6 +24,8 @@ class PromptCompletionModel(ABC, _UserAgent):
An AI model that can be used to complete prompts.
"""

events: Optional[PromptCompletionModelEmitter] = None

@abstractmethod
async def complete_prompt(
self,
Expand Down
Loading

0 comments on commit 8ac9bc2

Please sign in to comment.