Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup dependent default resolution #465

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
introduce engine for API (#434)
pmeier authored Jun 26, 2024
commit 4d37e33859f56c0a7b4c1b8e9bcae9b1d5eb2867
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -162,7 +162,7 @@ ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
"ragna.deploy._api.orm",
"ragna.deploy._orm",
]
# Our ORM schema doesn't really work with mypy. There are some other ways to define it
# to play ball. We should do that in the future.
12 changes: 12 additions & 0 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import abc
import datetime
import enum
import functools
import inspect
import uuid
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union

import pydantic
@@ -157,6 +159,8 @@ def __init__(
*,
role: MessageRole = MessageRole.SYSTEM,
sources: Optional[list[Source]] = None,
id: Optional[uuid.UUID] = None,
timestamp: Optional[datetime.datetime] = None,
) -> None:
if isinstance(content, str):
self._content: str = content
@@ -166,6 +170,14 @@ def __init__(
self.role = role
self.sources = sources or []

if id is None:
id = uuid.uuid4()
self.id = id

if timestamp is None:
timestamp = datetime.datetime.utcnow()
self.timestamp = timestamp

async def __aiter__(self) -> AsyncIterator[str]:
if hasattr(self, "_content"):
yield self._content
87 changes: 70 additions & 17 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
import inspect
import uuid
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
@@ -12,21 +13,24 @@
Iterable,
Iterator,
Optional,
Type,
TypeVar,
Union,
cast,
)

import pydantic
from fastapi import status
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
from ._document import Document, LocalDocument
from ._utils import RagnaException, default_user, merge_models

if TYPE_CHECKING:
from ragna.deploy import Config

T = TypeVar("T")
C = TypeVar("C", bound=Component)
C = TypeVar("C", bound=Component, covariant=True)


class Rag(Generic[C]):
@@ -41,20 +45,69 @@ class Rag(Generic[C]):
```
"""

def __init__(self) -> None:
self._components: dict[Type[C], C] = {}
def __init__(
self,
*,
config: Optional[Config] = None,
ignore_unavailable_components: bool = False,
) -> None:
self._components: dict[type[C], C] = {}
self._display_name_map: dict[str, type[C]] = {}

if config is not None:
self._preload_components(
config=config,
ignore_unavailable_components=ignore_unavailable_components,
)

def _preload_components(
self, *, config: Config, ignore_unavailable_components: bool
) -> None:
for components in [config.source_storages, config.assistants]:
components = cast(list[type[Component]], components)
at_least_one = False
for component in components:
loaded_component = self._load_component(
component, # type: ignore[arg-type]
ignore_unavailable=ignore_unavailable_components,
)
if loaded_component is None:
print(
f"Ignoring {component.display_name()}, because it is not available."
)
else:
at_least_one = True

if not at_least_one:
raise RagnaException(
"No component available",
components=[component.display_name() for component in components],
)

def _load_component(
self, component: Union[Type[C], C], *, ignore_unavailable: bool = False
self, component: Union[C, type[C], str], *, ignore_unavailable: bool = False
) -> Optional[C]:
cls: Type[C]
cls: type[C]
instance: Optional[C]

if isinstance(component, Component):
instance = cast(C, component)
cls = type(instance)
elif isinstance(component, type) and issubclass(component, Component):
cls = component
instance = None
elif isinstance(component, str):
try:
cls = self._display_name_map[component]
except KeyError:
raise RagnaException(
"Unknown component",
display_name=component,
help="Did you forget to create the Rag() instance with a config?",
http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
http_detail=f"Unknown component '{component}'",
) from None

instance = None
else:
raise RagnaException
@@ -71,31 +124,33 @@ def _load_component(
instance = cls()

self._components[cls] = instance
self._display_name_map[cls.display_name()] = cls

return self._components[cls]

def chat(
self,
*,
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
source_storage: Union[SourceStorage, type[SourceStorage], str],
assistant: Union[Assistant, type[Assistant], str],
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].

Args:
documents: Documents to use. If any item is not a [ragna.core.Document][],
[ragna.core.LocalDocument.from_path][] is invoked on it.
FIXME
source_storage: Source storage to use.
assistant: Assistant to use.
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
documents=documents,
source_storage=source_storage,
assistant=assistant,
source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type]
assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type]
**params,
)

@@ -146,17 +201,15 @@ def __init__(
rag: Rag,
*,
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
source_storage: SourceStorage,
assistant: Assistant,
**params: Any,
) -> None:
self._rag = rag

self.documents = self._parse_documents(documents)
self.source_storage = cast(
SourceStorage, self._rag._load_component(source_storage)
)
self.assistant = cast(Assistant, self._rag._load_component(assistant))
self.source_storage = source_storage
self.assistant = assistant

special_params = SpecialChatParams().model_dump()
special_params.update(params)
@@ -306,6 +359,6 @@ async def __aenter__(self) -> Chat:
return self

async def __aexit__(
self, exc_type: Type[Exception], exc: Exception, traceback: str
self, exc_type: type[Exception], exc: Exception, traceback: str
) -> None:
pass
163 changes: 163 additions & 0 deletions ragna/deploy/_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import uuid
from typing import Annotated, AsyncIterator, cast

import aiofiles
import pydantic
from fastapi import (
APIRouter,
Body,
Depends,
Form,
HTTPException,
UploadFile,
)
from fastapi.responses import StreamingResponse

import ragna
import ragna.core
from ragna._compat import anext
from ragna.core._utils import default_user
from ragna.deploy import Config

from . import _schemas as schemas
from ._engine import Engine


def make_router(config: Config, engine: Engine) -> APIRouter:
router = APIRouter(tags=["API"])

def get_user() -> str:
return default_user()

UserDependency = Annotated[str, Depends(get_user)]

# TODO: the document endpoints do not go through the engine, because they'll change
# quite drastically when the UI no longer depends on the API

_database = engine._database

@router.post("/document")
async def create_document_upload_info(
user: UserDependency,
name: Annotated[str, Body(..., embed=True)],
) -> schemas.DocumentUpload:
with _database.get_session() as session:
document = schemas.Document(name=name)
metadata, parameters = await config.document.get_upload_info(
config=config, user=user, id=document.id, name=document.name
)
document.metadata = metadata
_database.add_document(
session, user=user, document=document, metadata=metadata
)
return schemas.DocumentUpload(parameters=parameters, document=document)

# TODO: Add UI support and documentation for this endpoint (#406)
@router.post("/documents")
async def create_documents_upload_info(
user: UserDependency,
names: Annotated[list[str], Body(..., embed=True)],
) -> list[schemas.DocumentUpload]:
with _database.get_session() as session:
document_metadata_collection = []
document_upload_collection = []
for name in names:
document = schemas.Document(name=name)
metadata, parameters = await config.document.get_upload_info(
config=config, user=user, id=document.id, name=document.name
)
document.metadata = metadata
document_metadata_collection.append((document, metadata))
document_upload_collection.append(
schemas.DocumentUpload(parameters=parameters, document=document)
)

_database.add_documents(
session,
user=user,
document_metadata_collection=document_metadata_collection,
)
return document_upload_collection

# TODO: Add new endpoint for batch uploading documents (#407)
@router.put("/document")
async def upload_document(
token: Annotated[str, Form()], file: UploadFile
) -> schemas.Document:
if not issubclass(config.document, ragna.core.LocalDocument):
raise HTTPException(
status_code=400,
detail="Ragna configuration does not support local upload",
)
with _database.get_session() as session:
user, id = ragna.core.LocalDocument.decode_upload_token(token)
document = _database.get_document(session, user=user, id=id)

core_document = cast(
ragna.core.LocalDocument, engine._to_core.document(document)
)
core_document.path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(core_document.path, "wb") as document_file:
while content := await file.read(1024):
await document_file.write(content)

return document

@router.get("/components")
def get_components(_: UserDependency) -> schemas.Components:
return engine.get_components()

@router.post("/chats")
async def create_chat(
user: UserDependency,
chat_metadata: schemas.ChatMetadata,
) -> schemas.Chat:
return engine.create_chat(user=user, chat_metadata=chat_metadata)

@router.get("/chats")
async def get_chats(user: UserDependency) -> list[schemas.Chat]:
return engine.get_chats(user=user)

@router.get("/chats/{id}")
async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat:
return engine.get_chat(user=user, id=id)

@router.post("/chats/{id}/prepare")
async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message:
return await engine.prepare_chat(user=user, id=id)

@router.post("/chats/{id}/answer")
async def answer(
user: UserDependency,
id: uuid.UUID,
prompt: Annotated[str, Body(..., embed=True)],
stream: Annotated[bool, Body(..., embed=True)] = False,
) -> schemas.Message:
message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt)
answer = await anext(message_stream)

if not stream:
content_chunks = [chunk.content async for chunk in message_stream]
answer.content += "".join(content_chunks)
return answer

async def message_chunks() -> AsyncIterator[schemas.Message]:
yield answer
async for chunk in message_stream:
yield chunk

async def to_jsonl(
models: AsyncIterator[pydantic.BaseModel],
) -> AsyncIterator[str]:
async for model in models:
yield f"{model.model_dump_json()}\n"

return StreamingResponse( # type: ignore[return-value]
to_jsonl(message_chunks())
)

@router.delete("/chats/{id}")
async def delete_chat(user: UserDependency, id: uuid.UUID) -> None:
engine.delete_chat(user=user, id=id)

return router
1 change: 0 additions & 1 deletion ragna/deploy/_api/__init__.py

This file was deleted.

318 changes: 0 additions & 318 deletions ragna/deploy/_api/core.py

This file was deleted.

270 changes: 0 additions & 270 deletions ragna/deploy/_api/database.py

This file was deleted.

14 changes: 7 additions & 7 deletions ragna/deploy/_core.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@

from ._api import make_router as make_api_router
from ._config import Config
from ._engine import Engine
from ._ui import app as make_ui_app
from ._utils import handle_localhost_origins, redirect, set_redirect_root_path

@@ -70,14 +71,13 @@ def server_available():
allow_headers=["*"],
)

engine = Engine(
config=config,
ignore_unavailable_components=ignore_unavailable_components,
)

if api:
app.include_router(
make_api_router(
config,
ignore_unavailable_components=ignore_unavailable_components,
),
prefix="/api",
)
app.include_router(make_api_router(config, engine), prefix="/api")

if ui:
panel_app = make_ui_app(config=config)
274 changes: 274 additions & 0 deletions ragna/deploy/_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
from __future__ import annotations

import uuid
from typing import Any, Optional
from urllib.parse import urlsplit

from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session, joinedload, sessionmaker

from ragna.core import RagnaException

from . import _orm as orm
from . import _schemas as schemas


class Database:
def __init__(self, url: str) -> None:
components = urlsplit(url)
if components.scheme == "sqlite":
connect_args = dict(check_same_thread=False)
else:
connect_args = dict()
engine = create_engine(url, connect_args=connect_args)
orm.Base.metadata.create_all(bind=engine)

self.get_session = sessionmaker(autocommit=False, autoflush=False, bind=engine)

self._to_orm = SchemaToOrmConverter()
self._to_schema = OrmToSchemaConverter()

def _get_user(self, session: Session, *, username: str) -> orm.User:
user: Optional[orm.User] = session.execute(
select(orm.User).where(orm.User.name == username)
).scalar_one_or_none()

if user is None:
# Add a new user if the current username is not registered yet. Since this
# is behind the authentication layer, we don't need any extra security here.
user = orm.User(id=uuid.uuid4(), name=username)
session.add(user)
session.commit()

return user

def add_document(
self,
session: Session,
*,
user: str,
document: schemas.Document,
metadata: dict[str, Any],
) -> None:
session.add(
orm.Document(
id=document.id,
user_id=self._get_user(session, username=user).id,
name=document.name,
metadata_=metadata,
)
)
session.commit()

def add_documents(
self,
session: Session,
*,
user: str,
document_metadata_collection: list[tuple[schemas.Document, dict[str, Any]]],
) -> None:
"""
Add multiple documents to the database.
This function allows adding multiple documents at once by calling `add_all`. This is
important when there is non-negligible latency attached to each database operation.
"""
documents = [
orm.Document(
id=document.id,
user_id=self._get_user(session, username=user).id,
name=document.name,
metadata_=metadata,
)
for document, metadata in document_metadata_collection
]
session.add_all(documents)
session.commit()

def get_document(
self, session: Session, *, user: str, id: uuid.UUID
) -> schemas.Document:
document = session.execute(
select(orm.Document).where(
(orm.Document.user_id == self._get_user(session, username=user).id)
& (orm.Document.id == id)
)
).scalar_one_or_none()
return self._to_schema.document(document)

def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None:
document_ids = {document.id for document in chat.metadata.documents}
# FIXME also check if the user is allowed to access the documents?
documents = (
session.execute(
select(orm.Document).where(orm.Document.id.in_(document_ids))
)
.scalars()
.all()
)
if len(documents) != len(document_ids):
raise RagnaException(
str(document_ids - {document.id for document in documents})
)

orm_chat = self._to_orm.chat(
chat,
user_id=self._get_user(session, username=user).id,
# We have to pass the documents here, because SQLAlchemy does not allow a
# second instance of orm.Document with the same primary key in the session.
documents=documents,
)
session.add(orm_chat)
session.commit()

def _select_chat(self, *, eager: bool = False) -> Any:
selector = select(orm.Chat)
if eager:
selector = selector.options( # type: ignore[attr-defined]
joinedload(orm.Chat.messages).joinedload(orm.Message.sources),
joinedload(orm.Chat.documents),
)
return selector

def get_chats(self, session: Session, *, user: str) -> list[schemas.Chat]:
return [
self._to_schema.chat(chat)
for chat in session.execute(
self._select_chat(eager=True).where(
orm.Chat.user_id == self._get_user(session, username=user).id
)
)
.scalars()
.unique()
.all()
]

def _get_orm_chat(
self, session: Session, *, user: str, id: uuid.UUID, eager: bool = False
) -> orm.Chat:
chat: Optional[orm.Chat] = (
session.execute(
self._select_chat(eager=eager).where(
(orm.Chat.id == id)
& (orm.Chat.user_id == self._get_user(session, username=user).id)
)
)
.unique()
.scalar_one_or_none()
)
if chat is None:
raise RagnaException()
return chat

def get_chat(self, session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat:
return self._to_schema.chat(
(self._get_orm_chat(session, user=user, id=id, eager=True))
)

def update_chat(self, session: Session, user: str, chat: schemas.Chat) -> None:
orm_chat = self._to_orm.chat(
chat, user_id=self._get_user(session, username=user).id
)
session.merge(orm_chat)
session.commit()

def delete_chat(self, session: Session, user: str, id: uuid.UUID) -> None:
orm_chat = self._get_orm_chat(session, user=user, id=id)
session.delete(orm_chat) # type: ignore[no-untyped-call]
session.commit()


class SchemaToOrmConverter:
def document(
self, document: schemas.Document, *, user_id: uuid.UUID
) -> orm.Document:
return orm.Document(
id=document.id,
user_id=user_id,
name=document.name,
metadata_=document.metadata,
)

def source(self, source: schemas.Source) -> orm.Source:
return orm.Source(
id=source.id,
document_id=source.document.id,
location=source.location,
content=source.content,
num_tokens=source.num_tokens,
)

def message(self, message: schemas.Message, *, chat_id: uuid.UUID) -> orm.Message:
return orm.Message(
id=message.id,
chat_id=chat_id,
content=message.content,
role=message.role,
sources=[self.source(source) for source in message.sources],
timestamp=message.timestamp,
)

def chat(
self,
chat: schemas.Chat,
*,
user_id: uuid.UUID,
documents: Optional[list[orm.Document]] = None,
) -> orm.Chat:
if documents is None:
documents = [
self.document(document, user_id=user_id)
for document in chat.metadata.documents
]
return orm.Chat(
id=chat.id,
user_id=user_id,
name=chat.metadata.name,
documents=documents,
source_storage=chat.metadata.source_storage,
assistant=chat.metadata.assistant,
params=chat.metadata.params,
messages=[
self.message(message, chat_id=chat.id) for message in chat.messages
],
prepared=chat.prepared,
)


class OrmToSchemaConverter:
def document(self, document: orm.Document) -> schemas.Document:
return schemas.Document(
id=document.id, name=document.name, metadata=document.metadata_
)

def source(self, source: orm.Source) -> schemas.Source:
return schemas.Source(
id=source.id,
document=self.document(source.document),
location=source.location,
content=source.content,
num_tokens=source.num_tokens,
)

def message(self, message: orm.Message) -> schemas.Message:
return schemas.Message(
id=message.id,
role=message.role, # type: ignore[arg-type]
content=message.content,
sources=[self.source(source) for source in message.sources],
timestamp=message.timestamp,
)

def chat(self, chat: orm.Chat) -> schemas.Chat:
return schemas.Chat(
id=chat.id,
metadata=schemas.ChatMetadata(
name=chat.name,
documents=[self.document(document) for document in chat.documents],
source_storage=chat.source_storage,
assistant=chat.assistant,
params=chat.params,
),
messages=[self.message(message) for message in chat.messages],
prepared=chat.prepared,
)
205 changes: 205 additions & 0 deletions ragna/deploy/_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import uuid
from typing import Any, AsyncIterator, Optional, Type

from ragna import Rag, core
from ragna._compat import aiter, anext
from ragna.core._rag import SpecialChatParams
from ragna.deploy import Config

from . import _schemas as schemas
from ._database import Database


class Engine:
def __init__(self, *, config: Config, ignore_unavailable_components: bool) -> None:
self._config = config

self._database = Database(url=config.database_url)

self._rag: Rag = Rag(
config=config,
ignore_unavailable_components=ignore_unavailable_components,
)

self._to_core = SchemaToCoreConverter(config=config, rag=self._rag)
self._to_schema = CoreToSchemaConverter()

def _get_component_json_schema(
self,
component: Type[core.Component],
) -> dict[str, dict[str, Any]]:
json_schema = component._protocol_model().model_json_schema()
# FIXME: there is likely a better way to exclude certain fields builtin in
# pydantic
for special_param in SpecialChatParams.model_fields:
if (
"properties" in json_schema
and special_param in json_schema["properties"]
):
del json_schema["properties"][special_param]
if "required" in json_schema and special_param in json_schema["required"]:
json_schema["required"].remove(special_param)
return json_schema

def get_components(self) -> schemas.Components:
return schemas.Components(
documents=sorted(self._config.document.supported_suffixes()),
source_storages=[
self._get_component_json_schema(source_storage)
for source_storage in self._rag._components.keys()
if issubclass(source_storage, core.SourceStorage)
],
assistants=[
self._get_component_json_schema(assistant)
for assistant in self._rag._components.keys()
if issubclass(assistant, core.Assistant)
],
)

def create_chat(
self, *, user: str, chat_metadata: schemas.ChatMetadata
) -> schemas.Chat:
chat = schemas.Chat(metadata=chat_metadata)

# Although we don't need the actual core.Chat here, this just performs the input
# validation.
self._to_core.chat(chat, user=user)

with self._database.get_session() as session:
self._database.add_chat(session, user=user, chat=chat)

return chat

def get_chats(self, *, user: str) -> list[schemas.Chat]:
with self._database.get_session() as session:
return self._database.get_chats(session, user=user)

def get_chat(self, *, user: str, id: uuid.UUID) -> schemas.Chat:
with self._database.get_session() as session:
return self._database.get_chat(session, user=user, id=id)

async def prepare_chat(self, *, user: str, id: uuid.UUID) -> schemas.Message:
core_chat = self._to_core.chat(self.get_chat(user=user, id=id), user=user)
core_message = await core_chat.prepare()

with self._database.get_session() as session:
self._database.update_chat(
session, chat=self._to_schema.chat(core_chat), user=user
)

return self._to_schema.message(core_message)

async def answer_stream(
self, *, user: str, chat_id: uuid.UUID, prompt: str
) -> AsyncIterator[schemas.Message]:
core_chat = self._to_core.chat(self.get_chat(user=user, id=chat_id), user=user)
core_message = await core_chat.answer(prompt, stream=True)

content_stream = aiter(core_message)
content_chunk = await anext(content_stream)
message = self._to_schema.message(core_message, content_override=content_chunk)
yield message

# Avoid sending the sources multiple times
message_chunk = message.model_copy(update=dict(sources=None))
async for content_chunk in content_stream:
message_chunk.content = content_chunk
yield message_chunk

with self._database.get_session() as session:
self._database.update_chat(
session, chat=self._to_schema.chat(core_chat), user=user
)

def delete_chat(self, *, user: str, id: uuid.UUID) -> None:
with self._database.get_session() as session:
self._database.delete_chat(session, user=user, id=id)


class SchemaToCoreConverter:
def __init__(self, *, config: Config, rag: Rag) -> None:
self._config = config
self._rag = rag

def document(self, document: schemas.Document) -> core.Document:
return self._config.document(
id=document.id,
name=document.name,
metadata=document.metadata,
)

def source(self, source: schemas.Source) -> core.Source:
return core.Source(
id=source.id,
document=self.document(source.document),
location=source.location,
content=source.content,
num_tokens=source.num_tokens,
)

def message(self, message: schemas.Message) -> core.Message:
return core.Message(
message.content,
role=message.role,
sources=[self.source(source) for source in message.sources],
)

def chat(self, chat: schemas.Chat, *, user: str) -> core.Chat:
core_chat = self._rag.chat(
documents=[self.document(document) for document in chat.metadata.documents],
source_storage=chat.metadata.source_storage,
assistant=chat.metadata.assistant,
user=user,
chat_id=chat.id,
chat_name=chat.metadata.name,
**chat.metadata.params,
)
core_chat._messages = [self.message(message) for message in chat.messages]
core_chat._prepared = chat.prepared

return core_chat


class CoreToSchemaConverter:
def document(self, document: core.Document) -> schemas.Document:
return schemas.Document(
id=document.id,
name=document.name,
metadata=document.metadata,
)

def source(self, source: core.Source) -> schemas.Source:
return schemas.Source(
id=source.id,
document=self.document(source.document),
location=source.location,
content=source.content,
num_tokens=source.num_tokens,
)

def message(
self, message: core.Message, *, content_override: Optional[str] = None
) -> schemas.Message:
return schemas.Message(
id=message.id,
content=content_override or message.content,
role=message.role,
sources=[self.source(source) for source in message.sources],
timestamp=message.timestamp,
)

def chat(self, chat: core.Chat) -> schemas.Chat:
params = chat.params.copy()
del params["user"]
return schemas.Chat(
id=params.pop("chat_id"),
metadata=schemas.ChatMetadata(
name=params.pop("chat_name"),
source_storage=chat.source_storage.display_name(),
assistant=chat.assistant.display_name(),
params=params,
documents=[self.document(document) for document in chat.documents],
),
messages=[self.message(message) for message in chat._messages],
prepared=chat._prepared,
)
File renamed without changes.
26 changes: 1 addition & 25 deletions ragna/deploy/_api/schemas.py → ragna/deploy/_schemas.py
Original file line number Diff line number Diff line change
@@ -18,13 +18,7 @@ class Components(BaseModel):
class Document(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
name: str

@classmethod
def from_core(cls, document: ragna.core.Document) -> Document:
return cls(
id=document.id,
name=document.name,
)
metadata: dict[str, Any] = Field(default_factory=dict)


class DocumentUpload(BaseModel):
@@ -40,16 +34,6 @@ class Source(BaseModel):
content: str
num_tokens: int

@classmethod
def from_core(cls, source: ragna.core.Source) -> Source:
return cls(
id=source.id,
document=Document.from_core(source.document),
location=source.location,
content=source.content,
num_tokens=source.num_tokens,
)


class Message(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
@@ -58,14 +42,6 @@ class Message(BaseModel):
sources: list[Source] = Field(default_factory=list)
timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)

@classmethod
def from_core(cls, message: ragna.core.Message) -> Message:
return cls(
content=message.content,
role=message.role,
sources=[Source.from_core(source) for source in message.sources],
)


class ChatMetadata(BaseModel):
name: str
41 changes: 41 additions & 0 deletions tests/deploy/api/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import contextlib
import json

import httpx
import pytest


@pytest.fixture(scope="package", autouse=True)
def enhance_raise_for_status(package_mocker):
raise_for_status = httpx.Response.raise_for_status

def enhanced_raise_for_status(self):
__tracebackhide__ = True

try:
return raise_for_status(self)
except httpx.HTTPStatusError as error:
content = None
with contextlib.suppress(Exception):
content = error.response.read()
content = content.decode()
content = "\n" + json.dumps(json.loads(content), indent=2)

if content is None:
raise error

message = f"{error}\nResponse content: {content}"
raise httpx.HTTPStatusError(
message, request=error.request, response=error.response
) from None

yield package_mocker.patch(
".".join(
[
httpx.Response.__module__,
httpx.Response.__name__,
raise_for_status.__name__,
]
),
new=enhanced_raise_for_status,
)