Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup dependent default resolution #465

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
introduce engine for API (#434)
pmeier authored Jun 26, 2024
commit 4d37e33859f56c0a7b4c1b8e9bcae9b1d5eb2867
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -162,7 +162,7 @@ ignore_missing_imports = true

[[tool.mypy.overrides]]
module = [
"ragna.deploy._api.orm",
"ragna.deploy._orm",
]
# Our ORM schema doesn't really work with mypy. There are some other ways to define it
# to play ball. We should do that in the future.
12 changes: 12 additions & 0 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import abc
import datetime
import enum
import functools
import inspect
import uuid
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union

import pydantic
@@ -157,6 +159,8 @@ def __init__(
*,
role: MessageRole = MessageRole.SYSTEM,
sources: Optional[list[Source]] = None,
id: Optional[uuid.UUID] = None,
timestamp: Optional[datetime.datetime] = None,
) -> None:
if isinstance(content, str):
self._content: str = content
@@ -166,6 +170,14 @@ def __init__(
self.role = role
self.sources = sources or []

if id is None:
id = uuid.uuid4()
self.id = id

if timestamp is None:
timestamp = datetime.datetime.utcnow()
self.timestamp = timestamp

async def __aiter__(self) -> AsyncIterator[str]:
if hasattr(self, "_content"):
yield self._content
87 changes: 70 additions & 17 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
import inspect
import uuid
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
@@ -12,21 +13,24 @@
Iterable,
Iterator,
Optional,
Type,
TypeVar,
Union,
cast,
)

import pydantic
from fastapi import status
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
from ._document import Document, LocalDocument
from ._utils import RagnaException, default_user, merge_models

if TYPE_CHECKING:
from ragna.deploy import Config

T = TypeVar("T")
C = TypeVar("C", bound=Component)
C = TypeVar("C", bound=Component, covariant=True)


class Rag(Generic[C]):
@@ -41,20 +45,69 @@ class Rag(Generic[C]):
```
"""

def __init__(self) -> None:
self._components: dict[Type[C], C] = {}
def __init__(
self,
*,
config: Optional[Config] = None,
ignore_unavailable_components: bool = False,
) -> None:
self._components: dict[type[C], C] = {}
self._display_name_map: dict[str, type[C]] = {}

if config is not None:
self._preload_components(
config=config,
ignore_unavailable_components=ignore_unavailable_components,
)

def _preload_components(
self, *, config: Config, ignore_unavailable_components: bool
) -> None:
for components in [config.source_storages, config.assistants]:
components = cast(list[type[Component]], components)
at_least_one = False
for component in components:
loaded_component = self._load_component(
component, # type: ignore[arg-type]
ignore_unavailable=ignore_unavailable_components,
)
if loaded_component is None:
print(
f"Ignoring {component.display_name()}, because it is not available."
)
else:
at_least_one = True

if not at_least_one:
raise RagnaException(
"No component available",
components=[component.display_name() for component in components],
)

def _load_component(
self, component: Union[Type[C], C], *, ignore_unavailable: bool = False
self, component: Union[C, type[C], str], *, ignore_unavailable: bool = False
) -> Optional[C]:
cls: Type[C]
cls: type[C]
instance: Optional[C]

if isinstance(component, Component):
instance = cast(C, component)
cls = type(instance)
elif isinstance(component, type) and issubclass(component, Component):
cls = component
instance = None
elif isinstance(component, str):
try:
cls = self._display_name_map[component]
except KeyError:
raise RagnaException(
"Unknown component",
display_name=component,
help="Did you forget to create the Rag() instance with a config?",
http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
http_detail=f"Unknown component '{component}'",
) from None

instance = None
else:
raise RagnaException
@@ -71,31 +124,33 @@ def _load_component(
instance = cls()

self._components[cls] = instance
self._display_name_map[cls.display_name()] = cls

return self._components[cls]

def chat(
self,
*,
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
source_storage: Union[SourceStorage, type[SourceStorage], str],
assistant: Union[Assistant, type[Assistant], str],
**params: Any,
) -> Chat:
"""Create a new [ragna.core.Chat][].

Args:
documents: Documents to use. If any item is not a [ragna.core.Document][],
[ragna.core.LocalDocument.from_path][] is invoked on it.
FIXME
source_storage: Source storage to use.
assistant: Assistant to use.
**params: Additional parameters passed to the source storage and assistant.
"""
return Chat(
self,
documents=documents,
source_storage=source_storage,
assistant=assistant,
source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type]
assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type]
**params,
)

@@ -146,17 +201,15 @@ def __init__(
rag: Rag,
*,
documents: Iterable[Any],
source_storage: Union[Type[SourceStorage], SourceStorage],
assistant: Union[Type[Assistant], Assistant],
source_storage: SourceStorage,
assistant: Assistant,
**params: Any,
) -> None:
self._rag = rag

self.documents = self._parse_documents(documents)
self.source_storage = cast(
SourceStorage, self._rag._load_component(source_storage)
)
self.assistant = cast(Assistant, self._rag._load_component(assistant))
self.source_storage = source_storage
self.assistant = assistant

special_params = SpecialChatParams().model_dump()
special_params.update(params)
@@ -306,6 +359,6 @@ async def __aenter__(self) -> Chat:
return self

async def __aexit__(
self, exc_type: Type[Exception], exc: Exception, traceback: str
self, exc_type: type[Exception], exc: Exception, traceback: str
) -> None:
pass
163 changes: 163 additions & 0 deletions ragna/deploy/_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import uuid
from typing import Annotated, AsyncIterator, cast

import aiofiles
import pydantic
from fastapi import (
APIRouter,
Body,
Depends,
Form,
HTTPException,
UploadFile,
)
from fastapi.responses import StreamingResponse

import ragna
import ragna.core
from ragna._compat import anext
from ragna.core._utils import default_user
from ragna.deploy import Config

from . import _schemas as schemas
from ._engine import Engine


def make_router(config: Config, engine: Engine) -> APIRouter:
router = APIRouter(tags=["API"])

def get_user() -> str:
return default_user()

UserDependency = Annotated[str, Depends(get_user)]

# TODO: the document endpoints do not go through the engine, because they'll change
# quite drastically when the UI no longer depends on the API

_database = engine._database

@router.post("/document")
async def create_document_upload_info(
user: UserDependency,
name: Annotated[str, Body(..., embed=True)],
) -> schemas.DocumentUpload:
with _database.get_session() as session:
document = schemas.Document(name=name)
metadata, parameters = await config.document.get_upload_info(
config=config, user=user, id=document.id, name=document.name
)
document.metadata = metadata
_database.add_document(
session, user=user, document=document, metadata=metadata
)
return schemas.DocumentUpload(parameters=parameters, document=document)

# TODO: Add UI support and documentation for this endpoint (#406)
@router.post("/documents")
async def create_documents_upload_info(
user: UserDependency,
names: Annotated[list[str], Body(..., embed=True)],
) -> list[schemas.DocumentUpload]:
with _database.get_session() as session:
document_metadata_collection = []
document_upload_collection = []
for name in names:
document = schemas.Document(name=name)
metadata, parameters = await config.document.get_upload_info(
config=config, user=user, id=document.id, name=document.name
)
document.metadata = metadata
document_metadata_collection.append((document, metadata))
document_upload_collection.append(
schemas.DocumentUpload(parameters=parameters, document=document)
)

_database.add_documents(
session,
user=user,
document_metadata_collection=document_metadata_collection,
)
return document_upload_collection

# TODO: Add new endpoint for batch uploading documents (#407)
@router.put("/document")
async def upload_document(
token: Annotated[str, Form()], file: UploadFile
) -> schemas.Document:
if not issubclass(config.document, ragna.core.LocalDocument):
raise HTTPException(
status_code=400,
detail="Ragna configuration does not support local upload",
)
with _database.get_session() as session:
user, id = ragna.core.LocalDocument.decode_upload_token(token)
document = _database.get_document(session, user=user, id=id)

core_document = cast(
ragna.core.LocalDocument, engine._to_core.document(document)
)
core_document.path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(core_document.path, "wb") as document_file:
while content := await file.read(1024):
await document_file.write(content)

return document

@router.get("/components")
def get_components(_: UserDependency) -> schemas.Components:
return engine.get_components()

@router.post("/chats")
async def create_chat(
user: UserDependency,
chat_metadata: schemas.ChatMetadata,
) -> schemas.Chat:
return engine.create_chat(user=user, chat_metadata=chat_metadata)

@router.get("/chats")
async def get_chats(user: UserDependency) -> list[schemas.Chat]:
return engine.get_chats(user=user)

@router.get("/chats/{id}")
async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat:
return engine.get_chat(user=user, id=id)

@router.post("/chats/{id}/prepare")
async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message:
return await engine.prepare_chat(user=user, id=id)

@router.post("/chats/{id}/answer")
async def answer(
user: UserDependency,
id: uuid.UUID,
prompt: Annotated[str, Body(..., embed=True)],
stream: Annotated[bool, Body(..., embed=True)] = False,
) -> schemas.Message:
message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt)
answer = await anext(message_stream)

if not stream:
content_chunks = [chunk.content async for chunk in message_stream]
answer.content += "".join(content_chunks)
return answer

async def message_chunks() -> AsyncIterator[schemas.Message]:
yield answer
async for chunk in message_stream:
yield chunk

async def to_jsonl(
models: AsyncIterator[pydantic.BaseModel],
) -> AsyncIterator[str]:
async for model in models:
yield f"{model.model_dump_json()}\n"

return StreamingResponse( # type: ignore[return-value]
to_jsonl(message_chunks())
)

@router.delete("/chats/{id}")
async def delete_chat(user: UserDependency, id: uuid.UUID) -> None:
engine.delete_chat(user=user, id=id)

return router
1 change: 0 additions & 1 deletion ragna/deploy/_api/__init__.py

This file was deleted.

Loading