diff --git a/.secrets.baseline b/.secrets.baseline index 0d8db7530..5c702eb6f 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -231,15 +231,6 @@ "line_number": 63 } ], - "core_backend/app/question_answer/schemas.py": [ - { - "type": "Secret Keyword", - "filename": "core_backend/app/question_answer/schemas.py", - "hashed_secret": "5b8b7a620e54e681c584f5b5c89152773c10c253", - "is_verified": false, - "line_number": 67 - } - ], "core_backend/migrations/versions/2023_09_16_c5a948963236_create_query_table.py": [ { "type": "Hex High Entropy String", @@ -430,7 +421,7 @@ "filename": "core_backend/tests/api/test_dashboard_overview.py", "hashed_secret": "6367c48dd193d56ea7b0baad25b19455e529f5ee", "is_verified": false, - "line_number": 291 + "line_number": 290 } ], "core_backend/tests/api/test_dashboard_performance.py": [ @@ -590,5 +581,5 @@ } ] }, - "generated_at": "2024-08-23T09:41:17Z" + "generated_at": "2024-08-26T14:03:01Z" } diff --git a/core_backend/add_dummy_data_to_db.py b/core_backend/add_dummy_data_to_db.py index 1116a5810..bc8f4f075 100644 --- a/core_backend/add_dummy_data_to_db.py +++ b/core_backend/add_dummy_data_to_db.py @@ -185,7 +185,7 @@ def create_query_record(dt: datetime, session: Session) -> QueryDB: query_db = QueryDB( user_id=_USER_ID, - session_id=1, + session_id="1", feedback_secret_key="abc123", # pragma: allowlist secret query_text="test query", query_generate_llm_response=False, @@ -198,7 +198,7 @@ def create_query_record(dt: datetime, session: Session) -> QueryDB: def create_response_feedback_record( - dt: datetime, query_id: int, session_id: int, is_negative: bool, session: Session + dt: datetime, query_id: int, session_id: str, is_negative: bool, session: Session ) -> None: """Create a feedback record for a given datetime. @@ -209,7 +209,7 @@ def create_response_feedback_record( query_id The ID of the query record. session_id - The ID of the session record. + The ID of the session record -- uuid is_negative Specifies whether the feedback is negative. session @@ -235,7 +235,7 @@ def create_response_feedback_record( def create_content_feedback_record( dt: datetime, query_id: int, - session_id: int, + session_id: str, is_negative: bool, session: Session, ) -> None: @@ -248,7 +248,7 @@ def create_content_feedback_record( query_id The ID of the query record. session_id - The ID of the session record. + The ID of the session record. (uuid) is_negative Specifies whether the content feedback is negative. session diff --git a/core_backend/app/contents/schemas.py b/core_backend/app/contents/schemas.py index 9ea35a0cb..164b41592 100644 --- a/core_backend/app/contents/schemas.py +++ b/core_backend/app/contents/schemas.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import Annotated, List +from typing import List -from pydantic import BaseModel, ConfigDict, Field, StringConstraints +from pydantic import BaseModel, ConfigDict, Field class ContentCreate(BaseModel): @@ -9,14 +9,16 @@ class ContentCreate(BaseModel): Pydantic model for content creation request """ - content_title: Annotated[str, StringConstraints(max_length=150)] = Field( + content_title: str = Field( + max_length=150, examples=["Example Content Title"], ) - content_text: Annotated[str, StringConstraints(max_length=2000)] = Field( - examples=["This is an example content."] + content_text: str = Field( + max_length=2000, + examples=["This is an example content."], ) - content_tags: list = Field(default=[], examples=[[1, 4]]) - content_metadata: dict = Field(default={}, examples=[{"key": "optional_value"}]) + content_tags: list = Field(default=[]) + content_metadata: dict = Field(default={}) is_archived: bool = False model_config = ConfigDict( diff --git a/core_backend/app/llm_call/llm_rag.py b/core_backend/app/llm_call/llm_rag.py index 349761aad..3f2d839d6 100644 --- a/core_backend/app/llm_call/llm_rag.py +++ b/core_backend/app/llm_call/llm_rag.py @@ -2,8 +2,6 @@ Augmented Generation (RAG). """ -from typing import Optional - from pydantic import ValidationError from ..config import LITELLM_MODEL_GENERATION @@ -15,23 +13,27 @@ async def get_llm_rag_answer( - question: str, + question: str | list[dict[str, str]], context: str, original_language: IdentifiedLanguage, - metadata: Optional[dict] = None, + metadata: dict | None = None, + chat_history: list[dict[str, str]] | None = None, ) -> RAG: """Get an answer from the LLM model using RAG. Parameters ---------- question - The question to ask the LLM model. + The question to ask the LLM model, or list of chat history messages in the form + of {"content": str, "role": str}. context The context to provide to the LLM model. response_language The language of the response. metadata Additional metadata to provide to the LLM model. + chat_history + The previous chat history to provide to the LLM model if it exists. Returns ------- @@ -45,6 +47,7 @@ async def get_llm_rag_answer( result = await _ask_llm_async( user_message=question, system_message=prompt, + chat_history=chat_history, litellm_model=LITELLM_MODEL_GENERATION, metadata=metadata, json=True, diff --git a/core_backend/app/llm_call/process_input.py b/core_backend/app/llm_call/process_input.py index 56b793ccc..ce65bc602 100644 --- a/core_backend/app/llm_call/process_input.py +++ b/core_backend/app/llm_call/process_input.py @@ -121,6 +121,7 @@ def _process_identified_language_response( error_response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, @@ -206,6 +207,7 @@ async def _translate_question( else: error_response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, @@ -275,6 +277,7 @@ async def _classify_safety( else: error_response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, @@ -352,6 +355,7 @@ async def _paraphrase_question( else: error_response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, diff --git a/core_backend/app/llm_call/process_output.py b/core_backend/app/llm_call/process_output.py index 5e7d32a1d..8503411ef 100644 --- a/core_backend/app/llm_call/process_output.py +++ b/core_backend/app/llm_call/process_output.py @@ -45,39 +45,7 @@ class AlignScoreData(TypedDict): claim: str -def generate_llm_response__after(func: Callable) -> Callable: - """ - Decorator to generate the LLM response. - - Only runs if the generate_llm_response flag is set to True. - Requires "search_results" and "original_language" in the response. - """ - - @wraps(func) - async def wrapper( - query_refined: QueryRefined, - response: QueryResponse | QueryResponseError, - *args: Any, - **kwargs: Any, - ) -> QueryResponse | QueryResponseError: - """ - Generate the LLM response - """ - response = await func(query_refined, response, *args, **kwargs) - - if not query_refined.generate_llm_response: - return response - - metadata = create_langfuse_metadata( - query_id=response.query_id, user_id=query_refined.user_id - ) - response = await _generate_llm_response(query_refined, response, metadata) - return response - - return wrapper - - -async def _generate_llm_response( +async def generate_llm_query_response( query_refined: QueryRefined, response: QueryResponse, metadata: Optional[dict] = None, @@ -99,12 +67,13 @@ async def _generate_llm_response( return response context = get_context_string_from_search_results(response.search_results) + rag_response = await get_llm_rag_answer( - # use the original query text question=query_refined.query_text_original, context=context, original_language=query_refined.original_language, metadata=metadata, + chat_history=response.chat_history, ) if rag_response.answer != RAG_FAILURE_MESSAGE: @@ -114,6 +83,7 @@ async def _generate_llm_response( else: response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=None, search_results=response.search_results, @@ -219,6 +189,7 @@ async def _check_align_score( ) response = QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=None, search_results=response.search_results, @@ -311,6 +282,7 @@ async def wrapper( ) response = QueryAudioResponse( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, @@ -361,6 +333,7 @@ async def _generate_tts_response( logger.error(f"Error generating TTS for query_id {response.query_id}: {e}") return QueryResponseError( query_id=response.query_id, + session_id=response.session_id, feedback_secret_key=response.feedback_secret_key, llm_response=response.llm_response, search_results=response.search_results, diff --git a/core_backend/app/llm_call/utils.py b/core_backend/app/llm_call/utils.py index 2f8a360c9..0ee4969e9 100644 --- a/core_backend/app/llm_call/utils.py +++ b/core_backend/app/llm_call/utils.py @@ -1,5 +1,3 @@ -from typing import Optional - from litellm import acompletion from ..config import LITELLM_API_KEY, LITELLM_ENDPOINT, LITELLM_MODEL_DEFAULT @@ -9,11 +7,12 @@ async def _ask_llm_async( - user_message: str, + user_message: str | list[dict[str, str]], system_message: str, - litellm_model: Optional[str] = LITELLM_MODEL_DEFAULT, - litellm_endpoint: Optional[str] = LITELLM_ENDPOINT, - metadata: Optional[dict] = None, + chat_history: list[dict[str, str]] | None = None, + litellm_model: str | None = LITELLM_MODEL_DEFAULT, + litellm_endpoint: str | None = LITELLM_ENDPOINT, + metadata: dict | None = None, json: bool = False, ) -> str: """ @@ -36,6 +35,10 @@ async def _ask_llm_async( "role": "user", }, ] + + if chat_history is not None: + messages = messages[:1] + chat_history + messages[1:] + logger.info(f"LLM input: 'model': {litellm_model}, 'endpoint': {litellm_endpoint}") llm_response_raw = await acompletion( diff --git a/core_backend/app/question_answer/models.py b/core_backend/app/question_answer/models.py index 1c62fc1ab..22065af5c 100644 --- a/core_backend/app/question_answer/models.py +++ b/core_backend/app/question_answer/models.py @@ -52,7 +52,7 @@ class QueryDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("user.user_id"), nullable=False ) - session_id: Mapped[int] = mapped_column(Integer, nullable=True) + session_id: Mapped[str] = mapped_column(String, nullable=True) feedback_secret_key: Mapped[str] = mapped_column(String, nullable=False) query_text: Mapped[str] = mapped_column(String, nullable=False) query_generate_llm_response: Mapped[bool] = mapped_column(Boolean, nullable=False) @@ -166,7 +166,7 @@ class QueryResponseDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("user.user_id"), nullable=False ) - session_id: Mapped[int] = mapped_column(Integer, nullable=True) + session_id: Mapped[str] = mapped_column(String, nullable=True) search_results: Mapped[JSONDict] = mapped_column(JSON, nullable=False) tts_filepath: Mapped[str] = mapped_column(String, nullable=True) llm_response: Mapped[str] = mapped_column(String, nullable=True) @@ -261,6 +261,53 @@ async def save_query_response_to_db( return user_query_responses_db +async def check_user_and_session_id( + user_id: int, session_id: str, asession: AsyncSession +) -> bool: + """Check if the user_id and session_id are valid.""" + stmt = select(QueryDB).where( + QueryDB.user_id == user_id, QueryDB.session_id == session_id + ) + query_result = await asession.execute(stmt) + return query_result.scalar() is not None + + +async def get_session_history( + user_id: int, session_id: str | None, asession: AsyncSession +) -> list[dict[str, str]]: + """Get the session history for a given session ID.""" + if session_id is None: + raise ValueError("Session ID cannot be None.") + query_stmt = ( + select(QueryDB) + .where(QueryDB.user_id == user_id, QueryDB.session_id == session_id) + .order_by(QueryDB.query_datetime_utc) + ) + query_results = (await asession.execute(query_stmt)).scalars().all() + + response_stmt = select(QueryResponseDB).where( + QueryResponseDB.user_id == user_id, QueryResponseDB.session_id == session_id + ) + response_results = (await asession.execute(response_stmt)).scalars().all() + response_dict = {response.query_id: response for response in response_results} + + messages = [] + for query in query_results: + if query.query_id in response_dict: + response = response_dict[query.query_id] + # use original query for history since that's what llms see + messages.append( + {"content": response.debug_info["original_query"], "role": "user"} + ) + if response.llm_response is not None: + messages.append({"content": response.llm_response, "role": "assistant"}) + + else: # this shouldn't happen, but for completeness + messages.append({"content": query.query_text, "role": "user"}) + + return messages + + class QueryResponseContentDB(Base): """ ORM for storing what content was returned for a given query. @@ -275,7 +322,7 @@ class QueryResponseContentDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("user.user_id"), nullable=False ) - session_id: Mapped[int] = mapped_column(Integer, nullable=True) + session_id: Mapped[str] = mapped_column(String, nullable=True) query_id: Mapped[int] = mapped_column( Integer, ForeignKey("query.query_id"), nullable=False ) @@ -312,7 +359,7 @@ def __repr__(self) -> str: async def save_content_for_query_to_db( user_id: int, - session_id: int | None, + session_id: str | None, query_id: int, contents: dict[int, QuerySearchResult] | None, asession: AsyncSession, @@ -355,7 +402,7 @@ class ResponseFeedbackDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("user.user_id"), nullable=False ) - session_id: Mapped[int] = mapped_column(Integer, nullable=True) + session_id: Mapped[str] = mapped_column(String, nullable=True) feedback_text: Mapped[str] = mapped_column(String, nullable=True) feedback_datetime_utc: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False @@ -435,7 +482,7 @@ class ContentFeedbackDB(Base): user_id: Mapped[int] = mapped_column( Integer, ForeignKey("user.user_id"), nullable=False ) - session_id: Mapped[int] = mapped_column(Integer, nullable=True) + session_id: Mapped[str] = mapped_column(String, nullable=True) feedback_text: Mapped[str] = mapped_column(String, nullable=True) feedback_datetime_utc: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False diff --git a/core_backend/app/question_answer/routers.py b/core_backend/app/question_answer/routers.py index 660243ebd..095206bc8 100644 --- a/core_backend/app/question_answer/routers.py +++ b/core_backend/app/question_answer/routers.py @@ -6,7 +6,7 @@ from io import BytesIO from typing import Tuple -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status +from fastapi import APIRouter, Depends, File, UploadFile, status from fastapi.responses import JSONResponse from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession @@ -27,15 +27,15 @@ ) from ..llm_call.process_output import ( check_align_score__after, - generate_llm_response__after, + generate_llm_query_response, generate_tts__after, ) from ..users.models import UserDB from ..utils import ( create_langfuse_metadata, generate_random_filename, + generate_secret_key, get_file_extension_from_mime_type, - get_http_client, setup_logger, upload_file_to_gcs, ) @@ -43,6 +43,7 @@ from .models import ( QueryDB, check_secret_key_match, + get_session_history, save_content_feedback_to_db, save_content_for_query_to_db, save_query_response_to_db, @@ -59,6 +60,8 @@ ResponseFeedbackBase, ) from .speech_components.external_voice_components import transcribe_audio +from .speech_components.utils import post_to_speech +from .utils import format_session_history_as_query logger = setup_logger() @@ -77,62 +80,27 @@ @router.post( - "/voice-search", - response_model=QueryAudioResponse, + "/search", + response_model=QueryResponse, responses={ status.HTTP_400_BAD_REQUEST: { "model": QueryResponseError, - "description": "Bad Request", - }, - status.HTTP_500_INTERNAL_SERVER_ERROR: { - "model": QueryResponseError, - "description": "Internal Server Error", - }, + "description": "Guardrail failure", + } }, ) -async def voice_search( - file: UploadFile = File(...), +async def search( + user_query: QueryBase, asession: AsyncSession = Depends(get_async_session), user_db: UserDB = Depends(authenticate_key), -) -> QueryAudioResponse | JSONResponse: - """ - Endpoint to transcribe audio from the provided file, - generate an LLM response, by default generate_tts is - set to true and return a public signed URL of an audio - file containing the spoken version of the generated response +) -> QueryResponse | JSONResponse: """ - file_stream = BytesIO(await file.read()) - - file_path = f"temp/{file.filename}" - with open(file_path, "wb") as f: - file_stream.seek(0) - f.write(file_stream.read()) - - file_stream.seek(0) - - content_type = file.content_type - - file_extension = get_file_extension_from_mime_type(content_type) - unique_filename = generate_random_filename(file_extension) - - destination_blob_name = f"stt-voice-notes/{unique_filename}" - - await upload_file_to_gcs( - GCS_SPEECH_BUCKET, file_stream, destination_blob_name, content_type - ) - - if CUSTOM_SPEECH_ENDPOINT is not None: - transcription = await post_to_speech(file_path, CUSTOM_SPEECH_ENDPOINT) - transcription_result = transcription["text"] - - else: - transcription_result = await transcribe_audio(file_path) + Search endpoint finds the most similar content to the user query and optionally + generates a single-turn LLM response. - user_query = QueryBase( - generate_llm_response=True, - query_text=transcription_result, - query_metadata={}, - ) + If any guardrails fail, the embeddings search is still done and an error 400 is + returned that includes the search results as well as the details of the failure. + """ ( user_query_db, @@ -142,10 +110,9 @@ async def voice_search( user_id=user_db.user_id, user_query=user_query, asession=asession, - generate_tts=True, + generate_tts=False, ) - - response = await search_base( + response = await get_search_response( query_refined=user_query_refined_template, response=response_template, user_id=user_db.user_id, @@ -153,6 +120,13 @@ async def voice_search( asession=asession, exclude_archived=True, ) + + if user_query.generate_llm_response: + response = await get_generation_response( + query_refined=user_query_refined_template, + response=response, + ) + await save_query_response_to_db(user_query_db, response, asession) await increment_query_count( user_id=user_db.user_id, @@ -161,17 +135,13 @@ async def voice_search( ) await save_content_for_query_to_db( user_id=user_db.user_id, - query_id=response.query_id, session_id=user_query.session_id, + query_id=response.query_id, contents=response.search_results, asession=asession, ) - if os.path.exists(file_path): - os.remove(file_path) - file_stream.close() - - if type(response) is QueryAudioResponse: + if type(response) is QueryResponse: return response elif type(response) is QueryResponseError: return JSONResponse( @@ -180,25 +150,12 @@ async def voice_search( else: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={"error": "Internal server error"}, + content={"message": "Internal server error"}, ) -async def post_to_speech(file_path: str, endpoint_url: str) -> dict: - """ - Post request the file to the speech endpoint to get the transcription - """ - async with get_http_client() as client: - async with client.post(endpoint_url, json={"file_path": file_path}) as response: - if response.status != 200: - error_content = await response.json() - logger.error(f"Error from CUSTOM_SPEECH_ENDPOINT: {error_content}") - raise HTTPException(status_code=response.status, detail=error_content) - return await response.json() - - @router.post( - "/search", + "/chat", response_model=QueryResponse, responses={ status.HTTP_400_BAD_REQUEST: { @@ -207,18 +164,17 @@ async def post_to_speech(file_path: str, endpoint_url: str) -> dict: } }, ) -async def search( +async def chat( user_query: QueryBase, asession: AsyncSession = Depends(get_async_session), user_db: UserDB = Depends(authenticate_key), ) -> QueryResponse | JSONResponse: - """ - Search endpoint finds the most similar content to the user query and optionally - generates a single-turn LLM response. + """ """ + is_new_chat = user_query.session_id is None - If any guardrails fail, the embeddings search is still done and an error 400 is - returned that includes the search results as well as the details of the failure. - """ + if is_new_chat: + session_id = generate_secret_key() + user_query.session_id = session_id ( user_query_db, @@ -230,21 +186,30 @@ async def search( asession=asession, generate_tts=False, ) - response = await search_base( + + response = await get_search_response( query_refined=user_query_refined_template, response=response_template, user_id=user_db.user_id, n_similar=int(N_TOP_CONTENT), asession=asession, exclude_archived=True, + is_chat=True, + ) + + response = await get_generation_response( + query_refined=user_query_refined_template, + response=response, ) await save_query_response_to_db(user_query_db, response, asession) + await increment_query_count( user_id=user_db.user_id, contents=response.search_results, asession=asession, ) + await save_content_for_query_to_db( user_id=user_db.user_id, session_id=user_query.session_id, @@ -266,20 +231,134 @@ async def search( ) +@router.post( + "/voice-search", + response_model=QueryAudioResponse, + responses={ + status.HTTP_400_BAD_REQUEST: { + "model": QueryResponseError, + "description": "Bad Request", + }, + status.HTTP_500_INTERNAL_SERVER_ERROR: { + "model": QueryResponseError, + "description": "Internal Server Error", + }, + }, +) +async def voice_search( + file: UploadFile = File(...), + asession: AsyncSession = Depends(get_async_session), + user_db: UserDB = Depends(authenticate_key), +) -> QueryAudioResponse | JSONResponse: + """ + Endpoint to transcribe audio from the provided file, + generate an LLM response, by default generate_tts is + set to true and return a public signed URL of an audio + file containing the spoken version of the generated response + """ + file_stream = BytesIO(await file.read()) + + file_path = f"temp/{file.filename}" + with open(file_path, "wb") as f: + file_stream.seek(0) + f.write(file_stream.read()) + + file_stream.seek(0) + + content_type = file.content_type + + file_extension = get_file_extension_from_mime_type(content_type) + unique_filename = generate_random_filename(file_extension) + + destination_blob_name = f"stt-voice-notes/{unique_filename}" + + await upload_file_to_gcs( + GCS_SPEECH_BUCKET, file_stream, destination_blob_name, content_type + ) + + if CUSTOM_SPEECH_ENDPOINT is not None: + transcription = await post_to_speech(file_path, CUSTOM_SPEECH_ENDPOINT) + transcription_result = transcription["text"] + + else: + transcription_result = await transcribe_audio(file_path) + + user_query = QueryBase( + generate_llm_response=True, + query_text=transcription_result, + query_metadata={}, + ) + + ( + user_query_db, + user_query_refined_template, + response_template, + ) = await get_user_query_and_response( + user_id=user_db.user_id, + user_query=user_query, + asession=asession, + generate_tts=True, + ) + + response = await get_search_response( + query_refined=user_query_refined_template, + response=response_template, + user_id=user_db.user_id, + n_similar=int(N_TOP_CONTENT), + asession=asession, + exclude_archived=True, + is_chat=False, + ) + + if user_query.generate_llm_response: + response = await get_generation_response( + query_refined=user_query_refined_template, + response=response, + ) + + await save_query_response_to_db(user_query_db, response, asession) + await increment_query_count( + user_id=user_db.user_id, + contents=response.search_results, + asession=asession, + ) + await save_content_for_query_to_db( + user_id=user_db.user_id, + query_id=response.query_id, + session_id=user_query.session_id, + contents=response.search_results, + asession=asession, + ) + + if os.path.exists(file_path): + os.remove(file_path) + file_stream.close() + + if type(response) is QueryAudioResponse: + return response + elif type(response) is QueryResponseError: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump() + ) + else: + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"error": "Internal server error"}, + ) + + @identify_language__before @classify_safety__before @translate_question__before @paraphrase_question__before -@generate_tts__after -@generate_llm_response__after -@check_align_score__after -async def search_base( +async def get_search_response( query_refined: QueryRefined, response: QueryResponse, user_id: int, n_similar: int, asession: AsyncSession, exclude_archived: bool = True, + is_chat: bool = False, ) -> QueryResponse | QueryResponseError: """Get similar content and construct the LLM answer for the user query. @@ -308,13 +387,26 @@ async def search_base( An appropriate query response object. """ - - # always do the embeddings search even if some guardrails have failed + # No checks for errors: + # always do the embeddings search even if some guardrails have failed metadata = create_langfuse_metadata(query_id=response.query_id, user_id=user_id) + if is_chat: + # May return empty list if no chat history. + # Ideally want to skip if it's a new chat. + messages = await get_session_history( + user_id=query_refined.user_id, + session_id=query_refined.session_id, + asession=asession, + ) + question = format_session_history_as_query(messages) + response.chat_history = messages + else: + # use latest transformed version of the text + question = query_refined.query_text + search_results = await get_similar_content_async( user_id=user_id, - # use latest version of the text - question=query_refined.query_text, + question=question, n_similar=n_similar, asession=asession, metadata=metadata, @@ -325,6 +417,34 @@ async def search_base( return response +@generate_tts__after +@check_align_score__after +async def get_generation_response( + query_refined: QueryRefined, + response: QueryResponse, +) -> QueryResponse | QueryResponseError: + """ + Generate a response using an LLM given a query with search results. + + Only runs if the generate_llm_response flag is set to True. + Requires "search_results" and "original_language" in the response. + + If it's a chat query, the chat history is passed to the LLM model. + """ + if not query_refined.generate_llm_response: + return response + + metadata = create_langfuse_metadata( + query_id=response.query_id, user_id=query_refined.user_id + ) + + response = await generate_llm_query_response( + query_refined=query_refined, response=response, metadata=metadata + ) + + return response + + async def get_user_query_and_response( user_id: int, user_query: QueryBase, diff --git a/core_backend/app/question_answer/schemas.py b/core_backend/app/question_answer/schemas.py index 128b1792f..7437e705e 100644 --- a/core_backend/app/question_answer/schemas.py +++ b/core_backend/app/question_answer/schemas.py @@ -1,7 +1,8 @@ from enum import Enum -from typing import Dict, Optional +from typing import Dict from pydantic import BaseModel, ConfigDict, Field +from pydantic.json_schema import SkipJsonSchema from ..llm_call.llm_prompts import IdentifiedLanguage from ..schemas import FeedbackSentiment, QuerySearchResult @@ -13,9 +14,10 @@ class QueryBase(BaseModel): """ query_text: str = Field(..., examples=["What is AAQ?"]) - session_id: Optional[int] = None generate_llm_response: bool = Field(False) query_metadata: dict = Field({}, examples=[{"some_key": "some_value"}]) + # TODO: create SearchQueryBase and ChatQueryBase and set QueryBase as a union + session_id: SkipJsonSchema[str | None] = Field(default=None, exclude=True) model_config = ConfigDict(from_attributes=True) @@ -54,38 +56,32 @@ class QueryResponse(BaseModel): """ query_id: int = Field(..., examples=[1]) - session_id: Optional[int] = None + session_id: str | None = Field(None, examples=["session-id-12345-abcde"]) feedback_secret_key: str = Field(..., examples=["secret-key-12345-abcde"]) llm_response: str | None = Field(None, examples=["Example LLM response"]) search_results: Dict[int, QuerySearchResult] | None = Field( - None, examples=[ { - "query_id": 1, - "session_id": 1, - "feedback_secret_key": "secret-key-12345-abcde", - "llm_response": "Example LLM response " - "(null if generate_llm_response is false)", - "search_results": { - "0": { - "title": "Example content title", - "text": "Example content text", - "id": 23, - "distance": 0.1, - }, - "1": { - "title": "Another example content title", - "text": "Another example content text", - "id": 12, - "distance": 0.2, - }, + "0": { + "title": "Example content title", + "text": "Example content text", + "id": 23, + "distance": 0.1, + }, + "1": { + "title": "Another example content title", + "text": "Another example content text", + "id": 12, + "distance": 0.2, }, - "debug_info": {"example": "debug-info"}, } ], ) debug_info: dict = Field({}, examples=[{"example": "debug-info"}]) + chat_history: list[dict[str, str]] | None = Field( + None, examples=[{"role": "user", "content": "Hello"}] + ) model_config = ConfigDict(from_attributes=True) @@ -110,7 +106,7 @@ class QueryResponseError(QueryResponse): """ error_type: ErrorType = Field(..., examples=["example_error"]) - error_message: Optional[str] = Field(None, examples=["Example error message"]) + error_message: str | None = Field(None, examples=["Example error message"]) model_config = ConfigDict(from_attributes=True) @@ -122,11 +118,11 @@ class ResponseFeedbackBase(BaseModel): """ query_id: int = Field(..., examples=[1]) - session_id: Optional[int] = None + session_id: SkipJsonSchema[str | None] = None feedback_sentiment: FeedbackSentiment = Field( FeedbackSentiment.UNKNOWN, examples=["positive"] ) - feedback_text: Optional[str] = Field(None, examples=["This is helpful"]) + feedback_text: str | None = Field(None, examples=["This is helpful"]) feedback_secret_key: str = Field(..., examples=["secret-key-12345-abcde"]) model_config = ConfigDict(from_attributes=True) diff --git a/core_backend/app/question_answer/speech_components/utils.py b/core_backend/app/question_answer/speech_components/utils.py index 7fdcb1372..2947f22f9 100644 --- a/core_backend/app/question_answer/speech_components/utils.py +++ b/core_backend/app/question_answer/speech_components/utils.py @@ -1,9 +1,10 @@ import os +from fastapi import HTTPException from pydub import AudioSegment from ...llm_call.llm_prompts import IdentifiedLanguage -from ...utils import setup_logger +from ...utils import get_http_client, setup_logger logger = setup_logger("Voice utils") @@ -80,3 +81,16 @@ def set_wav_specifications(wav_filename: str) -> str: logger.info(f"Updated file created: {updated_wav_filename}") return updated_wav_filename + + +async def post_to_speech(file_path: str, endpoint_url: str) -> dict: + """ + Post request the file to the speech endpoint to get the transcription + """ + async with get_http_client() as client: + async with client.post(endpoint_url, json={"file_path": file_path}) as response: + if response.status != 200: + error_content = await response.json() + logger.error(f"Error from CUSTOM_SPEECH_ENDPOINT: {error_content}") + raise HTTPException(status_code=response.status, detail=error_content) + return await response.json() diff --git a/core_backend/app/question_answer/utils.py b/core_backend/app/question_answer/utils.py index a1ab8a666..3d9b7e323 100644 --- a/core_backend/app/question_answer/utils.py +++ b/core_backend/app/question_answer/utils.py @@ -16,3 +16,9 @@ def get_context_string_from_search_results( context_list.append(f"{key}. {result.title}\n{result.text}") context_string = "\n\n".join(context_list) return context_string + + +def format_session_history_as_query(messages: list[dict[str, str]]) -> str: + """Format the session history as a query.""" + history = "\n".join([message["content"] for message in messages]) + return history diff --git a/core_backend/migrations/versions/2024_08_26_1bbcc2584fde_change_session_id_to_string.py b/core_backend/migrations/versions/2024_08_26_1bbcc2584fde_change_session_id_to_string.py new file mode 100644 index 000000000..e839e56a4 --- /dev/null +++ b/core_backend/migrations/versions/2024_08_26_1bbcc2584fde_change_session_id_to_string.py @@ -0,0 +1,41 @@ +"""change session_id to string + +Revision ID: 1bbcc2584fde +Revises: c571cf9aae63 +Create Date: 2024-08-26 19:40:08.259316 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "1bbcc2584fde" # pragma: allowlist secret +down_revision: Union[str, None] = "c571cf9aae63" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +tables_with_session_id = [ + "query", + "query_response", + "query_response_feedback", + "query_response_content", + "content_feedback", +] + + +def upgrade() -> None: + for table in tables_with_session_id: + op.alter_column( + table, "session_id", existing_type=sa.Integer(), type_=sa.String() + ) + + +def downgrade() -> None: + for table in tables_with_session_id: + # TODO: if the session_id can't be casted, set it to null + op.alter_column( + table, "session_id", existing_type=sa.String(), type_=sa.Integer() + ) diff --git a/core_backend/tests/api/test_dashboard_overview.py b/core_backend/tests/api/test_dashboard_overview.py index 54e85d5de..45773ed3d 100644 --- a/core_backend/tests/api/test_dashboard_overview.py +++ b/core_backend/tests/api/test_dashboard_overview.py @@ -287,7 +287,6 @@ async def queries( for i in range(count): query = QueryDB( user_id=1, - session_id=1, feedback_secret_key="abc123", query_text=f"test_{day}_{i}", query_generate_llm_response=False, @@ -321,7 +320,6 @@ async def queries_hour(self, asession: AsyncSession) -> AsyncGenerator[None, Non for i in range(count): query = QueryDB( user_id=1, - session_id=1, feedback_secret_key="abc123", query_text=f"test_{hour}_{i}", query_generate_llm_response=False, @@ -512,7 +510,6 @@ async def create_query_and_feedback( for i in range(n_positive + n_negative + n_neutral): query = QueryDB( user_id=user_id, - session_id=1, feedback_secret_key="abc123", query_text="test message", query_generate_llm_response=False,