From fef1ca495df052b89f329f139c6a6cfd164026b7 Mon Sep 17 00:00:00 2001 From: Suzin You <7042047+suzinyou@users.noreply.github.com> Date: Mon, 26 Aug 2024 23:04:52 +0900 Subject: [PATCH 1/8] implement multiturn --- .secrets.baseline | 13 +- core_backend/add_dummy_data_to_db.py | 10 +- core_backend/app/contents/schemas.py | 14 +- core_backend/app/llm_call/llm_rag.py | 13 +- core_backend/app/llm_call/process_output.py | 37 +---- core_backend/app/llm_call/utils.py | 11 +- core_backend/app/question_answer/models.py | 65 +++++++- core_backend/app/question_answer/routers.py | 148 ++++++++++++++++-- core_backend/app/question_answer/schemas.py | 45 +++--- core_backend/app/question_answer/utils.py | 6 + ...bbcc2584fde_change_session_id_to_string.py | 41 +++++ .../tests/api/test_dashboard_overview.py | 3 - 12 files changed, 297 insertions(+), 109 deletions(-) create mode 100644 core_backend/migrations/versions/2024_08_26_1bbcc2584fde_change_session_id_to_string.py diff --git a/.secrets.baseline b/.secrets.baseline index 0d8db7530..5c702eb6f 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -231,15 +231,6 @@ "line_number": 63 } ], - "core_backend/app/question_answer/schemas.py": [ - { - "type": "Secret Keyword", - "filename": "core_backend/app/question_answer/schemas.py", - "hashed_secret": "5b8b7a620e54e681c584f5b5c89152773c10c253", - "is_verified": false, - "line_number": 67 - } - ], "core_backend/migrations/versions/2023_09_16_c5a948963236_create_query_table.py": [ { "type": "Hex High Entropy String", @@ -430,7 +421,7 @@ "filename": "core_backend/tests/api/test_dashboard_overview.py", "hashed_secret": "6367c48dd193d56ea7b0baad25b19455e529f5ee", "is_verified": false, - "line_number": 291 + "line_number": 290 } ], "core_backend/tests/api/test_dashboard_performance.py": [ @@ -590,5 +581,5 @@ } ] }, - "generated_at": "2024-08-23T09:41:17Z" + "generated_at": "2024-08-26T14:03:01Z" } diff --git a/core_backend/add_dummy_data_to_db.py b/core_backend/add_dummy_data_to_db.py index 1116a5810..bc8f4f075 100644 --- a/core_backend/add_dummy_data_to_db.py +++ b/core_backend/add_dummy_data_to_db.py @@ -185,7 +185,7 @@ def create_query_record(dt: datetime, session: Session) -> QueryDB: query_db = QueryDB( user_id=_USER_ID, - session_id=1, + session_id="1", feedback_secret_key="abc123", # pragma: allowlist secret query_text="test query", query_generate_llm_response=False, @@ -198,7 +198,7 @@ def create_query_record(dt: datetime, session: Session) -> QueryDB: def create_response_feedback_record( - dt: datetime, query_id: int, session_id: int, is_negative: bool, session: Session + dt: datetime, query_id: int, session_id: str, is_negative: bool, session: Session ) -> None: """Create a feedback record for a given datetime. @@ -209,7 +209,7 @@ def create_response_feedback_record( query_id The ID of the query record. session_id - The ID of the session record. + The ID of the session record -- uuid is_negative Specifies whether the feedback is negative. session @@ -235,7 +235,7 @@ def create_response_feedback_record( def create_content_feedback_record( dt: datetime, query_id: int, - session_id: int, + session_id: str, is_negative: bool, session: Session, ) -> None: @@ -248,7 +248,7 @@ def create_content_feedback_record( query_id The ID of the query record. session_id - The ID of the session record. + The ID of the session record. (uuid) is_negative Specifies whether the content feedback is negative. session diff --git a/core_backend/app/contents/schemas.py b/core_backend/app/contents/schemas.py index 9ea35a0cb..bc888cb53 100644 --- a/core_backend/app/contents/schemas.py +++ b/core_backend/app/contents/schemas.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import Annotated, List +from typing import List -from pydantic import BaseModel, ConfigDict, Field, StringConstraints +from pydantic import BaseModel, ConfigDict, Field class ContentCreate(BaseModel): @@ -9,13 +9,15 @@ class ContentCreate(BaseModel): Pydantic model for content creation request """ - content_title: Annotated[str, StringConstraints(max_length=150)] = Field( + content_title: str = Field( + max_length=150, examples=["Example Content Title"], ) - content_text: Annotated[str, StringConstraints(max_length=2000)] = Field( - examples=["This is an example content."] + content_text: str = Field( + max_length=2000, + examples=["This is an example content."], ) - content_tags: list = Field(default=[], examples=[[1, 4]]) + content_tags: list = Field(default=[]) content_metadata: dict = Field(default={}, examples=[{"key": "optional_value"}]) is_archived: bool = False diff --git a/core_backend/app/llm_call/llm_rag.py b/core_backend/app/llm_call/llm_rag.py index 349761aad..3f2d839d6 100644 --- a/core_backend/app/llm_call/llm_rag.py +++ b/core_backend/app/llm_call/llm_rag.py @@ -2,8 +2,6 @@ Augmented Generation (RAG). """ -from typing import Optional - from pydantic import ValidationError from ..config import LITELLM_MODEL_GENERATION @@ -15,23 +13,27 @@ async def get_llm_rag_answer( - question: str, + question: str | list[dict[str, str]], context: str, original_language: IdentifiedLanguage, - metadata: Optional[dict] = None, + metadata: dict | None = None, + chat_history: list[dict[str, str]] | None = None, ) -> RAG: """Get an answer from the LLM model using RAG. Parameters ---------- question - The question to ask the LLM model. + The question to ask the LLM model, or list of chat history messages in the form + of {"content": str, "role": str}. context The context to provide to the LLM model. response_language The language of the response. metadata Additional metadata to provide to the LLM model. + chat_history + The previous chat history to provide to the LLM model if it exists. Returns ------- @@ -45,6 +47,7 @@ async def get_llm_rag_answer( result = await _ask_llm_async( user_message=question, system_message=prompt, + chat_history=chat_history, litellm_model=LITELLM_MODEL_GENERATION, metadata=metadata, json=True, diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index 75cc812bb..95b316a5f 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -45,39 +45,7 @@ class AlignScoreData(TypedDict): claim: str -def generate_llm_response__after(func: Callable) -> Callable: - """ - Decorator to generate the LLM response. - - Only runs if the generate_llm_response flag is set to True. - Requires "search_results" and "original_language" in the response. - """ - - @wraps(func) - async def wrapper( - query_refined: QueryRefined, - response: QueryResponse | QueryResponseError, - *args: Any, - **kwargs: Any, - ) -> QueryResponse | QueryResponseError: - """ - Generate the LLM response - """ - response = await func(query_refined, response, *args, **kwargs) - - if not query_refined.generate_llm_response: - return response - - metadata = create_langfuse_metadata( - query_id=response.query_id, user_id=query_refined.user_id - ) - response = await _generate_llm_response(query_refined, response, metadata) - return response - - return wrapper - - -async def _generate_llm_response( +async def generate_llm_query_response( query_refined: QueryRefined, response: QueryResponse, metadata: Optional[dict] = None, @@ -99,12 +67,13 @@ async def _generate_llm_response( return response context = get_context_string_from_search_results(response.search_results) + rag_response = await get_llm_rag_answer( - # use the original query text question=query_refined.query_text_original, context=context, original_language=query_refined.original_language, metadata=metadata, + chat_history=response.chat_history, ) if rag_response.answer != RAG_FAILURE_MESSAGE: diff --git a/core_backend/app/llm_call/utils.py b/core_backend/app/llm_call/utils.py index 2f8a360c9..b423cc2ca 100644 --- a/core_backend/app/llm_call/utils.py +++ b/core_backend/app/llm_call/utils.py @@ -9,10 +9,11 @@ async def _ask_llm_async( - user_message: str, + user_message: str | list[dict[str, str]], system_message: str, - litellm_model: Optional[str] = LITELLM_MODEL_DEFAULT, - litellm_endpoint: Optional[str] = LITELLM_ENDPOINT, + chat_history: list[dict[str, str]] | None = None, + litellm_model: str | None = LITELLM_MODEL_DEFAULT, + litellm_endpoint: str | None = LITELLM_ENDPOINT, metadata: Optional[dict] = None, json: bool = False, ) -> str: @@ -36,6 +37,10 @@ async def _ask_llm_async( "role": "user", }, ] + + if chat_history is not None: + messages = messages[:1] + chat_history + messages[1:] + logger.info(f"LLM input: 'model': {litellm_model}, 'endpoint': {litellm_endpoint}") llm_response_raw = await acompletion( diff --git a/core_backend/app/question_answer/models.py b/core_backend/app/question_answer/models.py index 1c62fc1ab..2320c2e86 100644 --- a/core_backend/app/question_answer/models.py +++ b/core_backend/app/question_answer/models.py @@ -52,7 +52,7 @@ class QueryDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("user.user_id"), nullable=False ) - session_id: Mapped[int] = mapped_column(Integer, nullable=True) + session_id: Mapped[str] = mapped_column(String, nullable=True) feedback_secret_key: Mapped[str] = mapped_column(String, nullable=False) query_text: Mapped[str] = mapped_column(String, nullable=False) query_generate_llm_response: Mapped[bool] = mapped_column(Boolean, nullable=False) @@ -166,7 +166,7 @@ class QueryResponseDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("user.user_id"), nullable=False ) - session_id: Mapped[int] = mapped_column(Integer, nullable=True) + session_id: Mapped[str] = mapped_column(String, nullable=True) search_results: Mapped[JSONDict] = mapped_column(JSON, nullable=False) tts_filepath: Mapped[str] = mapped_column(String, nullable=True) llm_response: Mapped[str] = mapped_column(String, nullable=True) @@ -261,6 +261,59 @@ async def save_query_response_to_db( return user_query_responses_db +async def check_user_and_session_id( + user_id: int, session_id: str, asession: AsyncSession +) -> bool: + """Check if the user_id and session_id are valid.""" + stmt = select(QueryDB).where( + QueryDB.user_id == user_id, QueryDB.session_id == session_id + ) + query_result = await asession.execute(stmt) + return query_result.scalar() is not None + + +async def get_session_history( + user_id: int, session_id: str | None, asession: AsyncSession +) -> list[dict[str, str]]: + """Get the session history for a given session ID.""" + if session_id is None: + raise ValueError("Session ID cannot be None.") + # Query QueryDB for the given user_id and session_id, ordered by query_datetime_utc + query_stmt = ( + select(QueryDB) + .where(QueryDB.user_id == user_id, QueryDB.session_id == session_id) + .order_by(QueryDB.query_datetime_utc) + ) + + query_results = (await asession.execute(query_stmt)).scalars().all() + + # Query QueryResponseDB for the given user_id and session_id + response_stmt = select(QueryResponseDB).where( + QueryResponseDB.user_id == user_id, QueryResponseDB.session_id == session_id + ) + + response_results = (await asession.execute(response_stmt)).scalars().all() + + # Create a dictionary to map query_id to llm_response + response_dict = {response.query_id: response for response in response_results} + + # Construct the list of dictionaries + messages = [] + for query in query_results: + if query.query_id in response_dict: + response = response_dict[query.query_id] + messages.append( + {"content": response.debug_info["original_query"], "role": "user"} + ) + if response.llm_response is not None: + messages.append({"content": response.llm_response, "role": "assistant"}) + + else: # this shouldn't happen, but for completeness + messages.append({"content": query.query_text, "role": "user"}) + + return messages + + class QueryResponseContentDB(Base): """ ORM for storing what content was returned for a given query. @@ -275,7 +328,7 @@ class QueryResponseContentDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("user.user_id"), nullable=False ) - session_id: Mapped[int] = mapped_column(Integer, nullable=True) + session_id: Mapped[str] = mapped_column(String, nullable=True) query_id: Mapped[int] = mapped_column( Integer, ForeignKey("query.query_id"), nullable=False ) @@ -312,7 +365,7 @@ def __repr__(self) -> str: async def save_content_for_query_to_db( user_id: int, - session_id: int | None, + session_id: str | None, query_id: int, contents: dict[int, QuerySearchResult] | None, asession: AsyncSession, @@ -355,7 +408,7 @@ class ResponseFeedbackDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("user.user_id"), nullable=False ) - session_id: Mapped[int] = mapped_column(Integer, nullable=True) + session_id: Mapped[str] = mapped_column(String, nullable=True) feedback_text: Mapped[str] = mapped_column(String, nullable=True) feedback_datetime_utc: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False @@ -435,7 +488,7 @@ class ContentFeedbackDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("user.user_id"), nullable=False ) - session_id: Mapped[int] = mapped_column(Integer, nullable=True) + session_id: Mapped[str] = mapped_column(String, nullable=True) feedback_text: Mapped[str] = mapped_column(String, nullable=True) feedback_datetime_utc: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 660243ebd..a9037a8dc 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -27,13 +27,14 @@ ) from ..llm_call.process_output import ( check_align_score__after, - generate_llm_response__after, + generate_llm_query_response, generate_tts__after, ) from ..users.models import UserDB from ..utils import ( create_langfuse_metadata, generate_random_filename, + generate_secret_key, get_file_extension_from_mime_type, get_http_client, setup_logger, @@ -43,6 +44,7 @@ from .models import ( QueryDB, check_secret_key_match, + get_session_history, save_content_feedback_to_db, save_content_for_query_to_db, save_query_response_to_db, @@ -59,6 +61,7 @@ ResponseFeedbackBase, ) from .speech_components.external_voice_components import transcribe_audio +from .utils import format_session_history_as_query logger = setup_logger() @@ -145,13 +148,14 @@ async def voice_search( generate_tts=True, ) - response = await search_base( + response = await get_search_response( query_refined=user_query_refined_template, response=response_template, user_id=user_db.user_id, n_similar=int(N_TOP_CONTENT), asession=asession, exclude_archived=True, + is_chat=False, ) await save_query_response_to_db(user_query_db, response, asession) await increment_query_count( @@ -230,21 +234,105 @@ async def search( asession=asession, generate_tts=False, ) - response = await search_base( + response = await get_search_response( query_refined=user_query_refined_template, response=response_template, user_id=user_db.user_id, n_similar=int(N_TOP_CONTENT), asession=asession, exclude_archived=True, + is_chat=False, + ) + + if user_query.generate_llm_response: + response = get_generation_response( + query_refined=user_query_refined_template, + response=response, + ) + + await save_query_response_to_db(user_query_db, response, asession) + await increment_query_count( + user_id=user_db.user_id, + contents=response.search_results, + asession=asession, + ) + await save_content_for_query_to_db( + user_id=user_db.user_id, + session_id=user_query.session_id, + query_id=response.query_id, + contents=response.search_results, + asession=asession, + ) + + if type(response) is QueryResponse: + return response + elif type(response) is QueryResponseError: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump() + ) + else: + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"message": "Internal server error"}, + ) + + +@router.post( + "/chat", + response_model=QueryResponse, + responses={ + status.HTTP_400_BAD_REQUEST: { + "model": QueryResponseError, + "description": "Guardrail failure", + } + }, +) +async def chat( + user_query: QueryBase, + asession: AsyncSession = Depends(get_async_session), + user_db: UserDB = Depends(authenticate_key), +) -> QueryResponse | JSONResponse: + """ """ + is_new_chat = user_query.session_id is None + + if is_new_chat: + session_id = generate_secret_key() + user_query.session_id = session_id + + ( + user_query_db, + user_query_refined_template, + response_template, + ) = await get_user_query_and_response( + user_id=user_db.user_id, + user_query=user_query, + asession=asession, + generate_tts=False, + ) + + response = await get_search_response( + query_refined=user_query_refined_template, + response=response_template, + user_id=user_db.user_id, + n_similar=int(N_TOP_CONTENT), + asession=asession, + exclude_archived=True, + is_chat=True, + ) + + response = await get_generation_response( + query_refined=user_query_refined_template, + response=response, ) await save_query_response_to_db(user_query_db, response, asession) + await increment_query_count( user_id=user_db.user_id, contents=response.search_results, asession=asession, ) + await save_content_for_query_to_db( user_id=user_db.user_id, session_id=user_query.session_id, @@ -270,16 +358,14 @@ async def search( @classify_safety__before @translate_question__before @paraphrase_question__before -@generate_tts__after -@generate_llm_response__after -@check_align_score__after -async def search_base( +async def get_search_response( query_refined: QueryRefined, response: QueryResponse, user_id: int, n_similar: int, asession: AsyncSession, exclude_archived: bool = True, + is_chat: bool = False, ) -> QueryResponse | QueryResponseError: """Get similar content and construct the LLM answer for the user query. @@ -308,13 +394,25 @@ async def search_base( An appropriate query response object. """ - - # always do the embeddings search even if some guardrails have failed + # No checks for errors: + # always do the embeddings search even if some guardrails have failed metadata = create_langfuse_metadata(query_id=response.query_id, user_id=user_id) + if is_chat: + # May return empty list if no chat history. + # Ideally want to skip if it's a new chat. + messages = await get_session_history( + user_id=query_refined.user_id, + session_id=query_refined.session_id, + asession=asession, + ) + question = format_session_history_as_query(messages) + response.chat_history = messages + else: + question = query_refined.query_text + search_results = await get_similar_content_async( user_id=user_id, - # use latest version of the text - question=query_refined.query_text, + question=question, n_similar=n_similar, asession=asession, metadata=metadata, @@ -325,6 +423,34 @@ async def search_base( return response +@generate_tts__after +@check_align_score__after +async def get_generation_response( + query_refined: QueryRefined, + response: QueryResponse, +) -> QueryResponse | QueryResponseError: + """ + Generate a response using an LLM given a query with search results. + + Only runs if the generate_llm_response flag is set to True. + Requires "search_results" and "original_language" in the response. + + If it's a chat query, the chat history is passed to the LLM model. + """ + if not query_refined.generate_llm_response: + return response + + metadata = create_langfuse_metadata( + query_id=response.query_id, user_id=query_refined.user_id + ) + + response = await generate_llm_query_response( + query_refined=query_refined, response=response, metadata=metadata + ) + + return response + + async def get_user_query_and_response( user_id: int, user_query: QueryBase, diff --git a/core_backend/app/question_answer/schemas.py b/core_backend/app/question_answer/schemas.py index 128b1792f..97d125bef 100644 --- a/core_backend/app/question_answer/schemas.py +++ b/core_backend/app/question_answer/schemas.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, Optional +from typing import Dict from pydantic import BaseModel, ConfigDict, Field @@ -13,9 +13,9 @@ class QueryBase(BaseModel): """ query_text: str = Field(..., examples=["What is AAQ?"]) - session_id: Optional[int] = None generate_llm_response: bool = Field(False) query_metadata: dict = Field({}, examples=[{"some_key": "some_value"}]) + session_id: str | None = Field(default=None) model_config = ConfigDict(from_attributes=True) @@ -54,7 +54,7 @@ class QueryResponse(BaseModel): """ query_id: int = Field(..., examples=[1]) - session_id: Optional[int] = None + session_id: str | None = Field(None, examples=["session-id-12345-abcde"]) feedback_secret_key: str = Field(..., examples=["secret-key-12345-abcde"]) llm_response: str | None = Field(None, examples=["Example LLM response"]) @@ -62,30 +62,25 @@ class QueryResponse(BaseModel): None, examples=[ { - "query_id": 1, - "session_id": 1, - "feedback_secret_key": "secret-key-12345-abcde", - "llm_response": "Example LLM response " - "(null if generate_llm_response is false)", - "search_results": { - "0": { - "title": "Example content title", - "text": "Example content text", - "id": 23, - "distance": 0.1, - }, - "1": { - "title": "Another example content title", - "text": "Another example content text", - "id": 12, - "distance": 0.2, - }, + "0": { + "title": "Example content title", + "text": "Example content text", + "id": 23, + "distance": 0.1, + }, + "1": { + "title": "Another example content title", + "text": "Another example content text", + "id": 12, + "distance": 0.2, }, - "debug_info": {"example": "debug-info"}, } ], ) debug_info: dict = Field({}, examples=[{"example": "debug-info"}]) + chat_history: list[dict[str, str]] | None = Field( + None, examples=[{"role": "user", "content": "Hello"}] + ) model_config = ConfigDict(from_attributes=True) @@ -110,7 +105,7 @@ class QueryResponseError(QueryResponse): """ error_type: ErrorType = Field(..., examples=["example_error"]) - error_message: Optional[str] = Field(None, examples=["Example error message"]) + error_message: str | None = Field(None, examples=["Example error message"]) model_config = ConfigDict(from_attributes=True) @@ -122,11 +117,11 @@ class ResponseFeedbackBase(BaseModel): """ query_id: int = Field(..., examples=[1]) - session_id: Optional[int] = None + session_id: str | None = None feedback_sentiment: FeedbackSentiment = Field( FeedbackSentiment.UNKNOWN, examples=["positive"] ) - feedback_text: Optional[str] = Field(None, examples=["This is helpful"]) + feedback_text: str | None = Field(None, examples=["This is helpful"]) feedback_secret_key: str = Field(..., examples=["secret-key-12345-abcde"]) model_config = ConfigDict(from_attributes=True) diff --git a/core_backend/app/question_answer/utils.py b/core_backend/app/question_answer/utils.py index a1ab8a666..3d9b7e323 100644 --- a/core_backend/app/question_answer/utils.py +++ b/core_backend/app/question_answer/utils.py @@ -16,3 +16,9 @@ def get_context_string_from_search_results( context_list.append(f"{key}. {result.title}\n{result.text}") context_string = "\n\n".join(context_list) return context_string + + +def format_session_history_as_query(messages: list[dict[str, str]]) -> str: + """Format the session history as a query.""" + history = "\n".join([message["content"] for message in messages]) + return history diff --git a/core_backend/migrations/versions/2024_08_26_1bbcc2584fde_change_session_id_to_string.py b/core_backend/migrations/versions/2024_08_26_1bbcc2584fde_change_session_id_to_string.py new file mode 100644 index 000000000..e839e56a4 --- /dev/null +++ b/core_backend/migrations/versions/2024_08_26_1bbcc2584fde_change_session_id_to_string.py @@ -0,0 +1,41 @@ +"""change session_id to string + +Revision ID: 1bbcc2584fde +Revises: c571cf9aae63 +Create Date: 2024-08-26 19:40:08.259316 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "1bbcc2584fde" # pragma: allowlist secret +down_revision: Union[str, None] = "c571cf9aae63" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +tables_with_session_id = [ + "query", + "query_response", + "query_response_feedback", + "query_response_content", + "content_feedback", +] + + +def upgrade() -> None: + for table in tables_with_session_id: + op.alter_column( + table, "session_id", existing_type=sa.Integer(), type_=sa.String() + ) + + +def downgrade() -> None: + for table in tables_with_session_id: + # TODO: if the session_id can't be casted, set it to null + op.alter_column( + table, "session_id", existing_type=sa.String(), type_=sa.Integer() + ) diff --git a/core_backend/tests/api/test_dashboard_overview.py b/core_backend/tests/api/test_dashboard_overview.py index 54e85d5de..45773ed3d 100644 --- a/core_backend/tests/api/test_dashboard_overview.py +++ b/core_backend/tests/api/test_dashboard_overview.py @@ -287,7 +287,6 @@ async def queries( for i in range(count): query = QueryDB( user_id=1, - session_id=1, feedback_secret_key="abc123", query_text=f"test_{day}_{i}", query_generate_llm_response=False, @@ -321,7 +320,6 @@ async def queries_hour(self, asession: AsyncSession) -> AsyncGenerator[None, Non for i in range(count): query = QueryDB( user_id=1, - session_id=1, feedback_secret_key="abc123", query_text=f"test_{hour}_{i}", query_generate_llm_response=False, @@ -512,7 +510,6 @@ async def create_query_and_feedback( for i in range(n_positive + n_negative + n_neutral): query = QueryDB( user_id=user_id, - session_id=1, feedback_secret_key="abc123", query_text="test message", query_generate_llm_response=False, From b7dae791fbc9d36e92e327a07007632232cfc474 Mon Sep 17 00:00:00 2001 From: Suzin You <7042047+suzinyou@users.noreply.github.com> Date: Mon, 26 Aug 2024 23:27:08 +0900 Subject: [PATCH 2/8] refactor search_base func --- .secrets.baseline | 13 +---- core_backend/app/contents/schemas.py | 14 +++--- core_backend/app/llm_call/llm_rag.py | 5 +- core_backend/app/llm_call/process_output.py | 34 +------------ core_backend/app/llm_call/utils.py | 8 ++- core_backend/app/question_answer/routers.py | 50 +++++++++++++++---- core_backend/app/question_answer/schemas.py | 43 +++++++--------- .../tests/api/test_dashboard_overview.py | 3 -- 8 files changed, 72 insertions(+), 98 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index 0d8db7530..5c702eb6f 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -231,15 +231,6 @@ "line_number": 63 } ], - "core_backend/app/question_answer/schemas.py": [ - { - "type": "Secret Keyword", - "filename": "core_backend/app/question_answer/schemas.py", - "hashed_secret": "5b8b7a620e54e681c584f5b5c89152773c10c253", - "is_verified": false, - "line_number": 67 - } - ], "core_backend/migrations/versions/2023_09_16_c5a948963236_create_query_table.py": [ { "type": "Hex High Entropy String", @@ -430,7 +421,7 @@ "filename": "core_backend/tests/api/test_dashboard_overview.py", "hashed_secret": "6367c48dd193d56ea7b0baad25b19455e529f5ee", "is_verified": false, - "line_number": 291 + "line_number": 290 } ], "core_backend/tests/api/test_dashboard_performance.py": [ @@ -590,5 +581,5 @@ } ] }, - "generated_at": "2024-08-23T09:41:17Z" + "generated_at": "2024-08-26T14:03:01Z" } diff --git a/core_backend/app/contents/schemas.py b/core_backend/app/contents/schemas.py index 9ea35a0cb..bc888cb53 100644 --- a/core_backend/app/contents/schemas.py +++ b/core_backend/app/contents/schemas.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import Annotated, List +from typing import List -from pydantic import BaseModel, ConfigDict, Field, StringConstraints +from pydantic import BaseModel, ConfigDict, Field class ContentCreate(BaseModel): @@ -9,13 +9,15 @@ class ContentCreate(BaseModel): Pydantic model for content creation request """ - content_title: Annotated[str, StringConstraints(max_length=150)] = Field( + content_title: str = Field( + max_length=150, examples=["Example Content Title"], ) - content_text: Annotated[str, StringConstraints(max_length=2000)] = Field( - examples=["This is an example content."] + content_text: str = Field( + max_length=2000, + examples=["This is an example content."], ) - content_tags: list = Field(default=[], examples=[[1, 4]]) + content_tags: list = Field(default=[]) content_metadata: dict = Field(default={}, examples=[{"key": "optional_value"}]) is_archived: bool = False diff --git a/core_backend/app/llm_call/llm_rag.py b/core_backend/app/llm_call/llm_rag.py index 349761aad..0f86fa9cb 100644 --- a/core_backend/app/llm_call/llm_rag.py +++ b/core_backend/app/llm_call/llm_rag.py @@ -2,8 +2,6 @@ Augmented Generation (RAG). """ -from typing import Optional - from pydantic import ValidationError from ..config import LITELLM_MODEL_GENERATION @@ -18,7 +16,7 @@ async def get_llm_rag_answer( question: str, context: str, original_language: IdentifiedLanguage, - metadata: Optional[dict] = None, + metadata: dict | None = None, ) -> RAG: """Get an answer from the LLM model using RAG. @@ -32,7 +30,6 @@ async def get_llm_rag_answer( The language of the response. metadata Additional metadata to provide to the LLM model. - Returns ------- RAG diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index 5e7d32a1d..bb2a1d3a9 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -45,39 +45,7 @@ class AlignScoreData(TypedDict): claim: str -def generate_llm_response__after(func: Callable) -> Callable: - """ - Decorator to generate the LLM response. - - Only runs if the generate_llm_response flag is set to True. - Requires "search_results" and "original_language" in the response. - """ - - @wraps(func) - async def wrapper( - query_refined: QueryRefined, - response: QueryResponse | QueryResponseError, - *args: Any, - **kwargs: Any, - ) -> QueryResponse | QueryResponseError: - """ - Generate the LLM response - """ - response = await func(query_refined, response, *args, **kwargs) - - if not query_refined.generate_llm_response: - return response - - metadata = create_langfuse_metadata( - query_id=response.query_id, user_id=query_refined.user_id - ) - response = await _generate_llm_response(query_refined, response, metadata) - return response - - return wrapper - - -async def _generate_llm_response( +async def generate_llm_query_response( query_refined: QueryRefined, response: QueryResponse, metadata: Optional[dict] = None, diff --git a/core_backend/app/llm_call/utils.py b/core_backend/app/llm_call/utils.py index 2f8a360c9..5a5ac1331 100644 --- a/core_backend/app/llm_call/utils.py +++ b/core_backend/app/llm_call/utils.py @@ -1,5 +1,3 @@ -from typing import Optional - from litellm import acompletion from ..config import LITELLM_API_KEY, LITELLM_ENDPOINT, LITELLM_MODEL_DEFAULT @@ -11,9 +9,9 @@ async def _ask_llm_async( user_message: str, system_message: str, - litellm_model: Optional[str] = LITELLM_MODEL_DEFAULT, - litellm_endpoint: Optional[str] = LITELLM_ENDPOINT, - metadata: Optional[dict] = None, + litellm_model: str | None = LITELLM_MODEL_DEFAULT, + litellm_endpoint: str | None = LITELLM_ENDPOINT, + metadata: dict | None = None, json: bool = False, ) -> str: """ diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 660243ebd..68b919a97 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -27,7 +27,7 @@ ) from ..llm_call.process_output import ( check_align_score__after, - generate_llm_response__after, + generate_llm_query_response, generate_tts__after, ) from ..users.models import UserDB @@ -145,7 +145,7 @@ async def voice_search( generate_tts=True, ) - response = await search_base( + response = await get_search_response( query_refined=user_query_refined_template, response=response_template, user_id=user_db.user_id, @@ -230,7 +230,7 @@ async def search( asession=asession, generate_tts=False, ) - response = await search_base( + response = await get_search_response( query_refined=user_query_refined_template, response=response_template, user_id=user_db.user_id, @@ -239,6 +239,13 @@ async def search( exclude_archived=True, ) + if user_query.generate_llm_response: + response = await get_generation_response( + query_refined=user_query_refined_template, + response=response, + ) + logger.info(f"Search response: {response}") + logger.debug(f"Search response type: {type(response)}") await save_query_response_to_db(user_query_db, response, asession) await increment_query_count( user_id=user_db.user_id, @@ -270,10 +277,7 @@ async def search( @classify_safety__before @translate_question__before @paraphrase_question__before -@generate_tts__after -@generate_llm_response__after -@check_align_score__after -async def search_base( +async def get_search_response( query_refined: QueryRefined, response: QueryResponse, user_id: int, @@ -308,12 +312,12 @@ async def search_base( An appropriate query response object. """ - - # always do the embeddings search even if some guardrails have failed + # No checks for errors: + # always do the embeddings search even if some guardrails have failed metadata = create_langfuse_metadata(query_id=response.query_id, user_id=user_id) + search_results = await get_similar_content_async( user_id=user_id, - # use latest version of the text question=query_refined.query_text, n_similar=n_similar, asession=asession, @@ -325,6 +329,32 @@ async def search_base( return response +@generate_tts__after +@check_align_score__after +async def get_generation_response( + query_refined: QueryRefined, + response: QueryResponse, +) -> QueryResponse | QueryResponseError: + """ + Generate a response using an LLM given a query with search results. + + Only runs if the generate_llm_response flag is set to True. + Requires "search_results" and "original_language" in the response. + """ + if not query_refined.generate_llm_response: + return response + + metadata = create_langfuse_metadata( + query_id=response.query_id, user_id=query_refined.user_id + ) + + response = await generate_llm_query_response( + query_refined=query_refined, response=response, metadata=metadata + ) + + return response + + async def get_user_query_and_response( user_id: int, user_query: QueryBase, diff --git a/core_backend/app/question_answer/schemas.py b/core_backend/app/question_answer/schemas.py index 128b1792f..ae54ede08 100644 --- a/core_backend/app/question_answer/schemas.py +++ b/core_backend/app/question_answer/schemas.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, Optional +from typing import Dict from pydantic import BaseModel, ConfigDict, Field @@ -13,9 +13,9 @@ class QueryBase(BaseModel): """ query_text: str = Field(..., examples=["What is AAQ?"]) - session_id: Optional[int] = None generate_llm_response: bool = Field(False) query_metadata: dict = Field({}, examples=[{"some_key": "some_value"}]) + session_id: int | None = Field(default=None, exclude=True) model_config = ConfigDict(from_attributes=True) @@ -54,34 +54,25 @@ class QueryResponse(BaseModel): """ query_id: int = Field(..., examples=[1]) - session_id: Optional[int] = None + session_id: int | None = Field(None, exclude=True) feedback_secret_key: str = Field(..., examples=["secret-key-12345-abcde"]) llm_response: str | None = Field(None, examples=["Example LLM response"]) search_results: Dict[int, QuerySearchResult] | None = Field( - None, examples=[ { - "query_id": 1, - "session_id": 1, - "feedback_secret_key": "secret-key-12345-abcde", - "llm_response": "Example LLM response " - "(null if generate_llm_response is false)", - "search_results": { - "0": { - "title": "Example content title", - "text": "Example content text", - "id": 23, - "distance": 0.1, - }, - "1": { - "title": "Another example content title", - "text": "Another example content text", - "id": 12, - "distance": 0.2, - }, + "0": { + "title": "Example content title", + "text": "Example content text", + "id": 23, + "distance": 0.1, + }, + "1": { + "title": "Another example content title", + "text": "Another example content text", + "id": 12, + "distance": 0.2, }, - "debug_info": {"example": "debug-info"}, } ], ) @@ -110,7 +101,7 @@ class QueryResponseError(QueryResponse): """ error_type: ErrorType = Field(..., examples=["example_error"]) - error_message: Optional[str] = Field(None, examples=["Example error message"]) + error_message: str | None = Field(None, examples=["Example error message"]) model_config = ConfigDict(from_attributes=True) @@ -122,11 +113,11 @@ class ResponseFeedbackBase(BaseModel): """ query_id: int = Field(..., examples=[1]) - session_id: Optional[int] = None + session_id: int | None = None feedback_sentiment: FeedbackSentiment = Field( FeedbackSentiment.UNKNOWN, examples=["positive"] ) - feedback_text: Optional[str] = Field(None, examples=["This is helpful"]) + feedback_text: str | None = Field(None, examples=["This is helpful"]) feedback_secret_key: str = Field(..., examples=["secret-key-12345-abcde"]) model_config = ConfigDict(from_attributes=True) diff --git a/core_backend/tests/api/test_dashboard_overview.py b/core_backend/tests/api/test_dashboard_overview.py index 54e85d5de..45773ed3d 100644 --- a/core_backend/tests/api/test_dashboard_overview.py +++ b/core_backend/tests/api/test_dashboard_overview.py @@ -287,7 +287,6 @@ async def queries( for i in range(count): query = QueryDB( user_id=1, - session_id=1, feedback_secret_key="abc123", query_text=f"test_{day}_{i}", query_generate_llm_response=False, @@ -321,7 +320,6 @@ async def queries_hour(self, asession: AsyncSession) -> AsyncGenerator[None, Non for i in range(count): query = QueryDB( user_id=1, - session_id=1, feedback_secret_key="abc123", query_text=f"test_{hour}_{i}", query_generate_llm_response=False, @@ -512,7 +510,6 @@ async def create_query_and_feedback( for i in range(n_positive + n_negative + n_neutral): query = QueryDB( user_id=user_id, - session_id=1, feedback_secret_key="abc123", query_text="test message", query_generate_llm_response=False, From 0f1f546940337d833d0d1f3d8beb2ace78dc33ed Mon Sep 17 00:00:00 2001 From: Suzin You <7042047+suzinyou@users.noreply.github.com> Date: Mon, 26 Aug 2024 23:32:12 +0900 Subject: [PATCH 3/8] remove unused logging --- core_backend/app/question_answer/routers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 68b919a97..303edcd2f 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -244,8 +244,7 @@ async def search( query_refined=user_query_refined_template, response=response, ) - logger.info(f"Search response: {response}") - logger.debug(f"Search response type: {type(response)}") + await save_query_response_to_db(user_query_db, response, asession) await increment_query_count( user_id=user_db.user_id, From 7e8f6a144493646d625ef9201563a81f16cda3e5 Mon Sep 17 00:00:00 2001 From: Suzin You <7042047+suzinyou@users.noreply.github.com> Date: Mon, 26 Aug 2024 23:39:15 +0900 Subject: [PATCH 4/8] cleanup json schema displays --- core_backend/app/contents/schemas.py | 2 +- core_backend/app/question_answer/schemas.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core_backend/app/contents/schemas.py b/core_backend/app/contents/schemas.py index bc888cb53..164b41592 100644 --- a/core_backend/app/contents/schemas.py +++ b/core_backend/app/contents/schemas.py @@ -18,7 +18,7 @@ class ContentCreate(BaseModel): examples=["This is an example content."], ) content_tags: list = Field(default=[]) - content_metadata: dict = Field(default={}, examples=[{"key": "optional_value"}]) + content_metadata: dict = Field(default={}) is_archived: bool = False model_config = ConfigDict( diff --git a/core_backend/app/question_answer/schemas.py b/core_backend/app/question_answer/schemas.py index ae54ede08..35dd8b193 100644 --- a/core_backend/app/question_answer/schemas.py +++ b/core_backend/app/question_answer/schemas.py @@ -2,6 +2,7 @@ from typing import Dict from pydantic import BaseModel, ConfigDict, Field +from pydantic.json_schema import SkipJsonSchema from ..llm_call.llm_prompts import IdentifiedLanguage from ..schemas import FeedbackSentiment, QuerySearchResult @@ -15,7 +16,7 @@ class QueryBase(BaseModel): query_text: str = Field(..., examples=["What is AAQ?"]) generate_llm_response: bool = Field(False) query_metadata: dict = Field({}, examples=[{"some_key": "some_value"}]) - session_id: int | None = Field(default=None, exclude=True) + session_id: SkipJsonSchema[int | None] = Field(default=None, exclude=True) model_config = ConfigDict(from_attributes=True) @@ -113,7 +114,7 @@ class ResponseFeedbackBase(BaseModel): """ query_id: int = Field(..., examples=[1]) - session_id: int | None = None + session_id: SkipJsonSchema[int | None] = None feedback_sentiment: FeedbackSentiment = Field( FeedbackSentiment.UNKNOWN, examples=["positive"] ) From 9e0b2581b2bdfc76eb10f51b8bd48596320fcf75 Mon Sep 17 00:00:00 2001 From: Suzin You <7042047+suzinyou@users.noreply.github.com> Date: Mon, 26 Aug 2024 23:51:28 +0900 Subject: [PATCH 5/8] fix mypy issue --- core_backend/app/llm_call/process_input.py | 4 ++++ core_backend/app/llm_call/process_output.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/core_backend/app/llm_call/process_input.py b/core_backend/app/llm_call/process_input.py index 56b793ccc..ce65bc602 100644 --- a/core_backend/app/llm_call/process_input.py +++ b/core_backend/app/llm_call/process_input.py @@ -121,6 +121,7 @@ def _process_identified_language_response( error_response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, @@ -206,6 +207,7 @@ async def _translate_question( else: error_response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, @@ -275,6 +277,7 @@ async def _classify_safety( else: error_response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, @@ -352,6 +355,7 @@ async def _paraphrase_question( else: error_response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index bb2a1d3a9..3e58df9c7 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -82,6 +82,7 @@ async def generate_llm_query_response( else: response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=None, search_results=response.search_results, @@ -187,6 +188,7 @@ async def _check_align_score( ) response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=None, search_results=response.search_results, @@ -329,6 +331,7 @@ async def _generate_tts_response( logger.error(f"Error generating TTS for query_id {response.query_id}: {e}") return QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, From 07dbdf7886033188f71cd3cd2ab10441e56297a4 Mon Sep 17 00:00:00 2001 From: Suzin You <7042047+suzinyou@users.noreply.github.com> Date: Mon, 26 Aug 2024 23:56:39 +0900 Subject: [PATCH 6/8] fix mypy issue --- core_backend/app/llm_call/process_output.py | 1 + 1 file changed, 1 insertion(+) diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index 3e58df9c7..85695f1f7 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -281,6 +281,7 @@ async def wrapper( ) response = QueryAudioResponse( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, From 49b6e1236abf6bdf95e8f7a00aca27aabf102f2c Mon Sep 17 00:00:00 2001 From: Suzin You <7042047+suzinyou@users.noreply.github.com> Date: Tue, 27 Aug 2024 00:09:09 +0900 Subject: [PATCH 7/8] refactor voice search endpoint --- core_backend/app/question_answer/routers.py | 168 +++++++++--------- .../speech_components/utils.py | 16 +- 2 files changed, 96 insertions(+), 88 deletions(-) diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 303edcd2f..3238c9e2a 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -6,7 +6,7 @@ from io import BytesIO from typing import Tuple -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, UploadFile, status from fastapi.responses import JSONResponse from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -35,7 +35,6 @@ create_langfuse_metadata, generate_random_filename, get_file_extension_from_mime_type, - get_http_client, setup_logger, upload_file_to_gcs, ) @@ -59,6 +58,7 @@ ResponseFeedbackBase, ) from .speech_components.external_voice_components import transcribe_audio +from .speech_components.utils import post_to_speech logger = setup_logger() @@ -77,62 +77,27 @@ @router.post( - "/voice-search", - response_model=QueryAudioResponse, + "/search", + response_model=QueryResponse, responses={ status.HTTP_400_BAD_REQUEST: { "model": QueryResponseError, - "description": "Bad Request", - }, - status.HTTP_500_INTERNAL_SERVER_ERROR: { - "model": QueryResponseError, - "description": "Internal Server Error", - }, + "description": "Guardrail failure", + } }, ) -async def voice_search( - file: UploadFile = File(...), +async def search( + user_query: QueryBase, asession: AsyncSession = Depends(get_async_session), user_db: UserDB = Depends(authenticate_key), -) -> QueryAudioResponse | JSONResponse: - """ - Endpoint to transcribe audio from the provided file, - generate an LLM response, by default generate_tts is - set to true and return a public signed URL of an audio - file containing the spoken version of the generated response +) -> QueryResponse | JSONResponse: """ - file_stream = BytesIO(await file.read()) - - file_path = f"temp/{file.filename}" - with open(file_path, "wb") as f: - file_stream.seek(0) - f.write(file_stream.read()) - - file_stream.seek(0) - - content_type = file.content_type - - file_extension = get_file_extension_from_mime_type(content_type) - unique_filename = generate_random_filename(file_extension) - - destination_blob_name = f"stt-voice-notes/{unique_filename}" - - await upload_file_to_gcs( - GCS_SPEECH_BUCKET, file_stream, destination_blob_name, content_type - ) - - if CUSTOM_SPEECH_ENDPOINT is not None: - transcription = await post_to_speech(file_path, CUSTOM_SPEECH_ENDPOINT) - transcription_result = transcription["text"] - - else: - transcription_result = await transcribe_audio(file_path) + Search endpoint finds the most similar content to the user query and optionally + generates a single-turn LLM response. - user_query = QueryBase( - generate_llm_response=True, - query_text=transcription_result, - query_metadata={}, - ) + If any guardrails fail, the embeddings search is still done and an error 400 is + returned that includes the search results as well as the details of the failure. + """ ( user_query_db, @@ -142,9 +107,8 @@ async def voice_search( user_id=user_db.user_id, user_query=user_query, asession=asession, - generate_tts=True, + generate_tts=False, ) - response = await get_search_response( query_refined=user_query_refined_template, response=response_template, @@ -153,6 +117,13 @@ async def voice_search( asession=asession, exclude_archived=True, ) + + if user_query.generate_llm_response: + response = await get_generation_response( + query_refined=user_query_refined_template, + response=response, + ) + await save_query_response_to_db(user_query_db, response, asession) await increment_query_count( user_id=user_db.user_id, @@ -161,17 +132,13 @@ async def voice_search( ) await save_content_for_query_to_db( user_id=user_db.user_id, - query_id=response.query_id, session_id=user_query.session_id, + query_id=response.query_id, contents=response.search_results, asession=asession, ) - if os.path.exists(file_path): - os.remove(file_path) - file_stream.close() - - if type(response) is QueryAudioResponse: + if type(response) is QueryResponse: return response elif type(response) is QueryResponseError: return JSONResponse( @@ -180,45 +147,67 @@ async def voice_search( else: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"error": "Internal server error"}, + content={"message": "Internal server error"}, ) -async def post_to_speech(file_path: str, endpoint_url: str) -> dict: - """ - Post request the file to the speech endpoint to get the transcription - """ - async with get_http_client() as client: - async with client.post(endpoint_url, json={"file_path": file_path}) as response: - if response.status != 200: - error_content = await response.json() - logger.error(f"Error from CUSTOM_SPEECH_ENDPOINT: {error_content}") - raise HTTPException(status_code=response.status, detail=error_content) - return await response.json() - - @router.post( - "/search", - response_model=QueryResponse, + "/voice-search", + response_model=QueryAudioResponse, responses={ status.HTTP_400_BAD_REQUEST: { "model": QueryResponseError, - "description": "Guardrail failure", - } + "description": "Bad Request", + }, + status.HTTP_500_INTERNAL_SERVER_ERROR: { + "model": QueryResponseError, + "description": "Internal Server Error", + }, }, ) -async def search( - user_query: QueryBase, +async def voice_search( + file: UploadFile = File(...), asession: AsyncSession = Depends(get_async_session), user_db: UserDB = Depends(authenticate_key), -) -> QueryResponse | JSONResponse: +) -> QueryAudioResponse | JSONResponse: """ - Search endpoint finds the most similar content to the user query and optionally - generates a single-turn LLM response. - - If any guardrails fail, the embeddings search is still done and an error 400 is - returned that includes the search results as well as the details of the failure. + Endpoint to transcribe audio from the provided file, + generate an LLM response, by default generate_tts is + set to true and return a public signed URL of an audio + file containing the spoken version of the generated response """ + file_stream = BytesIO(await file.read()) + + file_path = f"temp/{file.filename}" + with open(file_path, "wb") as f: + file_stream.seek(0) + f.write(file_stream.read()) + + file_stream.seek(0) + + content_type = file.content_type + + file_extension = get_file_extension_from_mime_type(content_type) + unique_filename = generate_random_filename(file_extension) + + destination_blob_name = f"stt-voice-notes/{unique_filename}" + + await upload_file_to_gcs( + GCS_SPEECH_BUCKET, file_stream, destination_blob_name, content_type + ) + + if CUSTOM_SPEECH_ENDPOINT is not None: + transcription = await post_to_speech(file_path, CUSTOM_SPEECH_ENDPOINT) + transcription_result = transcription["text"] + + else: + transcription_result = await transcribe_audio(file_path) + + user_query = QueryBase( + generate_llm_response=True, + query_text=transcription_result, + query_metadata={}, + ) ( user_query_db, @@ -228,8 +217,9 @@ async def search( user_id=user_db.user_id, user_query=user_query, asession=asession, - generate_tts=False, + generate_tts=True, ) + response = await get_search_response( query_refined=user_query_refined_template, response=response_template, @@ -253,13 +243,17 @@ async def search( ) await save_content_for_query_to_db( user_id=user_db.user_id, - session_id=user_query.session_id, query_id=response.query_id, + session_id=user_query.session_id, contents=response.search_results, asession=asession, ) - if type(response) is QueryResponse: + if os.path.exists(file_path): + os.remove(file_path) + file_stream.close() + + if type(response) is QueryAudioResponse: return response elif type(response) is QueryResponseError: return JSONResponse( @@ -268,7 +262,7 @@ async def search( else: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"message": "Internal server error"}, + content={"error": "Internal server error"}, ) diff --git a/core_backend/app/question_answer/speech_components/utils.py b/core_backend/app/question_answer/speech_components/utils.py index 7fdcb1372..2947f22f9 100644 --- a/core_backend/app/question_answer/speech_components/utils.py +++ b/core_backend/app/question_answer/speech_components/utils.py @@ -1,9 +1,10 @@ import os +from fastapi import HTTPException from pydub import AudioSegment from ...llm_call.llm_prompts import IdentifiedLanguage -from ...utils import setup_logger +from ...utils import get_http_client, setup_logger logger = setup_logger("Voice utils") @@ -80,3 +81,16 @@ def set_wav_specifications(wav_filename: str) -> str: logger.info(f"Updated file created: {updated_wav_filename}") return updated_wav_filename + + +async def post_to_speech(file_path: str, endpoint_url: str) -> dict: + """ + Post request the file to the speech endpoint to get the transcription + """ + async with get_http_client() as client: + async with client.post(endpoint_url, json={"file_path": file_path}) as response: + if response.status != 200: + error_content = await response.json() + logger.error(f"Error from CUSTOM_SPEECH_ENDPOINT: {error_content}") + raise HTTPException(status_code=response.status, detail=error_content) + return await response.json() From 294b796313e44534817e25d9c48eeeac324a2401 Mon Sep 17 00:00:00 2001 From: Suzin You <7042047+suzinyou@users.noreply.github.com> Date: Tue, 27 Aug 2024 00:18:13 +0900 Subject: [PATCH 8/8] add comment back in --- core_backend/app/question_answer/routers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 3238c9e2a..0c61bce34 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -311,7 +311,7 @@ async def get_search_response( search_results = await get_similar_content_async( user_id=user_id, - question=query_refined.query_text, + question=query_refined.query_text, # use latest transformed version of the text n_similar=n_similar, asession=asession, metadata=metadata,