diff --git a/.github/actions/setup-env/action.yml b/.github/actions/setup-env/action.yml index 21a3742a..2aa56478 100644 --- a/.github/actions/setup-env/action.yml +++ b/.github/actions/setup-env/action.yml @@ -18,7 +18,7 @@ runs: with: miniforge-variant: Mambaforge miniforge-version: latest - activate-environment: ragna-dev + activate-environment: ragna-deploy-dev - name: Display conda info shell: bash -el {0} @@ -57,6 +57,13 @@ runs: shell: bash -el {0} run: playwright install + - name: Install dev dependencies + shell: bash -el {0} + run: | + pip install \ + git+https://github.com/bokeh/bokeh-fastapi.git@main \ + git+https://github.com/holoviz/panel@7377c9e99bef0d32cbc65e94e908e365211f4421 + - name: Install ragna shell: bash -el {0} run: | @@ -66,7 +73,7 @@ runs: else PROJECT_PATH='.' fi - pip install --editable "${PROJECT_PATH}" + pip install --verbose --editable "${PROJECT_PATH}" - name: Display development environment shell: bash -el {0} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5ff4147d..bee7f2d3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,7 +38,9 @@ jobs: matrix: os: - ubuntu-latest - - windows-latest + # FIXME + # Building panel from source on Windows does not work through pip + # - windows-latest - macos-latest python-version: ["3.9"] include: diff --git a/environment-dev.yml b/environment-dev.yml index e42fe77d..19b565df 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -1,4 +1,4 @@ -name: ragna-dev +name: ragna-deploy-dev channels: - conda-forge dependencies: diff --git a/pyproject.toml b/pyproject.toml index ccdc84e9..70627e4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [build-system] requires = [ - "setuptools>=45", - "setuptools_scm[toml]>=6.2", + "setuptools>=64", + "setuptools_scm[toml]>=8", ] build-backend = "setuptools.build_meta" @@ -23,11 +23,14 @@ requires-python = ">=3.9" dependencies = [ "aiofiles", "emoji", + "eval_type_backport; python_version<'3.10'", "fastapi", "httpx", "importlib_metadata>=4.6; python_version<'3.10'", "packaging", - "panel==1.4.4", + # FIXME: pin them to released versions + "bokeh-fastapi", + "panel", "pydantic>=2", "pydantic-core", "pydantic-settings>=2", @@ -141,6 +144,9 @@ disallow_incomplete_defs = false [[tool.mypy.overrides]] module = [ + # FIXME: the package should be typed + "bokeh_fastapi", + "bokeh_fastapi.handler", "docx", "fitz", "ijson", @@ -149,12 +155,13 @@ module = [ "pptx", "pyarrow", "sentence_transformers", + "traitlets", ] ignore_missing_imports = true [[tool.mypy.overrides]] module = [ - "ragna.deploy._api.orm", + "ragna.deploy._orm", ] # Our ORM schema doesn't really work with mypy. There are some other ways to define it # to play ball. We should do that in the future. diff --git a/ragna/__main__.py b/ragna/__main__.py index 1435bf2a..4acdf3a3 100644 --- a/ragna/__main__.py +++ b/ragna/__main__.py @@ -1,4 +1,4 @@ -from ragna.deploy._cli import app +from ragna._cli import app if __name__ == "__main__": app() diff --git a/ragna/deploy/_api/__init__.py b/ragna/_cli/__init__.py similarity index 100% rename from ragna/deploy/_api/__init__.py rename to ragna/_cli/__init__.py diff --git a/ragna/deploy/_cli/config.py b/ragna/_cli/config.py similarity index 88% rename from ragna/deploy/_cli/config.py rename to ragna/_cli/config.py index 621d1ef9..73044ed9 100644 --- a/ragna/deploy/_cli/config.py +++ b/ragna/_cli/config.py @@ -197,7 +197,7 @@ def _handle_unmet_requirements(components: Iterable[Type[Component]]) -> None: return rich.print( - "You have selected components, which have additional requirements that are" + "You have selected components, which have additional requirements that are " "currently not met." ) unmet_requirements_by_type = _split_requirements(unmet_requirements) @@ -251,51 +251,37 @@ def _wizard_common() -> Config: ).unsafe_ask() ) - for sub_config, title in [(config.api, "REST API"), (config.ui, "web UI")]: - sub_config.hostname = questionary.text( # type: ignore[attr-defined] - f"What hostname do you want to bind the the Ragna {title} to?", - default=sub_config.hostname, # type: ignore[attr-defined] - qmark=QMARK, - ).unsafe_ask() - - sub_config.port = int( # type: ignore[attr-defined] - questionary.text( - f"What port do you want to bind the the Ragna {title} to?", - default=str(sub_config.port), # type: ignore[attr-defined] - qmark=QMARK, - ).unsafe_ask() - ) - - config.api.database_url = questionary.text( - "What is the URL of the SQL database?", - default=Config(local_root=config.local_root).api.database_url, + config.hostname = questionary.text( + "What hostname do you want to bind the the Ragna server to?", + default=config.hostname, qmark=QMARK, ).unsafe_ask() - config.api.url = questionary.text( - "At which URL will the Ragna REST API be served?", - default=Config( - api=dict( # type: ignore[arg-type] - hostname=config.api.hostname, - port=config.api.port, - ) - ).api.url, - qmark=QMARK, - ).unsafe_ask() + config.port = int( + questionary.text( + "What port do you want to bind the the Ragna server to?", + default=str(config.port), + qmark=QMARK, + ).unsafe_ask() + ) - config.api.origins = config.ui.origins = [ + config.origins = [ questionary.text( - "At which URL will the Ragna web UI be served?", + "At which URL will Ragna be served?", default=Config( - ui=dict( # type: ignore[arg-type] - hostname=config.ui.hostname, - port=config.ui.port, - ) - ).api.origins[0], + hostname=config.hostname, + port=config.port, + ).origins[0], qmark=QMARK, ).unsafe_ask() ] + config.database_url = questionary.text( + "What is the URL of the SQL database?", + default=Config(local_root=config.local_root).database_url, + qmark=QMARK, + ).unsafe_ask() + return config diff --git a/ragna/_cli/core.py b/ragna/_cli/core.py new file mode 100644 index 00000000..64d89e25 --- /dev/null +++ b/ragna/_cli/core.py @@ -0,0 +1,122 @@ +from pathlib import Path +from typing import Annotated, Optional + +import rich +import typer +import uvicorn + +import ragna +from ragna.deploy._core import make_app + +from .config import ConfigOption, check_config, init_config + +app = typer.Typer( + name="Ragna", + invoke_without_command=True, + no_args_is_help=True, + add_completion=False, + pretty_exceptions_enable=False, +) + + +def version_callback(value: bool) -> None: + if value: + rich.print(f"ragna {ragna.__version__} from {ragna.__path__[0]}") + raise typer.Exit() + + +@app.callback() +def _main( + version: Annotated[ + Optional[bool], + typer.Option( + "--version", callback=version_callback, help="Show version and exit." + ), + ] = None, +) -> None: + pass + + +@app.command(help="Start a wizard to build a Ragna configuration interactively.") +def init( + *, + output_path: Annotated[ + Path, + typer.Option( + "-o", + "--output-file", + metavar="OUTPUT_PATH", + default_factory=lambda: Path.cwd() / "ragna.toml", + show_default="./ragna.toml", + help="Write configuration to .", + ), + ], + force: Annotated[ + bool, + typer.Option( + "-f", "--force", help="Overwrite an existing file at ." + ), + ] = False, +) -> None: + config, output_path, force = init_config(output_path=output_path, force=force) + config.to_file(output_path, force=force) + + +@app.command(help="Check the availability of components.") +def check(config: ConfigOption = "./ragna.toml") -> None: # type: ignore[assignment] + is_available = check_config(config) + raise typer.Exit(int(not is_available)) + + +@app.command(help="Deploy Ragna REST API and web UI.") +def deploy( + *, + config: ConfigOption = "./ragna.toml", # type: ignore[assignment] + api: Annotated[ + bool, + typer.Option( + "--api/--no-api", + help="Deploy the Ragna REST API.", + ), + ] = True, + ui: Annotated[ + bool, + typer.Option( + help="Deploy the Ragna web UI.", + ), + ] = True, + ignore_unavailable_components: Annotated[ + bool, + typer.Option( + help=( + "Ignore components that are not available, " + "i.e. their requirements are not met. " + ) + ), + ] = False, + open_browser: Annotated[ + Optional[bool], + typer.Option( + help="Open a browser when Ragna is deployed.", + show_default="value of ui / no-ui", + ), + ] = None, +) -> None: + if not (api or ui): + raise Exception + + if open_browser is None: + open_browser = ui + + uvicorn.run( + lambda: make_app( + config, + ui=ui, + api=api, + ignore_unavailable_components=ignore_unavailable_components, + open_browser=open_browser, + ), + factory=True, + host=config.hostname, + port=config.port, + ) diff --git a/ragna/_utils.py b/ragna/_utils.py index c6f8e5cc..6ef5eb5c 100644 --- a/ragna/_utils.py +++ b/ragna/_utils.py @@ -5,7 +5,6 @@ import threading from pathlib import Path from typing import Any, Callable, Optional, Union -from urllib.parse import SplitResult, urlsplit, urlunsplit _LOCAL_ROOT = ( Path(os.environ.get("RAGNA_LOCAL_ROOT", "~/.cache/ragna")).expanduser().resolve() @@ -28,7 +27,7 @@ def local_root(path: Optional[Union[str, Path]] = None) -> Path: path: If passed, this is set as new local root directory. Returns: - Ragnas local root directory. + Ragna's local root directory. """ global _LOCAL_ROOT if path is not None: @@ -59,37 +58,6 @@ def fix_module(globals: dict[str, Any]) -> None: obj.__module__ = globals["__package__"] -def _replace_hostname(split_result: SplitResult, hostname: str) -> SplitResult: - # This is a separate function, since hostname is not an element of the SplitResult - # namedtuple, but only a property. Thus, we need to replace the netloc item, from - # which the hostname is generated. - if split_result.port is None: - netloc = hostname - else: - netloc = f"{hostname}:{split_result.port}" - return split_result._replace(netloc=netloc) - - -def handle_localhost_origins(origins: list[str]) -> list[str]: - # Since localhost is an alias for 127.0.0.1, we allow both so users and developers - # don't need to worry about it. - localhost_origins = { - components.hostname: components - for url in origins - if (components := urlsplit(url)).hostname in {"127.0.0.1", "localhost"} - } - if "127.0.0.1" in localhost_origins: - origins.append( - urlunsplit(_replace_hostname(localhost_origins["127.0.0.1"], "localhost")) - ) - elif "localhost" in localhost_origins: - origins.append( - urlunsplit(_replace_hostname(localhost_origins["localhost"], "127.0.0.1")) - ) - - return origins - - def timeout_after( seconds: float = 30, *, message: str = "" ) -> Callable[[Callable], Callable]: diff --git a/ragna/core/__init__.py b/ragna/core/__init__.py index 0f4b4bdf..44449775 100644 --- a/ragna/core/__init__.py +++ b/ragna/core/__init__.py @@ -34,7 +34,6 @@ from ._document import ( Document, DocumentHandler, - DocumentUploadParameters, DocxDocumentHandler, LocalDocument, Page, diff --git a/ragna/core/_components.py b/ragna/core/_components.py index bff49790..ecec015b 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -1,9 +1,11 @@ from __future__ import annotations import abc +import datetime import enum import functools import inspect +import uuid from typing import ( AsyncIterable, AsyncIterator, @@ -182,6 +184,8 @@ def __init__( *, role: MessageRole = MessageRole.SYSTEM, sources: Optional[list[Source]] = None, + id: Optional[uuid.UUID] = None, + timestamp: Optional[datetime.datetime] = None, ) -> None: if isinstance(content, str): self._content: str = content @@ -191,6 +195,14 @@ def __init__( self.role = role self.sources = sources or [] + if id is None: + id = uuid.uuid4() + self.id = id + + if timestamp is None: + timestamp = datetime.datetime.utcnow() + self.timestamp = timestamp + async def __aiter__(self) -> AsyncIterator[str]: if hasattr(self, "_content"): yield self._content diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 878742eb..7a1cef7f 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -2,26 +2,17 @@ import abc import io -import os -import secrets -import time import uuid +from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterator, Optional, Type, TypeVar, Union +from typing import Any, AsyncIterator, Iterator, Optional, Type, TypeVar, Union -import jwt +import aiofiles from pydantic import BaseModel -from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin - -if TYPE_CHECKING: - from ragna.deploy import Config +import ragna - -class DocumentUploadParameters(BaseModel): - method: str - url: str - data: dict +from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin class Document(RequirementsMixin, abc.ABC): @@ -62,16 +53,6 @@ def get_handler(name: str) -> DocumentHandler: return handler - @classmethod - @abc.abstractmethod - async def get_upload_info( - cls, *, config: Config, user: str, id: uuid.UUID, name: str - ) -> tuple[dict[str, Any], DocumentUploadParameters]: - pass - - @abc.abstractmethod - def is_readable(self) -> bool: ... - @abc.abstractmethod def read(self) -> bytes: ... @@ -88,12 +69,25 @@ class LocalDocument(Document): [ragna.core.LocalDocument.from_path][]. """ + def __init__( + self, + *, + id: Optional[uuid.UUID] = None, + name: str, + metadata: dict[str, Any], + handler: Optional[DocumentHandler] = None, + ): + super().__init__(id=id, name=name, metadata=metadata, handler=handler) + if "path" not in self.metadata: + metadata["path"] = str(ragna.local_root() / "documents" / str(self.id)) + @classmethod def from_path( cls, path: Union[str, Path], *, id: Optional[uuid.UUID] = None, + name: Optional[str] = None, metadata: Optional[dict[str, Any]] = None, handler: Optional[DocumentHandler] = None, ) -> LocalDocument: @@ -102,6 +96,7 @@ def from_path( Args: path: Local path to the file. id: ID of the document. If omitted, one is generated. + name: Name of the document. If omitted, defaults to the name of the `path`. metadata: Optional metadata of the document. handler: Document handler. If omitted, a builtin handler is selected based on the suffix of the `path`. @@ -118,60 +113,34 @@ def from_path( ) path = Path(path).expanduser().resolve() + if name is None: + name = path.name metadata["path"] = str(path) - return cls(id=id, name=path.name, metadata=metadata, handler=handler) + return cls(id=id, name=name, metadata=metadata, handler=handler) - @property + @cached_property def path(self) -> Path: return Path(self.metadata["path"]) - def is_readable(self) -> bool: - return self.path.exists() - - def read(self) -> bytes: - with open(self.path, "rb") as stream: - return stream.read() - - _JWT_SECRET = os.environ.get( - "RAGNA_API_DOCUMENT_UPLOAD_SECRET", secrets.token_urlsafe(32)[:32] - ) - _JWT_ALGORITHM = "HS256" - - @classmethod - async def get_upload_info( - cls, *, config: Config, user: str, id: uuid.UUID, name: str - ) -> tuple[dict[str, Any], DocumentUploadParameters]: - url = f"{config.api.url}/document" - data = { - "token": jwt.encode( - payload={ - "user": user, - "id": str(id), - "exp": time.time() + 5 * 60, - }, - key=cls._JWT_SECRET, - algorithm=cls._JWT_ALGORITHM, - ) - } - metadata = {"path": str(config.local_root / "documents" / str(id))} - return metadata, DocumentUploadParameters(method="PUT", url=url, data=data) - - @classmethod - def decode_upload_token(cls, token: str) -> tuple[str, uuid.UUID]: - try: - payload = jwt.decode( - token, key=cls._JWT_SECRET, algorithms=[cls._JWT_ALGORITHM] - ) - except jwt.InvalidSignatureError: + async def _write(self, stream: AsyncIterator[bytes]) -> None: + if self.path.exists(): raise RagnaException( - "Token invalid", http_status_code=401, http_detail=RagnaException.EVENT + "File already exists", path=self.path, http_detail=RagnaException.EVENT ) - except jwt.ExpiredSignatureError: + + async with aiofiles.open(self.path, "wb") as file: + async for content in stream: + await file.write(content) + + def read(self) -> bytes: + if not self.path.is_file(): raise RagnaException( - "Token expired", http_status_code=401, http_detail=RagnaException.EVENT + "File does not exist", path=self.path, http_detail=RagnaException.EVENT ) - return payload["user"], uuid.UUID(payload["id"]) + + with open(self.path, "rb") as file: + return file.read() class Page(BaseModel): diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 15154ea2..d963c15b 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -7,6 +7,7 @@ import uuid from collections import defaultdict from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Awaitable, @@ -15,7 +16,6 @@ Iterable, Iterator, Optional, - Type, TypeVar, Union, cast, @@ -23,14 +23,18 @@ import pydantic import pydantic_core +from fastapi import status from starlette.concurrency import iterate_in_threadpool, run_in_threadpool from ._components import Assistant, Component, Message, MessageRole, SourceStorage from ._document import Document, LocalDocument from ._utils import RagnaException, default_user, merge_models +if TYPE_CHECKING: + from ragna.deploy import Config + T = TypeVar("T") -C = TypeVar("C", bound=Component) +C = TypeVar("C", bound=Component, covariant=True) class Rag(Generic[C]): @@ -45,13 +49,49 @@ class Rag(Generic[C]): ``` """ - def __init__(self) -> None: - self._components: dict[Type[C], C] = {} + def __init__( + self, + *, + config: Optional[Config] = None, + ignore_unavailable_components: bool = False, + ) -> None: + self._components: dict[type[C], C] = {} + self._display_name_map: dict[str, type[C]] = {} + + if config is not None: + self._preload_components( + config=config, + ignore_unavailable_components=ignore_unavailable_components, + ) + + def _preload_components( + self, *, config: Config, ignore_unavailable_components: bool + ) -> None: + for components in [config.source_storages, config.assistants]: + components = cast(list[type[Component]], components) + at_least_one = False + for component in components: + loaded_component = self._load_component( + component, # type: ignore[arg-type] + ignore_unavailable=ignore_unavailable_components, + ) + if loaded_component is None: + print( + f"Ignoring {component.display_name()}, because it is not available." + ) + else: + at_least_one = True + + if not at_least_one: + raise RagnaException( + "No component available", + components=[component.display_name() for component in components], + ) def _load_component( - self, component: Union[Type[C], C], *, ignore_unavailable: bool = False + self, component: Union[C, type[C], str], *, ignore_unavailable: bool = False ) -> Optional[C]: - cls: Type[C] + cls: type[C] instance: Optional[C] if isinstance(component, Component): @@ -59,6 +99,19 @@ def _load_component( cls = type(instance) elif isinstance(component, type) and issubclass(component, Component): cls = component + instance = None + elif isinstance(component, str): + try: + cls = self._display_name_map[component] + except KeyError: + raise RagnaException( + "Unknown component", + display_name=component, + help="Did you forget to create the Rag() instance with a config?", + http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + http_detail=f"Unknown component '{component}'", + ) from None + instance = None else: raise RagnaException @@ -75,6 +128,7 @@ def _load_component( instance = cls() self._components[cls] = instance + self._display_name_map[cls.display_name()] = cls return self._components[cls] @@ -82,8 +136,8 @@ def chat( self, *, documents: Iterable[Any], - source_storage: Union[Type[SourceStorage], SourceStorage], - assistant: Union[Type[Assistant], Assistant], + source_storage: Union[SourceStorage, type[SourceStorage], str], + assistant: Union[Assistant, type[Assistant], str], **params: Any, ) -> Chat: """Create a new [ragna.core.Chat][]. @@ -91,6 +145,7 @@ def chat( Args: documents: Documents to use. If any item is not a [ragna.core.Document][], [ragna.core.LocalDocument.from_path][] is invoked on it. + FIXME source_storage: Source storage to use. assistant: Assistant to use. **params: Additional parameters passed to the source storage and assistant. @@ -98,8 +153,8 @@ def chat( return Chat( self, documents=documents, - source_storage=source_storage, - assistant=assistant, + source_storage=cast(SourceStorage, self._load_component(source_storage)), # type: ignore[arg-type] + assistant=cast(Assistant, self._load_component(assistant)), # type: ignore[arg-type] **params, ) @@ -150,17 +205,15 @@ def __init__( rag: Rag, *, documents: Iterable[Any], - source_storage: Union[Type[SourceStorage], SourceStorage], - assistant: Union[Type[Assistant], Assistant], + source_storage: SourceStorage, + assistant: Assistant, **params: Any, ) -> None: self._rag = rag self.documents = self._parse_documents(documents) - self.source_storage = cast( - SourceStorage, self._rag._load_component(source_storage) - ) - self.assistant = cast(Assistant, self._rag._load_component(assistant)) + self.source_storage = source_storage + self.assistant = assistant special_params = SpecialChatParams().model_dump() special_params.update(params) @@ -238,20 +291,14 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message: return answer def _parse_documents(self, documents: Iterable[Any]) -> list[Document]: - documents_ = [] - for document in documents: - if not isinstance(document, Document): - document = LocalDocument.from_path(document) - - if not document.is_readable(): - raise RagnaException( - "Document not readable", - document=document, - http_status_code=404, - ) - - documents_.append(document) - return documents_ + return [ + ( + document + if isinstance(document, Document) + else LocalDocument.from_path(document) + ) + for document in documents + ] def _unpack_chat_params( self, params: dict[str, Any] @@ -404,6 +451,6 @@ async def __aenter__(self) -> Chat: return self async def __aexit__( - self, exc_type: Type[Exception], exc: Exception, traceback: str + self, exc_type: type[Exception], exc: Exception, traceback: str ) -> None: pass diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py new file mode 100644 index 00000000..4de2737c --- /dev/null +++ b/ragna/deploy/_api.py @@ -0,0 +1,112 @@ +import uuid +from typing import Annotated, AsyncIterator + +import pydantic +from fastapi import ( + APIRouter, + Body, + Depends, + UploadFile, +) +from fastapi.responses import StreamingResponse + +from ragna._compat import anext +from ragna.core._utils import default_user + +from . import _schemas as schemas +from ._engine import Engine + + +def make_router(engine: Engine) -> APIRouter: + router = APIRouter(tags=["API"]) + + def get_user() -> str: + return default_user() + + UserDependency = Annotated[str, Depends(get_user)] + + @router.post("/documents") + def register_documents( + user: UserDependency, document_registrations: list[schemas.DocumentRegistration] + ) -> list[schemas.Document]: + return engine.register_documents( + user=user, document_registrations=document_registrations + ) + + @router.put("/documents") + async def upload_documents( + user: UserDependency, documents: list[UploadFile] + ) -> None: + def make_content_stream(file: UploadFile) -> AsyncIterator[bytes]: + async def content_stream() -> AsyncIterator[bytes]: + while content := await file.read(16 * 1024): + yield content + + return content_stream() + + await engine.store_documents( + user=user, + ids_and_streams=[ + (uuid.UUID(document.filename), make_content_stream(document)) + for document in documents + ], + ) + + @router.get("/components") + def get_components(_: UserDependency) -> schemas.Components: + return engine.get_components() + + @router.post("/chats") + async def create_chat( + user: UserDependency, + chat_creation: schemas.ChatCreation, + ) -> schemas.Chat: + return engine.create_chat(user=user, chat_creation=chat_creation) + + @router.get("/chats") + async def get_chats(user: UserDependency) -> list[schemas.Chat]: + return engine.get_chats(user=user) + + @router.get("/chats/{id}") + async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: + return engine.get_chat(user=user, id=id) + + @router.post("/chats/{id}/prepare") + async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: + return await engine.prepare_chat(user=user, id=id) + + @router.post("/chats/{id}/answer") + async def answer( + user: UserDependency, + id: uuid.UUID, + prompt: Annotated[str, Body(..., embed=True)], + stream: Annotated[bool, Body(..., embed=True)] = False, + ) -> schemas.Message: + message_stream = engine.answer_stream(user=user, chat_id=id, prompt=prompt) + answer = await anext(message_stream) + + if not stream: + content_chunks = [chunk.content async for chunk in message_stream] + answer.content += "".join(content_chunks) + return answer + + async def message_chunks() -> AsyncIterator[schemas.Message]: + yield answer + async for chunk in message_stream: + yield chunk + + async def to_jsonl( + models: AsyncIterator[pydantic.BaseModel], + ) -> AsyncIterator[str]: + async for model in models: + yield f"{model.model_dump_json()}\n" + + return StreamingResponse( # type: ignore[return-value] + to_jsonl(message_chunks()) + ) + + @router.delete("/chats/{id}") + async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: + engine.delete_chat(user=user, id=id) + + return router diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py deleted file mode 100644 index 78ee6154..00000000 --- a/ragna/deploy/_api/core.py +++ /dev/null @@ -1,347 +0,0 @@ -import contextlib -import uuid -from typing import Annotated, Any, AsyncIterator, Iterator, Type, cast - -import aiofiles -from fastapi import ( - Body, - Depends, - FastAPI, - Form, - HTTPException, - Request, - UploadFile, - status, -) -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, StreamingResponse -from pydantic import BaseModel - -import ragna -import ragna.core -from ragna._compat import aiter, anext -from ragna._utils import handle_localhost_origins -from ragna.core import Assistant, Component, Rag, RagnaException, SourceStorage -from ragna.core._rag import SpecialChatParams -from ragna.deploy import Config - -from . import database, schemas - - -def app(*, config: Config, ignore_unavailable_components: bool) -> FastAPI: - ragna.local_root(config.local_root) - - rag = Rag() # type: ignore[var-annotated] - components_map: dict[str, Component] = {} - for components in [config.source_storages, config.assistants]: - components = cast(list[Type[Component]], components) - at_least_one = False - for component in components: - loaded_component = rag._load_component( - component, ignore_unavailable=ignore_unavailable_components - ) - if loaded_component is None: - print( - f"Ignoring {component.display_name()}, because it is not available." - ) - else: - at_least_one = True - components_map[component.display_name()] = loaded_component - - if not at_least_one: - raise RagnaException( - "No component available", - components=[component.display_name() for component in components], - ) - - def get_component(display_name: str) -> Component: - component = components_map.get(display_name) - if component is None: - raise RagnaException( - "Unknown component", - display_name=display_name, - http_status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - http_detail=RagnaException.MESSAGE, - ) - - return component - - app = FastAPI( - title="ragna", - version=ragna.__version__, - root_path=config.api.root_path, - ) - app.add_middleware( - CORSMiddleware, - allow_origins=handle_localhost_origins(config.api.origins), - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - @app.exception_handler(RagnaException) - async def ragna_exception_handler( - request: Request, exc: RagnaException - ) -> JSONResponse: - if exc.http_detail is RagnaException.EVENT: - detail = exc.event - elif exc.http_detail is RagnaException.MESSAGE: - detail = str(exc) - else: - detail = cast(str, exc.http_detail) - return JSONResponse( - status_code=exc.http_status_code, - content={"error": {"message": detail}}, - ) - - @app.get("/") - async def version() -> str: - return ragna.__version__ - - authentication = config.authentication() - - @app.post("/token") - async def create_token(request: Request) -> str: - return await authentication.create_token(request) - - UserDependency = Annotated[str, Depends(authentication.get_user)] - - def _get_component_json_schema( - component: Type[Component], - ) -> dict[str, dict[str, Any]]: - json_schema = component._protocol_model().model_json_schema() - # FIXME: there is likely a better way to exclude certain fields builtin in - # pydantic - for special_param in SpecialChatParams.model_fields: - if ( - "properties" in json_schema - and special_param in json_schema["properties"] - ): - del json_schema["properties"][special_param] - if "required" in json_schema and special_param in json_schema["required"]: - json_schema["required"].remove(special_param) - return json_schema - - @app.get("/components") - async def get_components(_: UserDependency) -> schemas.Components: - return schemas.Components( - documents=sorted(config.document.supported_suffixes()), - source_storages=[ - _get_component_json_schema(type(source_storage)) - for source_storage in components_map.values() - if isinstance(source_storage, SourceStorage) - ], - assistants=[ - _get_component_json_schema(type(assistant)) - for assistant in components_map.values() - if isinstance(assistant, Assistant) - ], - ) - - make_session = database.get_sessionmaker(config.api.database_url) - - @contextlib.contextmanager - def get_session() -> Iterator[database.Session]: - with make_session() as session: # type: ignore[attr-defined] - yield session - - @app.post("/document") - async def create_document_upload_info( - user: UserDependency, - name: Annotated[str, Body(..., embed=True)], - ) -> schemas.DocumentUpload: - with get_session() as session: - document = schemas.Document(name=name) - metadata, parameters = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name - ) - database.add_document( - session, user=user, document=document, metadata=metadata - ) - return schemas.DocumentUpload(parameters=parameters, document=document) - - # TODO: Add UI support and documentation for this endpoint (#406) - @app.post("/documents") - async def create_documents_upload_info( - user: UserDependency, - names: Annotated[list[str], Body(..., embed=True)], - ) -> list[schemas.DocumentUpload]: - with get_session() as session: - document_metadata_collection = [] - document_upload_collection = [] - for name in names: - document = schemas.Document(name=name) - metadata, parameters = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name - ) - document_metadata_collection.append((document, metadata)) - document_upload_collection.append( - schemas.DocumentUpload(parameters=parameters, document=document) - ) - - database.add_documents( - session, - user=user, - document_metadata_collection=document_metadata_collection, - ) - return document_upload_collection - - # TODO: Add new endpoint for batch uploading documents (#407) - @app.put("/document") - async def upload_document( - token: Annotated[str, Form()], file: UploadFile - ) -> schemas.Document: - if not issubclass(config.document, ragna.core.LocalDocument): - raise HTTPException( - status_code=400, - detail="Ragna configuration does not support local upload", - ) - with get_session() as session: - user, id = ragna.core.LocalDocument.decode_upload_token(token) - document, metadata = database.get_document(session, user=user, id=id) - - core_document = ragna.core.LocalDocument( - id=document.id, name=document.name, metadata=metadata - ) - core_document.path.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(core_document.path, "wb") as document_file: - while content := await file.read(1024): - await document_file.write(content) - - return document - - def schema_to_core_chat( - session: database.Session, *, user: str, chat: schemas.Chat - ) -> ragna.core.Chat: - core_chat = rag.chat( - documents=[ - config.document( - id=document.id, - name=document.name, - metadata=database.get_document( - session, - user=user, - id=document.id, - )[1], - ) - for document in chat.metadata.documents - ], - source_storage=get_component(chat.metadata.source_storage), # type: ignore[arg-type] - assistant=get_component(chat.metadata.assistant), # type: ignore[arg-type] - user=user, - chat_id=chat.id, - chat_name=chat.metadata.name, - **chat.metadata.params, - ) - core_chat._messages = [message.to_core() for message in chat.messages] - core_chat._prepared = chat.prepared - - return core_chat - - @app.post("/chats") - async def create_chat( - user: UserDependency, - chat_metadata: schemas.ChatMetadata, - ) -> schemas.Chat: - with get_session() as session: - chat = schemas.Chat(metadata=chat_metadata) - - # Although we don't need the actual ragna.core.Chat object here, - # we use it to validate the documents and metadata. - schema_to_core_chat(session, user=user, chat=chat) - - database.add_chat(session, user=user, chat=chat) - return chat - - @app.get("/chats") - async def get_chats(user: UserDependency) -> list[schemas.Chat]: - with get_session() as session: - return database.get_chats(session, user=user) - - @app.get("/chats/{id}") - async def get_chat(user: UserDependency, id: uuid.UUID) -> schemas.Chat: - with get_session() as session: - return database.get_chat(session, user=user, id=id) - - @app.post("/chats/{id}/prepare") - async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: - with get_session() as session: - chat = database.get_chat(session, user=user, id=id) - - core_chat = schema_to_core_chat(session, user=user, chat=chat) - - welcome = schemas.Message.from_core(await core_chat.prepare()) - - chat.prepared = True - chat.messages.append(welcome) - database.update_chat(session, user=user, chat=chat) - - return welcome - - @app.post("/chats/{id}/answer") - async def answer( - user: UserDependency, - id: uuid.UUID, - prompt: Annotated[str, Body(..., embed=True)], - stream: Annotated[bool, Body(..., embed=True)] = False, - ) -> schemas.Message: - with get_session() as session: - chat = database.get_chat(session, user=user, id=id) - core_chat = schema_to_core_chat(session, user=user, chat=chat) - - core_answer = await core_chat.answer(prompt, stream=stream) - sources = [schemas.Source.from_core(source) for source in core_answer.sources] - chat.messages.append( - schemas.Message( - content=prompt, role=ragna.core.MessageRole.USER, sources=sources - ) - ) - - if stream: - - async def message_chunks() -> AsyncIterator[BaseModel]: - core_answer_stream = aiter(core_answer) - content_chunk = await anext(core_answer_stream) - - answer = schemas.Message( - content=content_chunk, - role=core_answer.role, - sources=sources, - ) - yield answer - - # Avoid sending the sources multiple times - answer_chunk = answer.model_copy(update=dict(sources=None)) - content_chunks = [answer_chunk.content] - async for content_chunk in core_answer_stream: - content_chunks.append(content_chunk) - answer_chunk.content = content_chunk - yield answer_chunk - - with get_session() as session: - answer.content = "".join(content_chunks) - chat.messages.append(answer) - database.update_chat(session, user=user, chat=chat) - - async def to_jsonl(models: AsyncIterator[Any]) -> AsyncIterator[str]: - async for model in models: - yield f"{model.model_dump_json()}\n" - - return StreamingResponse( # type: ignore[return-value] - to_jsonl(message_chunks()) - ) - else: - answer = schemas.Message.from_core(core_answer) - - with get_session() as session: - chat.messages.append(answer) - database.update_chat(session, user=user, chat=chat) - - return answer - - @app.delete("/chats/{id}") - async def delete_chat(user: UserDependency, id: uuid.UUID) -> None: - with get_session() as session: - database.delete_chat(session, user=user, id=id) - - return app diff --git a/ragna/deploy/_api/database.py b/ragna/deploy/_api/database.py deleted file mode 100644 index 2a61b048..00000000 --- a/ragna/deploy/_api/database.py +++ /dev/null @@ -1,270 +0,0 @@ -from __future__ import annotations - -import functools -import uuid -from typing import Any, Callable, Optional, cast -from urllib.parse import urlsplit - -from sqlalchemy import create_engine, select -from sqlalchemy.orm import Session, joinedload -from sqlalchemy.orm import sessionmaker as _sessionmaker - -from ragna.core import RagnaException - -from . import orm, schemas - - -def get_sessionmaker(database_url: str) -> Callable[[], Session]: - components = urlsplit(database_url) - if components.scheme == "sqlite": - connect_args = dict(check_same_thread=False) - else: - connect_args = dict() - engine = create_engine(database_url, connect_args=connect_args) - orm.Base.metadata.create_all(bind=engine) - return _sessionmaker(autocommit=False, autoflush=False, bind=engine) - - -@functools.lru_cache(maxsize=1024) -def _get_user_id(session: Session, username: str) -> uuid.UUID: - user: Optional[orm.User] = session.execute( - select(orm.User).where(orm.User.name == username) - ).scalar_one_or_none() - - if user is None: - # Add a new user if the current username is not registered yet. Since this is - # behind the authentication layer, we don't need any extra security here. - user = orm.User(id=uuid.uuid4(), name=username) - session.add(user) - session.commit() - - return cast(uuid.UUID, user.id) - - -def add_document( - session: Session, *, user: str, document: schemas.Document, metadata: dict[str, Any] -) -> None: - session.add( - orm.Document( - id=document.id, - user_id=_get_user_id(session, user), - name=document.name, - metadata_=metadata, - ) - ) - session.commit() - - -def add_documents( - session: Session, - *, - user: str, - document_metadata_collection: list[tuple[schemas.Document, dict[str, Any]]], -) -> None: - """ - Add multiple documents to the database. - - This function allows adding multiple documents at once by calling `add_all`. This is - important when there is non-negligible latency attached to each database operation. - """ - user_id = _get_user_id(session, user) - documents = [ - orm.Document( - id=document.id, - user_id=user_id, - name=document.name, - metadata_=metadata, - ) - for document, metadata in document_metadata_collection - ] - session.add_all(documents) - session.commit() - - -def _orm_to_schema_document(document: orm.Document) -> schemas.Document: - return schemas.Document(id=document.id, name=document.name) - - -@functools.lru_cache(maxsize=1024) -def get_document( - session: Session, *, user: str, id: uuid.UUID -) -> tuple[schemas.Document, dict[str, Any]]: - document = session.execute( - select(orm.Document).where( - (orm.Document.user_id == _get_user_id(session, user)) - & (orm.Document.id == id) - ) - ).scalar_one_or_none() - return _orm_to_schema_document(document), document.metadata_ - - -def add_chat(session: Session, *, user: str, chat: schemas.Chat) -> None: - document_ids = {document.id for document in chat.metadata.documents} - documents = ( - session.execute(select(orm.Document).where(orm.Document.id.in_(document_ids))) - .scalars() - .all() - ) - if len(documents) != len(document_ids): - raise RagnaException( - str(set(document_ids) - {document.id for document in documents}) - ) - session.add( - orm.Chat( - id=chat.id, - user_id=_get_user_id(session, user), - name=chat.metadata.name, - documents=documents, - source_storage=chat.metadata.source_storage, - assistant=chat.metadata.assistant, - params=chat.metadata.params, - prepared=chat.prepared, - ) - ) - session.commit() - - -def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat: - documents = [ - schemas.Document(id=document.id, name=document.name) - for document in chat.documents - ] - messages = [ - schemas.Message( - id=message.id, - role=message.role, - content=message.content, - sources=[ - schemas.Source( - id=source.id, - document=_orm_to_schema_document(source.document), - location=source.location, - content=source.content, - num_tokens=source.num_tokens, - ) - for source in message.sources - ], - timestamp=message.timestamp, - ) - for message in chat.messages - ] - return schemas.Chat( - id=chat.id, - metadata=schemas.ChatMetadata( - name=chat.name, - documents=documents, - source_storage=chat.source_storage, - assistant=chat.assistant, - params=chat.params, - ), - messages=messages, - prepared=chat.prepared, - ) - - -def _select_chat(*, eager: bool = False) -> Any: - selector = select(orm.Chat) - if eager: - selector = selector.options( # type: ignore[attr-defined] - joinedload(orm.Chat.messages).joinedload(orm.Message.sources), - joinedload(orm.Chat.documents), - ) - return selector - - -def get_chats(session: Session, *, user: str) -> list[schemas.Chat]: - return [ - _orm_to_schema_chat(chat) - for chat in session.execute( - _select_chat(eager=True).where( - orm.Chat.user_id == _get_user_id(session, user) - ) - ) - .scalars() - .unique() - .all() - ] - - -def _get_orm_chat( - session: Session, *, user: str, id: uuid.UUID, eager: bool = False -) -> orm.Chat: - chat: Optional[orm.Chat] = ( - session.execute( - _select_chat(eager=eager).where( - (orm.Chat.id == id) & (orm.Chat.user_id == _get_user_id(session, user)) - ) - ) - .unique() - .scalar_one_or_none() - ) - if chat is None: - raise RagnaException() - return chat - - -def get_chat(session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat: - return _orm_to_schema_chat(_get_orm_chat(session, user=user, id=id, eager=True)) - - -def _schema_to_orm_source(session: Session, source: schemas.Source) -> orm.Source: - orm_source: Optional[orm.Source] = session.execute( - select(orm.Source).where(orm.Source.id == source.id) - ).scalar_one_or_none() - - if orm_source is None: - orm_source = orm.Source( - id=source.id, - document_id=source.document.id, - location=source.location, - content=source.content, - num_tokens=source.num_tokens, - ) - session.add(orm_source) - session.commit() - session.refresh(orm_source) - - return orm_source - - -def _schema_to_orm_message( - session: Session, chat_id: uuid.UUID, message: schemas.Message -) -> orm.Message: - orm_message: Optional[orm.Message] = session.execute( - select(orm.Message).where(orm.Message.id == message.id) - ).scalar_one_or_none() - if orm_message is None: - orm_message = orm.Message( - id=message.id, - chat_id=chat_id, - content=message.content, - role=message.role, - sources=[ - _schema_to_orm_source(session, source=source) - for source in message.sources - ], - timestamp=message.timestamp, - ) - session.add(orm_message) - session.commit() - session.refresh(orm_message) - - return orm_message - - -def update_chat(session: Session, user: str, chat: schemas.Chat) -> None: - orm_chat = _get_orm_chat(session, user=user, id=chat.id) - - orm_chat.prepared = chat.prepared - orm_chat.messages = [ # type: ignore[assignment] - _schema_to_orm_message(session, chat_id=chat.id, message=message) - for message in chat.messages - ] - - session.commit() - - -def delete_chat(session: Session, user: str, id: uuid.UUID) -> None: - orm_chat = _get_orm_chat(session, user=user, id=id) - session.delete(orm_chat) # type: ignore[no-untyped-call] - session.commit() diff --git a/ragna/deploy/_api/schemas.py b/ragna/deploy/_api/schemas.py deleted file mode 100644 index 37471c69..00000000 --- a/ragna/deploy/_api/schemas.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -import datetime -import uuid -from typing import Any - -from pydantic import BaseModel, Field - -import ragna.core - - -class Components(BaseModel): - documents: list[str] - source_storages: list[dict[str, Any]] - assistants: list[dict[str, Any]] - - -class Document(BaseModel): - id: uuid.UUID = Field(default_factory=uuid.uuid4) - name: str - - @classmethod - def from_core(cls, document: ragna.core.Document) -> Document: - return cls( - id=document.id, - name=document.name, - ) - - def to_core(self) -> ragna.core.Document: - return ragna.core.LocalDocument( - id=self.id, - name=self.name, - # TEMP: setting an empty metadata dict for now. - # Will be resolved as part of the "managed ragna" work: - # https://github.com/Quansight/ragna/issues/256 - metadata={}, - ) - - -class DocumentUpload(BaseModel): - parameters: ragna.core.DocumentUploadParameters - document: Document - - -class Source(BaseModel): - # See orm.Source on why this is not a UUID - id: str - document: Document - location: str - content: str - num_tokens: int - - @classmethod - def from_core(cls, source: ragna.core.Source) -> Source: - return cls( - id=source.id, - document=Document.from_core(source.document), - location=source.location, - content=source.content, - num_tokens=source.num_tokens, - ) - - def to_core(self) -> ragna.core.Source: - return ragna.core.Source( - id=self.id, - document=self.document.to_core(), - location=self.location, - content=self.content, - num_tokens=self.num_tokens, - ) - - -class Message(BaseModel): - id: uuid.UUID = Field(default_factory=uuid.uuid4) - content: str - role: ragna.core.MessageRole - sources: list[Source] = Field(default_factory=list) - timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) - - @classmethod - def from_core(cls, message: ragna.core.Message) -> Message: - return cls( - content=message.content, - role=message.role, - sources=[Source.from_core(source) for source in message.sources], - ) - - def to_core(self) -> ragna.core.Message: - return ragna.core.Message( - content=self.content, - role=self.role, - sources=[source.to_core() for source in self.sources], - ) - - -class ChatMetadata(BaseModel): - name: str - source_storage: str - assistant: str - params: dict - documents: list[Document] - - -class Chat(BaseModel): - id: uuid.UUID = Field(default_factory=uuid.uuid4) - metadata: ChatMetadata - messages: list[Message] = Field(default_factory=list) - prepared: bool = False diff --git a/ragna/deploy/_cli/__init__.py b/ragna/deploy/_cli/__init__.py deleted file mode 100644 index 93eefb4d..00000000 --- a/ragna/deploy/_cli/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .core import app diff --git a/ragna/deploy/_cli/core.py b/ragna/deploy/_cli/core.py deleted file mode 100644 index fa493c21..00000000 --- a/ragna/deploy/_cli/core.py +++ /dev/null @@ -1,173 +0,0 @@ -import subprocess -import sys -import time -from pathlib import Path -from typing import Annotated, Optional - -import httpx -import rich -import typer -import uvicorn - -import ragna -from ragna._utils import timeout_after -from ragna.deploy._api import app as api_app -from ragna.deploy._ui import app as ui_app - -from .config import ConfigOption, check_config, init_config - -app = typer.Typer( - name="ragna", - invoke_without_command=True, - no_args_is_help=True, - add_completion=False, - pretty_exceptions_enable=False, -) - - -def version_callback(value: bool) -> None: - if value: - rich.print(f"ragna {ragna.__version__} from {ragna.__path__[0]}") - raise typer.Exit() - - -@app.callback() -def _main( - version: Annotated[ - Optional[bool], - typer.Option( - "--version", callback=version_callback, help="Show version and exit." - ), - ] = None, -) -> None: - pass - - -@app.command(help="Start a wizard to build a Ragna configuration interactively.") -def init( - *, - output_path: Annotated[ - Path, - typer.Option( - "-o", - "--output-file", - metavar="OUTPUT_PATH", - default_factory=lambda: Path.cwd() / "ragna.toml", - show_default="./ragna.toml", - help="Write configuration to .", - ), - ], - force: Annotated[ - bool, - typer.Option( - "-f", "--force", help="Overwrite an existing file at ." - ), - ] = False, -) -> None: - config, output_path, force = init_config(output_path=output_path, force=force) - config.to_file(output_path, force=force) - - -@app.command(help="Check the availability of components.") -def check(config: ConfigOption = "./ragna.toml") -> None: # type: ignore[assignment] - is_available = check_config(config) - raise typer.Exit(int(not is_available)) - - -@app.command(help="Start the REST API.") -def api( - *, - config: ConfigOption = "./ragna.toml", # type: ignore[assignment] - ignore_unavailable_components: Annotated[ - bool, - typer.Option( - help=( - "Ignore components that are not available, " - "i.e. their requirements are not met. " - ) - ), - ] = False, -) -> None: - uvicorn.run( - api_app( - config=config, ignore_unavailable_components=ignore_unavailable_components - ), - host=config.api.hostname, - port=config.api.port, - ) - - -@app.command(help="Start the web UI.") -def ui( - *, - config: ConfigOption = "./ragna.toml", # type: ignore[assignment] - start_api: Annotated[ - Optional[bool], - typer.Option( - help="Start the ragna REST API alongside the web UI in a subprocess.", - show_default="Start if the API is not served at the configured URL.", - ), - ] = None, - ignore_unavailable_components: Annotated[ - bool, - typer.Option( - help=( - "Ignore components that are not available, " - "i.e. their requirements are not met. " - "This option as no effect if --no-start-api is used." - ) - ), - ] = False, - open_browser: Annotated[ - bool, - typer.Option(help="Open the web UI in the browser when it is started."), - ] = True, -) -> None: - def check_api_available() -> bool: - try: - return httpx.get(config.api.url).is_success - except httpx.ConnectError: - return False - - if start_api is None: - start_api = not check_api_available() - - if start_api: - process = subprocess.Popen( - [ - sys.executable, - "-m", - "ragna", - "api", - "--config", - config.__ragna_cli_config_path__, # type: ignore[attr-defined] - f"--{'' if ignore_unavailable_components else 'no-'}ignore-unavailable-components", - ], - stdout=sys.stdout, - stderr=sys.stderr, - ) - else: - process = None - - try: - if process is not None: - - @timeout_after(60) - def wait_for_api() -> None: - while not check_api_available(): - time.sleep(0.5) - - try: - wait_for_api() - except TimeoutError: - rich.print( - "Failed to start the API in 60 seconds. " - "Please start it manually with [bold]ragna api[/bold]." - ) - raise typer.Exit(1) - - ui_app(config=config, open_browser=open_browser).serve() # type: ignore[no-untyped-call] - finally: - if process is not None: - process.kill() - process.communicate() diff --git a/ragna/deploy/_config.py b/ragna/deploy/_config.py index fa7a01f9..486713ee 100644 --- a/ragna/deploy/_config.py +++ b/ragna/deploy/_config.py @@ -2,17 +2,12 @@ import itertools from pathlib import Path -from typing import Annotated, Any, Callable, Generic, Type, TypeVar, Union +from typing import Annotated, Any, Callable, Type, Union import tomlkit import tomlkit.container import tomlkit.items -from pydantic import ( - AfterValidator, - Field, - ImportString, - model_validator, -) +from pydantic import AfterValidator, Field, ImportString, model_validator from pydantic_settings import ( BaseSettings, PydanticBaseSettingsSource, @@ -25,34 +20,31 @@ from ._authentication import Authentication -T = TypeVar("T") +class DependentDefaultValue: + def __init__(self, resolve: Callable[[Config], Any]): + self.resolve = resolve -class AfterConfigValidateDefault(Generic[T]): - """This class exists for a specific use case: - - We have values for which we need the validated config to compute the default, - e.g. the API default origins can only be computed after we know the UI hostname - and port. - - We want to use the plain annotations rather than allowing a sentinel type, e.g. - `str` vs. `Optional[str]`. - """ +_RESERVED_PARAMS = ["default", "default_factory", "validate_default"] - def __init__(self, make_default: Callable[[Config], T]) -> None: - self.make_default = make_default - @classmethod - def make(cls, make_default: Callable[[Config], T]) -> Any: - """Creates a default sentinel that is resolved after the config is validated. +def DependentDefaultField(resolve: Callable[[Config], Any], **kwargs: Any) -> Any: + if any(param in kwargs for param in _RESERVED_PARAMS): + reserved_params = ", ".join(repr(param) for param in _RESERVED_PARAMS) + raise Exception( + f"The parameters {reserved_params} are reserved " f"and cannot be passed." + ) + return Field( + default=DependentDefaultValue(resolve), validate_default=False, **kwargs + ) - Args: - make_default: Callable that takes the validated config and returns the - resolved value. - """ - return Field(default=cls(make_default), validate_default=False) +class Config(BaseSettings): + """Ragna configuration""" + + model_config = SettingsConfigDict(env_prefix="ragna_") -class ConfigBase(BaseSettings): @classmethod def settings_customise_sources( cls, @@ -76,64 +68,13 @@ def settings_customise_sources( # 4. Default return env_settings, init_settings - def _resolve_default_sentinels(self, config: Config) -> None: + @model_validator(mode="after") + def _resolve_dependent_default_values(self) -> Config: for name, info in self.model_fields.items(): value = getattr(self, name) - if isinstance(value, ConfigBase): - value._resolve_default_sentinels(config) - elif isinstance(value, AfterConfigValidateDefault): - setattr(self, name, value.make_default(config)) - - def __str__(self) -> str: - toml = tomlkit.item(self.model_dump(mode="json")) - self._set_multiline_array(toml) - return toml.as_string() - - def _set_multiline_array(self, item: tomlkit.items.Item) -> None: - if isinstance(item, tomlkit.items.Array): - item.multiline(True) - - if not isinstance(item, tomlkit.items.Table): - return - - container = item.value - for child in itertools.chain( - (value for _, value in container.body), container.value.values() - ): - self._set_multiline_array(child) - - -def make_default_origins(config: Config) -> list[str]: - return [f"http://{config.ui.hostname}:{config.ui.port}"] - - -class ApiConfig(ConfigBase): - model_config = SettingsConfigDict(env_prefix="ragna_api_") - - hostname: str = "127.0.0.1" - port: int = 31476 - root_path: str = "" - url: str = AfterConfigValidateDefault.make( - lambda config: f"http://{config.api.hostname}:{config.api.port}{config.api.root_path}", - ) - database_url: str = AfterConfigValidateDefault.make( - lambda config: f"sqlite:///{config.local_root}/ragna.db", - ) - origins: list[str] = AfterConfigValidateDefault.make(make_default_origins) - - -class UiConfig(ConfigBase): - model_config = SettingsConfigDict(env_prefix="ragna_ui_") - - hostname: str = "127.0.0.1" - port: int = 31477 - origins: list[str] = AfterConfigValidateDefault.make(make_default_origins) - - -class Config(ConfigBase): - """Ragna configuration""" - - model_config = SettingsConfigDict(env_prefix="ragna_") + if isinstance(value, DependentDefaultValue): + setattr(self, name, value.resolve(self)) + return self local_root: Annotated[Path, AfterValidator(make_directory)] = Field( default_factory=ragna.local_root @@ -151,13 +92,38 @@ class Config(ConfigBase): "ragna.assistants.RagnaDemoAssistant" # type: ignore[list-item] ] - api: ApiConfig = Field(default_factory=ApiConfig) - ui: UiConfig = Field(default_factory=UiConfig) + hostname: str = "127.0.0.1" + port: int = 31476 + root_path: str = "" + origins: list[str] = DependentDefaultField( + lambda config: [f"http://{config.hostname}:{config.port}"] + ) - @model_validator(mode="after") - def _validate_model(self) -> Config: - self._resolve_default_sentinels(self) - return self + database_url: str = DependentDefaultField( + lambda config: f"sqlite:///{config.local_root}/ragna.db", + ) + + @property + def _url(self) -> str: + return f"http://{self.hostname}:{self.port}{self.root_path}" + + def __str__(self) -> str: + toml = tomlkit.item(self.model_dump(mode="json")) + self._set_multiline_array(toml) + return toml.as_string() + + def _set_multiline_array(self, item: tomlkit.items.Item) -> None: + if isinstance(item, tomlkit.items.Array): + item.multiline(True) + + if not isinstance(item, tomlkit.items.Table): + return + + container = item.value + for child in itertools.chain( + (value for _, value in container.body), container.value.values() + ): + self._set_multiline_array(child) @classmethod def from_file(cls, path: Union[str, Path]) -> Config: diff --git a/ragna/deploy/_core.py b/ragna/deploy/_core.py new file mode 100644 index 00000000..44c672a8 --- /dev/null +++ b/ragna/deploy/_core.py @@ -0,0 +1,112 @@ +import contextlib +import threading +import time +import webbrowser +from typing import AsyncContextManager, AsyncIterator, Callable, Optional, cast + +import httpx +from fastapi import FastAPI, Request, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response + +import ragna +from ragna.core import RagnaException + +from ._api import make_router as make_api_router +from ._config import Config +from ._engine import Engine +from ._ui import app as make_ui_app +from ._utils import handle_localhost_origins, redirect, set_redirect_root_path + + +def make_app( + config: Config, + *, + api: bool, + ui: bool, + ignore_unavailable_components: bool, + open_browser: bool, +) -> FastAPI: + set_redirect_root_path(config.root_path) + + lifespan: Optional[Callable[[FastAPI], AsyncContextManager]] + if open_browser: + + @contextlib.asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncIterator[None]: + def target() -> None: + client = httpx.Client(base_url=config._url) + + def server_available() -> bool: + try: + return client.get("/health").is_success + except httpx.ConnectError: + return False + + while not server_available(): + time.sleep(0.1) + + webbrowser.open(config._url) + + # We are starting the browser on a thread, because the server can only + # become available _after_ the yield below. By setting daemon=True, the + # thread will automatically terminated together with the main thread. This + # is only relevant when the server never becomes available, e.g. if an error + # occurs. In this case our thread would be stuck in an endless loop. + thread = threading.Thread(target=target, daemon=True) + thread.start() + yield + + else: + lifespan = None + + app = FastAPI(title="Ragna", version=ragna.__version__, lifespan=lifespan) + + app.add_middleware( + CORSMiddleware, + allow_origins=handle_localhost_origins(config.origins), + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + engine = Engine( + config=config, + ignore_unavailable_components=ignore_unavailable_components, + ) + + if api: + app.include_router(make_api_router(engine), prefix="/api") + + if ui: + panel_app = make_ui_app(engine) + panel_app.serve_with_fastapi(app, endpoint="/ui") + + @app.get("/", include_in_schema=False) + async def base_redirect() -> Response: + return redirect("/ui" if ui else "/docs") + + @app.get("/health") + async def health() -> Response: + return Response(b"", status_code=status.HTTP_200_OK) + + @app.get("/version") + async def version() -> str: + return ragna.__version__ + + @app.exception_handler(RagnaException) + async def ragna_exception_handler( + request: Request, exc: RagnaException + ) -> JSONResponse: + if exc.http_detail is RagnaException.EVENT: + detail = exc.event + elif exc.http_detail is RagnaException.MESSAGE: + detail = str(exc) + else: + detail = cast(str, exc.http_detail) + return JSONResponse( + status_code=exc.http_status_code, + content={"error": {"message": detail}}, + ) + + return app diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py new file mode 100644 index 00000000..529fa3b6 --- /dev/null +++ b/ragna/deploy/_database.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import uuid +from typing import Any, Collection, Optional +from urllib.parse import urlsplit + +from sqlalchemy import create_engine, select +from sqlalchemy.orm import Session, joinedload, sessionmaker + +from ragna.core import RagnaException + +from . import _orm as orm +from . import _schemas as schemas + + +class Database: + def __init__(self, url: str) -> None: + components = urlsplit(url) + if components.scheme == "sqlite": + connect_args = dict(check_same_thread=False) + else: + connect_args = dict() + engine = create_engine(url, connect_args=connect_args) + orm.Base.metadata.create_all(bind=engine) + + self.get_session = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + self._to_orm = SchemaToOrmConverter() + self._to_schema = OrmToSchemaConverter() + + def _get_user(self, session: Session, *, username: str) -> orm.User: + user: Optional[orm.User] = session.execute( + select(orm.User).where(orm.User.name == username) + ).scalar_one_or_none() + + if user is None: + # Add a new user if the current username is not registered yet. Since this + # is behind the authentication layer, we don't need any extra security here. + user = orm.User(id=uuid.uuid4(), name=username) + session.add(user) + session.commit() + + return user + + def add_documents( + self, + session: Session, + *, + user: str, + documents: list[schemas.Document], + ) -> None: + user_id = self._get_user(session, username=user).id + session.add_all( + [self._to_orm.document(document, user_id=user_id) for document in documents] + ) + session.commit() + + def _get_orm_documents( + self, session: Session, *, user: str, ids: Collection[uuid.UUID] + ) -> list[orm.Document]: + # FIXME also check if the user is allowed to access the documents + # FIXME: maybe just take the user id to avoid getting it twice in add_chat? + documents = ( + session.execute(select(orm.Document).where(orm.Document.id.in_(ids))) + .scalars() + .all() + ) + if len(documents) != len(ids): + raise RagnaException( + str(set(ids) - {document.id for document in documents}) + ) + + return documents # type: ignore[no-any-return] + + def get_documents( + self, session: Session, *, user: str, ids: Collection[uuid.UUID] + ) -> list[schemas.Document]: + return [ + self._to_schema.document(document) + for document in self._get_orm_documents(session, user=user, ids=ids) + ] + + def add_chat(self, session: Session, *, user: str, chat: schemas.Chat) -> None: + orm_chat = self._to_orm.chat( + chat, user_id=self._get_user(session, username=user).id + ) + # We need to merge and not add here, because the documents are already in the DB + session.merge(orm_chat) + session.commit() + + def _select_chat(self, *, eager: bool = False) -> Any: + selector = select(orm.Chat) + if eager: + selector = selector.options( # type: ignore[attr-defined] + joinedload(orm.Chat.messages).joinedload(orm.Message.sources), + joinedload(orm.Chat.documents), + ) + return selector + + def get_chats(self, session: Session, *, user: str) -> list[schemas.Chat]: + return [ + self._to_schema.chat(chat) + for chat in session.execute( + self._select_chat(eager=True).where( + orm.Chat.user_id == self._get_user(session, username=user).id + ) + ) + .scalars() + .unique() + .all() + ] + + def _get_orm_chat( + self, session: Session, *, user: str, id: uuid.UUID, eager: bool = False + ) -> orm.Chat: + chat: Optional[orm.Chat] = ( + session.execute( + self._select_chat(eager=eager).where( + (orm.Chat.id == id) + & (orm.Chat.user_id == self._get_user(session, username=user).id) + ) + ) + .unique() + .scalar_one_or_none() + ) + if chat is None: + raise RagnaException() + return chat + + def get_chat(self, session: Session, *, user: str, id: uuid.UUID) -> schemas.Chat: + return self._to_schema.chat( + (self._get_orm_chat(session, user=user, id=id, eager=True)) + ) + + def update_chat(self, session: Session, user: str, chat: schemas.Chat) -> None: + orm_chat = self._to_orm.chat( + chat, user_id=self._get_user(session, username=user).id + ) + session.merge(orm_chat) + session.commit() + + def delete_chat(self, session: Session, user: str, id: uuid.UUID) -> None: + orm_chat = self._get_orm_chat(session, user=user, id=id) + session.delete(orm_chat) # type: ignore[no-untyped-call] + session.commit() + + +class SchemaToOrmConverter: + def document( + self, document: schemas.Document, *, user_id: uuid.UUID + ) -> orm.Document: + return orm.Document( + id=document.id, + user_id=user_id, + name=document.name, + metadata_=document.metadata, + ) + + def source(self, source: schemas.Source) -> orm.Source: + return orm.Source( + id=source.id, + document_id=source.document.id, + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message(self, message: schemas.Message, *, chat_id: uuid.UUID) -> orm.Message: + return orm.Message( + id=message.id, + chat_id=chat_id, + content=message.content, + role=message.role, + sources=[self.source(source) for source in message.sources], + timestamp=message.timestamp, + ) + + def chat( + self, + chat: schemas.Chat, + *, + user_id: uuid.UUID, + ) -> orm.Chat: + return orm.Chat( + id=chat.id, + user_id=user_id, + name=chat.name, + documents=[ + self.document(document, user_id=user_id) for document in chat.documents + ], + source_storage=chat.source_storage, + assistant=chat.assistant, + params=chat.params, + messages=[ + self.message(message, chat_id=chat.id) for message in chat.messages + ], + prepared=chat.prepared, + ) + + +class OrmToSchemaConverter: + def document(self, document: orm.Document) -> schemas.Document: + return schemas.Document( + id=document.id, name=document.name, metadata=document.metadata_ + ) + + def source(self, source: orm.Source) -> schemas.Source: + return schemas.Source( + id=source.id, + document=self.document(source.document), + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message(self, message: orm.Message) -> schemas.Message: + return schemas.Message( + id=message.id, + role=message.role, # type: ignore[arg-type] + content=message.content, + sources=[self.source(source) for source in message.sources], + timestamp=message.timestamp, + ) + + def chat(self, chat: orm.Chat) -> schemas.Chat: + return schemas.Chat( + id=chat.id, + name=chat.name, + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage, + assistant=chat.assistant, + params=chat.params, + messages=[self.message(message) for message in chat.messages], + prepared=chat.prepared, + ) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py new file mode 100644 index 00000000..2209a61f --- /dev/null +++ b/ragna/deploy/_engine.py @@ -0,0 +1,262 @@ +import uuid +from typing import Any, AsyncIterator, Optional, Type, cast + +from fastapi import status as http_status_code + +import ragna +from ragna import Rag, core +from ragna._compat import aiter, anext +from ragna._utils import make_directory +from ragna.core import RagnaException +from ragna.core._rag import SpecialChatParams +from ragna.deploy import Config + +from . import _schemas as schemas +from ._database import Database + + +class Engine: + def __init__(self, *, config: Config, ignore_unavailable_components: bool) -> None: + self._config = config + ragna.local_root(config.local_root) + self._documents_root = make_directory(config.local_root / "documents") + self.supports_store_documents = issubclass( + self._config.document, ragna.core.LocalDocument + ) + + self._database = Database(url=config.database_url) + + self._rag = Rag( # type: ignore[var-annotated] + config=config, + ignore_unavailable_components=ignore_unavailable_components, + ) + + self._to_core = SchemaToCoreConverter(config=self._config, rag=self._rag) + self._to_schema = CoreToSchemaConverter() + + def _get_component_json_schema( + self, + component: Type[core.Component], + ) -> dict[str, dict[str, Any]]: + json_schema = component._protocol_model().model_json_schema() + # FIXME: there is likely a better way to exclude certain fields builtin in + # pydantic + for special_param in SpecialChatParams.model_fields: + if ( + "properties" in json_schema + and special_param in json_schema["properties"] + ): + del json_schema["properties"][special_param] + if "required" in json_schema and special_param in json_schema["required"]: + json_schema["required"].remove(special_param) + return json_schema + + def get_components(self) -> schemas.Components: + return schemas.Components( + documents=sorted(self._config.document.supported_suffixes()), + source_storages=[ + self._get_component_json_schema(source_storage) + for source_storage in self._rag._components.keys() + if issubclass(source_storage, core.SourceStorage) + ], + assistants=[ + self._get_component_json_schema(assistant) + for assistant in self._rag._components.keys() + if issubclass(assistant, core.Assistant) + ], + ) + + def register_documents( + self, *, user: str, document_registrations: list[schemas.DocumentRegistration] + ) -> list[schemas.Document]: + # We create core.Document's first, because they might update the metadata + core_documents = [ + self._config.document( + name=registration.name, metadata=registration.metadata + ) + for registration in document_registrations + ] + documents = [self._to_schema.document(document) for document in core_documents] + + with self._database.get_session() as session: + self._database.add_documents(session, user=user, documents=documents) + + return documents + + async def store_documents( + self, + *, + user: str, + ids_and_streams: list[tuple[uuid.UUID, AsyncIterator[bytes]]], + ) -> None: + if not self.supports_store_documents: + raise RagnaException( + "Ragna configuration does not support local upload", + http_status_code=http_status_code.HTTP_400_BAD_REQUEST, + ) + + ids, streams = zip(*ids_and_streams) + + with self._database.get_session() as session: + documents = self._database.get_documents(session, user=user, ids=ids) + + for document, stream in zip(documents, streams): + core_document = cast( + ragna.core.LocalDocument, self._to_core.document(document) + ) + await core_document._write(stream) + + def create_chat( + self, *, user: str, chat_creation: schemas.ChatCreation + ) -> schemas.Chat: + params = chat_creation.model_dump() + document_ids = params.pop("document_ids") + with self._database.get_session() as session: + documents = self._database.get_documents( + session, user=user, ids=document_ids + ) + + chat = schemas.Chat(documents=documents, **params) + + # Although we don't need the actual core.Chat here, this performs the input + # validation. + self._to_core.chat(chat, user=user) + + with self._database.get_session() as session: + self._database.add_chat(session, user=user, chat=chat) + + return chat + + def get_chats(self, *, user: str) -> list[schemas.Chat]: + with self._database.get_session() as session: + return self._database.get_chats(session, user=user) + + def get_chat(self, *, user: str, id: uuid.UUID) -> schemas.Chat: + with self._database.get_session() as session: + return self._database.get_chat(session, user=user, id=id) + + async def prepare_chat(self, *, user: str, id: uuid.UUID) -> schemas.Message: + core_chat = self._to_core.chat(self.get_chat(user=user, id=id), user=user) + core_message = await core_chat.prepare() + + with self._database.get_session() as session: + self._database.update_chat( + session, chat=self._to_schema.chat(core_chat), user=user + ) + + return self._to_schema.message(core_message) + + async def answer_stream( + self, *, user: str, chat_id: uuid.UUID, prompt: str + ) -> AsyncIterator[schemas.Message]: + core_chat = self._to_core.chat(self.get_chat(user=user, id=chat_id), user=user) + core_message = await core_chat.answer(prompt, stream=True) + + content_stream = aiter(core_message) + content_chunk = await anext(content_stream) + message = self._to_schema.message(core_message, content_override=content_chunk) + yield message + + # Avoid sending the sources multiple times + message_chunk = message.model_copy(update=dict(sources=None)) + async for content_chunk in content_stream: + message_chunk.content = content_chunk + yield message_chunk + + with self._database.get_session() as session: + self._database.update_chat( + session, chat=self._to_schema.chat(core_chat), user=user + ) + + def delete_chat(self, *, user: str, id: uuid.UUID) -> None: + with self._database.get_session() as session: + self._database.delete_chat(session, user=user, id=id) + + +class SchemaToCoreConverter: + def __init__(self, *, config: Config, rag: Rag) -> None: + self._config = config + self._rag = rag + + def document(self, document: schemas.Document) -> core.Document: + return self._config.document( + id=document.id, + name=document.name, + metadata=document.metadata, + ) + + def source(self, source: schemas.Source) -> core.Source: + return core.Source( + id=source.id, + document=self.document(source.document), + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message(self, message: schemas.Message) -> core.Message: + return core.Message( + message.content, + role=message.role, + sources=[self.source(source) for source in message.sources], + ) + + def chat(self, chat: schemas.Chat, *, user: str) -> core.Chat: + core_chat = self._rag.chat( + user=user, + chat_id=chat.id, + chat_name=chat.name, + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage, + assistant=chat.assistant, + **chat.params, + ) + core_chat._messages = [self.message(message) for message in chat.messages] + core_chat._prepared = chat.prepared + + return core_chat + + +class CoreToSchemaConverter: + def document(self, document: core.Document) -> schemas.Document: + return schemas.Document( + id=document.id, + name=document.name, + metadata=document.metadata, + ) + + def source(self, source: core.Source) -> schemas.Source: + return schemas.Source( + id=source.id, + document=self.document(source.document), + location=source.location, + content=source.content, + num_tokens=source.num_tokens, + ) + + def message( + self, message: core.Message, *, content_override: Optional[str] = None + ) -> schemas.Message: + return schemas.Message( + id=message.id, + content=( + content_override if content_override is not None else message.content + ), + role=message.role, + sources=[self.source(source) for source in message.sources], + timestamp=message.timestamp, + ) + + def chat(self, chat: core.Chat) -> schemas.Chat: + params = chat.params.copy() + del params["user"] + return schemas.Chat( + id=params.pop("chat_id"), + name=params.pop("chat_name"), + documents=[self.document(document) for document in chat.documents], + source_storage=chat.source_storage.display_name(), + assistant=chat.assistant.display_name(), + params=params, + messages=[self.message(message) for message in chat._messages], + prepared=chat._prepared, + ) diff --git a/ragna/deploy/_api/orm.py b/ragna/deploy/_orm.py similarity index 100% rename from ragna/deploy/_api/orm.py rename to ragna/deploy/_orm.py diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py new file mode 100644 index 00000000..cc5490b7 --- /dev/null +++ b/ragna/deploy/_schemas.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import datetime +import uuid +from typing import Any + +from pydantic import BaseModel, Field + +import ragna.core + + +class Components(BaseModel): + documents: list[str] + source_storages: list[dict[str, Any]] + assistants: list[dict[str, Any]] + + +class DocumentRegistration(BaseModel): + name: str + metadata: dict[str, Any] = Field(default_factory=dict) + + +class Document(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + name: str + metadata: dict[str, Any] + + +class Source(BaseModel): + # See orm.Source on why this is not a UUID + id: str + document: Document + location: str + content: str + num_tokens: int + + +class Message(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + content: str + role: ragna.core.MessageRole + sources: list[Source] = Field(default_factory=list) + timestamp: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) + + +class ChatCreation(BaseModel): + name: str + document_ids: list[uuid.UUID] + source_storage: str + assistant: str + params: dict[str, Any] = Field(default_factory=dict) + + +class Chat(BaseModel): + id: uuid.UUID = Field(default_factory=uuid.uuid4) + name: str + documents: list[Document] + source_storage: str + assistant: str + params: dict[str, Any] + messages: list[Message] = Field(default_factory=list) + prepared: bool = False diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py index f96375de..170e8bbd 100644 --- a/ragna/deploy/_ui/api_wrapper.py +++ b/ragna/deploy/_ui/api_wrapper.py @@ -1,104 +1,53 @@ -import json +import uuid from datetime import datetime import emoji -import httpx import param +from ragna.core._utils import default_user +from ragna.deploy import _schemas as schemas +from ragna.deploy._engine import Engine -class RagnaAuthTokenExpiredException(Exception): - """Just a wrapper around Exception""" - pass - - -# The goal is this class is to provide ready-to-use functions to interact with the API class ApiWrapper(param.Parameterized): - auth_token = param.String(default=None) - - def __init__(self, api_url, **params): - self.client = httpx.AsyncClient(base_url=api_url, timeout=60) - - super().__init__(**params) - - try: - # If no auth token is provided, we use the API base URL and only test the API is up. - # else, we test the API is up *and* the token is valid. - endpoint = ( - api_url + "/components" if self.auth_token is not None else api_url - ) - httpx.get( - endpoint, headers={"Authorization": f"Bearer {self.auth_token}"} - ).raise_for_status() - - except httpx.HTTPStatusError as e: - # unauthorized - the token is invalid - if e.response.status_code == 401: - raise RagnaAuthTokenExpiredException("Unauthorized") - else: - raise e - - async def auth(self, username, password): - self.auth_token = ( - ( - await self.client.post( - "/token", - data={"username": username, "password": password}, - ) - ) - .raise_for_status() - .json() - ) - - return True - - @param.depends("auth_token", watch=True, on_init=True) - def update_auth_header(self): - self.client.headers["Authorization"] = f"Bearer {self.auth_token}" + def __init__(self, engine: Engine): + super().__init__() + self._user = default_user() + self._engine = engine async def get_chats(self): - json_data = (await self.client.get("/chats")).raise_for_status().json() + json_data = [ + chat.model_dump(mode="json") + for chat in self._engine.get_chats(user=self._user) + ] for chat in json_data: chat["messages"] = [self.improve_message(msg) for msg in chat["messages"]] return json_data async def answer(self, chat_id, prompt): - async with self.client.stream( - "POST", - f"/chats/{chat_id}/answer", - json={"prompt": prompt, "stream": True}, - ) as response: - async for data in response.aiter_lines(): - yield self.improve_message(json.loads(data)) + async for message in self._engine.answer_stream( + user=self._user, chat_id=uuid.UUID(chat_id), prompt=prompt + ): + yield self.improve_message(message.model_dump(mode="json")) async def get_components(self): - return (await self.client.get("/components")).raise_for_status().json() - - # Upload and related functions - def upload_endpoints(self): - return { - "informations_endpoint": f"{self.client.base_url}/document", - } + return self._engine.get_components().model_dump(mode="json") async def start_and_prepare( self, name, documents, source_storage, assistant, params ): - response = await self.client.post( - "/chats", - json={ - "name": name, - "documents": documents, - "source_storage": source_storage, - "assistant": assistant, - "params": params, - }, + chat = self._engine.create_chat( + user=self._user, + chat_creation=schemas.ChatCreation( + name=name, + document_ids=[document.id for document in documents], + source_storage=source_storage, + assistant=assistant, + params=params, + ), ) - chat = response.raise_for_status().json() - - response = await self.client.post(f"/chats/{chat['id']}/prepare", timeout=None) - response.raise_for_status() - - return chat["id"] + await self._engine.prepare_chat(user=self._user, id=chat.id) + return str(chat.id) def improve_message(self, msg): msg["timestamp"] = datetime.strptime(msg["timestamp"], "%Y-%m-%dT%H:%M:%S.%f") diff --git a/ragna/deploy/_ui/app.py b/ragna/deploy/_ui/app.py index ed418271..052ff36d 100644 --- a/ragna/deploy/_ui/app.py +++ b/ragna/deploy/_ui/app.py @@ -1,18 +1,16 @@ from pathlib import Path -from urllib.parse import urlsplit +from typing import cast import panel as pn import param +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles -from ragna._utils import handle_localhost_origins -from ragna.deploy import Config +from ragna.deploy._engine import Engine from . import js from . import styles as ui -from .api_wrapper import ApiWrapper, RagnaAuthTokenExpiredException -from .auth_page import AuthPage -from .js_utils import redirect_script -from .logout_page import LogoutPage +from .api_wrapper import ApiWrapper from .main_page import MainPage pn.extension( @@ -24,17 +22,13 @@ class App(param.Parameterized): - def __init__(self, *, hostname, port, api_url, origins, open_browser): + def __init__(self, engine: Engine): super().__init__() # Apply the design modifiers to the panel components # It returns all the CSS files of the modifiers self.css_filepaths = ui.apply_design_modifiers() - self.hostname = hostname - self.port = port - self.api_url = api_url - self.origins = origins - self.open_browser = open_browser + self._engine = engine def get_template(self): # A bit hacky, but works. @@ -79,81 +73,71 @@ def get_template(self): return template def index_page(self): - if "auth_token" not in pn.state.cookies: - return redirect_script(remove="", append="auth") - - try: - api_wrapper = ApiWrapper( - api_url=self.api_url, auth_token=pn.state.cookies["auth_token"] - ) - except RagnaAuthTokenExpiredException: - # If the token has expired / is invalid, we redirect to the logout page. - # The logout page will delete the cookie and redirect to the auth page. - return redirect_script(remove="", append="logout") + api_wrapper = ApiWrapper(self._engine) template = self.get_template() main_page = MainPage(api_wrapper=api_wrapper, template=template) template.main.append(main_page) return template - def auth_page(self): - # If the user is already authenticated, we receive the auth token in the cookie. - # in that case, redirect to the index page. - if "auth_token" in pn.state.cookies: - # Usually, we do a redirect this way : - # >>> pn.state.location.param.update(reload=True, pathname="/") - # But it only works once the page is fully loaded. - # So we render a javascript redirect instead. - return redirect_script(remove="auth") + def health_page(self): + return pn.pane.HTML("

Ok

") - template = self.get_template() - auth_page = AuthPage(api_wrapper=ApiWrapper(api_url=self.api_url)) - template.main.append(auth_page) - return template + def add_panel_app(self, server, panel_app_fn, endpoint): + # FIXME: this code will ultimately be distributed as part of panel + from functools import partial - def logout_page(self): - template = self.get_template() - logout_page = LogoutPage(api_wrapper=ApiWrapper(api_url=self.api_url)) - template.main.append(logout_page) - return template + import panel as pn + from bokeh.application import Application + from bokeh.application.handlers.function import FunctionHandler + from bokeh_fastapi import BokehFastAPI + from bokeh_fastapi.handler import WSHandler + from fastapi.responses import FileResponse + from panel.io.document import extra_socket_handlers + from panel.io.resources import COMPONENT_PATH + from panel.io.server import ComponentResourceHandler + from panel.io.state import set_curdoc - def health_page(self): - return pn.pane.HTML("

Ok

") + def dispatch_fastapi(conn, events=None, msg=None): + if msg is None: + msg = conn.protocol.create("PATCH-DOC", events) + return [conn._socket.send_message(msg)] + + extra_socket_handlers[WSHandler] = dispatch_fastapi + + def panel_app(doc): + doc.on_event("document_ready", partial(pn.state._schedule_on_load, doc)) - def serve(self): - all_pages = { - "/": self.index_page, - "/auth": self.auth_page, - "/logout": self.logout_page, - "/health": self.health_page, - } - titles = {"/": "Home"} - - pn.serve( - all_pages, - titles=titles, - address=self.hostname, - port=self.port, - admin=True, - start=True, - location=True, - show=self.open_browser, - keep_alive=30 * 1000, # 30s - autoreload=True, - profiler="pyinstrument", - allow_websocket_origin=[urlsplit(origin).netloc for origin in self.origins], - static_dirs={ - dir: str(Path(__file__).parent / dir) - for dir in ["css", "imgs", "resources"] - }, + with set_curdoc(doc): + panel_app = panel_app_fn() + panel_app.server_doc(doc) + + handler = FunctionHandler(panel_app) + application = Application(handler) + + BokehFastAPI({endpoint: application}, server=server) + + @server.get( + f"/{COMPONENT_PATH.rstrip('/')}" + "/{path:path}", include_in_schema=False ) + def get_component_resource(path: str): + # ComponentResourceHandler.parse_url_path only ever accesses + # self._resource_attrs, which fortunately is a class attribute. Thus, we can + # get away with using the method without actually instantiating the class + self_ = cast(ComponentResourceHandler, ComponentResourceHandler) + resolved_path = ComponentResourceHandler.parse_url_path(self_, path) + return FileResponse(resolved_path) + + def serve_with_fastapi(self, app: FastAPI, endpoint: str): + self.add_panel_app(app, self.index_page, endpoint) + + for dir in ["css", "imgs"]: + app.mount( + f"/{dir}", + StaticFiles(directory=str(Path(__file__).parent / dir)), + name=dir, + ) -def app(*, config: Config, open_browser: bool) -> App: - return App( - hostname=config.ui.hostname, - port=config.ui.port, - api_url=config.api.url, - origins=handle_localhost_origins(config.ui.origins), - open_browser=open_browser, - ) +def app(engine: Engine) -> App: + return App(engine) diff --git a/ragna/deploy/_ui/auth_page.py b/ragna/deploy/_ui/auth_page.py deleted file mode 100644 index 4df5098c..00000000 --- a/ragna/deploy/_ui/auth_page.py +++ /dev/null @@ -1,91 +0,0 @@ -import panel as pn -import param - - -class AuthPage(pn.viewable.Viewer, param.Parameterized): - feedback_message = param.String(default=None) - - custom_js = param.String(default="") - - def __init__(self, api_wrapper, **params): - super().__init__(**params) - self.api_wrapper = api_wrapper - - self.main_layout = None - - self.login_input = pn.widgets.TextInput( - name="Email", - css_classes=["auth_login_input"], - ) - self.password_input = pn.widgets.PasswordInput( - name="Password", - css_classes=["auth_password_input"], - ) - - async def perform_login(self, event=None): - self.main_layout.loading = True - - home_path = pn.state.location.pathname.rstrip("/").rstrip("auth") - try: - authed = await self.api_wrapper.auth( - self.login_input.value, self.password_input.value - ) - - if authed: - # Sets the cookie on the JS side - self.custom_js = f""" document.cookie = "auth_token={self.api_wrapper.auth_token}; path:{home_path}"; """ - - except Exception: - authed = False - - if authed: - # perform redirect - pn.state.location.param.update(reload=True, pathname=home_path) - else: - self.feedback_message = "Authentication failed. Please retry." - - self.main_layout.loading = False - - @pn.depends("feedback_message") - def display_error_message(self): - if self.feedback_message is None: - return None - else: - return pn.pane.HTML( - f"""
{self.feedback_message}
""", - css_classes=["auth_error"], - ) - - @pn.depends("custom_js") - def wrapped_custom_js(self): - return pn.pane.HTML( - f""" - Log In", - css_classes=["auth_title"], - ), - self.display_error_message, - self.login_input, - self.password_input, - pn.pane.HTML("
"), - login_button, - css_classes=["auth_page_main_layout"], - ) - - return self.main_layout diff --git a/ragna/deploy/_ui/central_view.py b/ragna/deploy/_ui/central_view.py index 173c6d55..00de3e8e 100644 --- a/ragna/deploy/_ui/central_view.py +++ b/ragna/deploy/_ui/central_view.py @@ -153,7 +153,7 @@ def _update_placeholder(self): show_timestamp=False, ) - def _build_message(self, *args, **kwargs) -> RagnaChatMessage | None: + def _build_message(self, *args, **kwargs) -> Optional[RagnaChatMessage]: message = super()._build_message(*args, **kwargs) if message is None: return None @@ -189,11 +189,11 @@ def on_click_chat_info_wrapper(self, event): pills = "".join( [ f"""
{d['name']}
""" - for d in self.current_chat["metadata"]["documents"] + for d in self.current_chat["documents"] ] ) - grid_height = len(self.current_chat["metadata"]["documents"]) // 3 + grid_height = len(self.current_chat["documents"]) // 3 markdown = "\n".join( [ @@ -202,14 +202,14 @@ def on_click_chat_info_wrapper(self, event): f"
{pills}

\n\n", "----", "**Source Storage**", - f"""{self.current_chat['metadata']['source_storage']}\n""", + f"""{self.current_chat['source_storage']}\n""", "----", "**Assistant**", - f"""{self.current_chat['metadata']['assistant']}\n""", + f"""{self.current_chat['assistant']}\n""", "**Advanced configuration**", *[ f"- **{key.replace('_', ' ').title()}**: {value}" - for key, value in self.current_chat["metadata"]["params"].items() + for key, value in self.current_chat["params"].items() ], ] ) @@ -275,7 +275,7 @@ def get_user_from_role(self, role: Literal["system", "user", "assistant"]) -> st elif role == "user": return cast(str, self.user) elif role == "assistant": - return cast(str, self.current_chat["metadata"]["assistant"]) + return cast(str, self.current_chat["assistant"]) else: raise RuntimeError @@ -301,12 +301,15 @@ async def chat_callback( message.clipboard_button.value = message.content_pane.object message.assistant_toolbar.visible = True - except Exception: + except Exception as exc: + import traceback + yield RagnaChatMessage( - ( - "Sorry, something went wrong. " - "If this problem persists, please contact your administrator." - ), + # ( + # "Sorry, something went wrong. " + # "If this problem persists, please contact your administrator." + # ), + "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)), role="system", user=self.get_user_from_role("system"), ) @@ -358,7 +361,7 @@ def header(self): current_chat_name = "" if self.current_chat is not None: - current_chat_name = self.current_chat["metadata"]["name"] + current_chat_name = self.current_chat["name"] chat_name_header = pn.pane.HTML( f"

{current_chat_name}

", @@ -370,9 +373,9 @@ def header(self): if ( self.current_chat is not None and "metadata" in self.current_chat - and "documents" in self.current_chat["metadata"] + and "documents" in self.current_chat ): - doc_names = [d["name"] for d in self.current_chat["metadata"]["documents"]] + doc_names = [d["name"] for d in self.current_chat["documents"]] # FIXME: Instead of setting a hard limit of 20 documents here, this should # scale automatically with the width of page @@ -385,7 +388,9 @@ def header(self): chat_documents_pills.append(pill) - self.chat_info_button.name = f"{self.current_chat['metadata']['assistant']} | {self.current_chat['metadata']['source_storage']}" + self.chat_info_button.name = ( + f"{self.current_chat['assistant']} | {self.current_chat['source_storage']}" + ) return pn.Row( chat_name_header, diff --git a/ragna/deploy/_ui/components/file_uploader.py b/ragna/deploy/_ui/components/file_uploader.py index c4fcbaae..1568bc61 100644 --- a/ragna/deploy/_ui/components/file_uploader.py +++ b/ragna/deploy/_ui/components/file_uploader.py @@ -17,10 +17,9 @@ class FileUploader(ReactiveHTML, Widget): # type: ignore[misc] title = param.String(default="") - def __init__(self, allowed_documents, token, informations_endpoint, **params): + def __init__(self, allowed_documents, informations_endpoint, **params): super().__init__(**params) - self.token = token self.informations_endpoint = informations_endpoint self.after_upload_callback = None @@ -56,7 +55,7 @@ def perform_upload(self, event=None, after_upload_callback=None): self.custom_js = ( final_callback_js + random_id - + f"""upload( self.get_upload_files(), '{self.token}', '{self.informations_endpoint}', final_callback) """ + + f"""upload( self.get_upload_files(), '{self.informations_endpoint}', final_callback) """ ) _child_config = { @@ -134,9 +133,8 @@ def perform_upload(self, event=None, after_upload_callback=None):
diff --git a/ragna/deploy/_ui/css/auth/button.css b/ragna/deploy/_ui/css/auth/button.css deleted file mode 100644 index e8a6ad3d..00000000 --- a/ragna/deploy/_ui/css/auth/button.css +++ /dev/null @@ -1,5 +0,0 @@ -:host(.auth_login_button) { - width: 100%; - margin-left: 0px; - margin-right: 0px; -} diff --git a/ragna/deploy/_ui/css/auth/column.css b/ragna/deploy/_ui/css/auth/column.css deleted file mode 100644 index d7bb60aa..00000000 --- a/ragna/deploy/_ui/css/auth/column.css +++ /dev/null @@ -1,23 +0,0 @@ -:host(.auth_page_main_layout) { - background-color: white; - border-radius: 5px; - box-shadow: lightgray 0px 0px 10px; - padding: 0 25px 0 25px; - - width: 30%; - min-width: 360px; - max-width: 430px; - - margin-left: auto; - margin-right: auto; - margin-top: 10%; -} - -:host(.auth_page_main_layout) > div { - margin-bottom: 10px; - margin-top: 10px; -} - -:host(.auth_page_main_layout) .bk-panel-models-layout-Column { - width: 100%; -} diff --git a/ragna/deploy/_ui/css/auth/html.css b/ragna/deploy/_ui/css/auth/html.css deleted file mode 100644 index 16c4490f..00000000 --- a/ragna/deploy/_ui/css/auth/html.css +++ /dev/null @@ -1,24 +0,0 @@ -:host(.auth_error) { - width: 100%; - margin-left: 0px; - margin-right: 0px; -} - -:host(.auth_error) div.auth_error { - width: 100%; - color: red; - text-align: center; - font-weight: 600; - font-size: 16px; -} - -:host(.auth_title) { - width: 100%; - margin-left: 0px; - margin-right: 0px; - text-align: center; -} -:host(.auth_title) h1 { - font-weight: 600; - font-size: 24px; -} diff --git a/ragna/deploy/_ui/css/auth/textinput.css b/ragna/deploy/_ui/css/auth/textinput.css deleted file mode 100644 index b6c16ce9..00000000 --- a/ragna/deploy/_ui/css/auth/textinput.css +++ /dev/null @@ -1,18 +0,0 @@ -:host(.auth_login_input), -:host(.auth_password_input) { - width: 100%; - margin-left: 0px; - margin-right: 0px; -} - -:host(.auth_login_input) label, -:host(.auth_password_input) label { - font-weight: 600; - font-size: 16px; -} - -:host(.auth_login_input) input, -:host(.auth_password_input) input { - background-color: white !important; - height: 38px; -} diff --git a/ragna/deploy/_ui/js_utils.py b/ragna/deploy/_ui/js_utils.py deleted file mode 100644 index 4497de2d..00000000 --- a/ragna/deploy/_ui/js_utils.py +++ /dev/null @@ -1,62 +0,0 @@ -import panel as pn - - -def preformat(text): - """Allows {{key}} to be used for formatting in textcthat already uses - curly braces. First switch this into something else, replace curlies - with double curlies, and then switch back to regular braces - """ - text = text.replace("{{", "<<<").replace("}}", ">>>") - text = text.replace("{", "{{").replace("}", "}}") - text = text.replace("<<<", "{").replace(">>>", "}") - return text - - -def redirect_script(remove, append="/", remove_auth_cookie=False): - """ - This function returns a js script to redirect to correct url. - :param remove: string to remove from the end of the url - :param append: string to append at the end of the url - :param remove_auth_cookie: boolean, will clear auth_token cookie when true. - :return: string javascript script - - Examples: - ========= - - # This will remove nothing from the end of the url and will - # add auth to it, so /foo/bar/car/ becomes /foo/bar/car/auth - >>> redirect_script(remove="", append="auth") - - # This will remove nothing from the end of the url and will - # add auth to it, so /foo/bar/car/ becomes /foo/bar/car/logout - >>> redirect_script(remove="", append="logout") - - # This will remove "auth" from the end of the url and will add / to it - # so /foo/bar/car/auth becomes /foo/bar/car/ - >>> redirect_script(remove="auth", append="/") - """ - js_script = preformat( - r""" - - """ - ) - - return pn.pane.HTML( - js_script.format( - remove=remove, - append=append, - remove_auth_cookie=str(remove_auth_cookie).lower(), - ) - ) diff --git a/ragna/deploy/_ui/left_sidebar.py b/ragna/deploy/_ui/left_sidebar.py index 379acef5..267a3b77 100644 --- a/ragna/deploy/_ui/left_sidebar.py +++ b/ragna/deploy/_ui/left_sidebar.py @@ -72,7 +72,7 @@ def __panel__(self): self.chat_buttons = [] for chat in self.chats: button = pn.widgets.Button( - name=chat["metadata"]["name"], + name=chat["name"], css_classes=["chat_button"], ) button.on_click(lambda event, c=chat: self.on_click_chat_wrapper(event, c)) diff --git a/ragna/deploy/_ui/logout_page.py b/ragna/deploy/_ui/logout_page.py deleted file mode 100644 index d86b77ab..00000000 --- a/ragna/deploy/_ui/logout_page.py +++ /dev/null @@ -1,21 +0,0 @@ -import panel as pn -import param - -from ragna.deploy._ui.js_utils import redirect_script - - -class LogoutPage(pn.viewable.Viewer, param.Parameterized): - def __init__(self, api_wrapper, **params): - super().__init__(**params) - self.api_wrapper = api_wrapper - - self.api_wrapper.auth_token = None - - def __panel__(self): - # Usually, we do a redirect this way : - # >>> pn.state.location.param.update(reload=True, pathname="/") - # But it only works once the page is fully loaded. - # So we render a javascript redirect instead. - - # To remove the token from the cookie, we have to force its expiry date to the past. - return redirect_script(remove="logout", append="/", remove_auth_cookie=True) diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index 70b02731..a8bb44e7 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -1,11 +1,13 @@ from datetime import datetime, timedelta, timezone +from typing import AsyncIterator import panel as pn import param +from ragna.deploy import _schemas as schemas + from . import js from . import styles as ui -from .components.file_uploader import FileUploader def get_default_chat_name(timezone_offset=None): @@ -82,16 +84,11 @@ def __init__(self, api_wrapper, **params): self.api_wrapper = api_wrapper - upload_endpoints = self.api_wrapper.upload_endpoints() - self.chat_name_input = pn.widgets.TextInput.from_param( self.param.chat_name, ) - self.document_uploader = FileUploader( - [], # the allowed documents are set in the model_section function - self.api_wrapper.auth_token, - upload_endpoints["informations_endpoint"], - ) + # FIXME: accept + self.document_uploader = pn.widgets.FileInput(multiple=True) # Most widgets (including those that use from_param) should be placed after the super init call self.cancel_button = pn.widgets.Button( @@ -115,12 +112,38 @@ def __init__(self, api_wrapper, **params): self.got_timezone = False - def did_click_on_start_chat_button(self, event): - if not self.document_uploader.can_proceed_to_upload(): + async def did_click_on_start_chat_button(self, event): + if not self.document_uploader.value: self.change_upload_files_label("missing_file") else: self.start_chat_button.disabled = True - self.document_uploader.perform_upload(event, self.did_finish_upload) + documents = self.api_wrapper._engine.register_documents( + user=self.api_wrapper._user, + document_registrations=[ + schemas.DocumentRegistration(name=name) + for name in self.document_uploader.filename + ], + ) + + if self.api_wrapper._engine.supports_store_documents: + + def make_content_stream(data: bytes) -> AsyncIterator[bytes]: + async def content_stream() -> AsyncIterator[bytes]: + yield data + + return content_stream() + + await self.api_wrapper._engine.store_documents( + user=self.api_wrapper._user, + ids_and_streams=[ + (document.id, make_content_stream(data)) + for document, data in zip( + documents, self.document_uploader.value + ) + ], + ) + + await self.did_finish_upload(documents) async def did_finish_upload(self, uploaded_documents): # at this point, the UI has uploaded the files to the API. diff --git a/ragna/deploy/_ui/resources/upload.js b/ragna/deploy/_ui/resources/upload.js index 1ecd54b2..905da833 100644 --- a/ragna/deploy/_ui/resources/upload.js +++ b/ragna/deploy/_ui/resources/upload.js @@ -1,8 +1,8 @@ -function upload(files, token, informationEndpoint, final_callback) { - uploadBatches(files, token, informationEndpoint).then(final_callback); +function upload(files, informationEndpoint, final_callback) { + uploadBatches(files, informationEndpoint).then(final_callback); } -async function uploadBatches(files, token, informationEndpoint) { +async function uploadBatches(files, informationEndpoint) { const batchSize = 500; const queue = Array.from(files); @@ -10,20 +10,20 @@ async function uploadBatches(files, token, informationEndpoint) { while (queue.length) { const batch = queue.splice(0, batchSize); - await Promise.all( - batch.map((file) => uploadFile(file, token, informationEndpoint)), - ).then((results) => { - uploaded.push(...results); - }); + await Promise.all(batch.map((file) => uploadFile(file, informationEndpoint))).then( + (results) => { + uploaded.push(...results); + }, + ); } return uploaded; } -async function uploadFile(file, token, informationEndpoint) { +async function uploadFile(file, informationEndpoint) { const response = await fetch(informationEndpoint, { method: "POST", - headers: { "Content-Type": "application/json", Authorization: `Bearer ${token}` }, + headers: { "Content-Type": "application/json" }, body: JSON.stringify({ name: file.name }), }); const documentUpload = await response.json(); diff --git a/ragna/deploy/_ui/styles.py b/ragna/deploy/_ui/styles.py index 213e6e1a..793b4ca6 100644 --- a/ragna/deploy/_ui/styles.py +++ b/ragna/deploy/_ui/styles.py @@ -34,7 +34,6 @@ pn.pane.Markdown, ], "chat_info": [pn.pane.Markdown, pn.widgets.Button], - "auth": [pn.widgets.TextInput, pn.pane.HTML, pn.widgets.Button, pn.Column], "central_view": [pn.Column, pn.Row, pn.pane.HTML], "chat_interface": [ pn.widgets.TextInput, diff --git a/ragna/deploy/_utils.py b/ragna/deploy/_utils.py new file mode 100644 index 00000000..4f369a52 --- /dev/null +++ b/ragna/deploy/_utils.py @@ -0,0 +1,57 @@ +from typing import Optional +from urllib.parse import SplitResult, urlsplit, urlunsplit + +from fastapi import status +from fastapi.responses import RedirectResponse + +from ragna.core import RagnaException + +_REDIRECT_ROOT_PATH: Optional[str] = None + + +def set_redirect_root_path(root_path: str) -> None: + global _REDIRECT_ROOT_PATH + _REDIRECT_ROOT_PATH = root_path + + +def redirect( + url: str, *, status_code: int = status.HTTP_303_SEE_OTHER +) -> RedirectResponse: + if _REDIRECT_ROOT_PATH is None: + raise RagnaException + + if url.startswith("/"): + url = _REDIRECT_ROOT_PATH + url + + return RedirectResponse(url, status_code=status_code) + + +def handle_localhost_origins(origins: list[str]) -> list[str]: + # Since localhost is an alias for 127.0.0.1, we allow both so users and developers + # don't need to worry about it. + localhost_origins = { + components.hostname: components + for url in origins + if (components := urlsplit(url)).hostname in {"127.0.0.1", "localhost"} + } + if "127.0.0.1" in localhost_origins and "localhost" not in localhost_origins: + origins.append( + urlunsplit(_replace_hostname(localhost_origins["127.0.0.1"], "localhost")) + ) + elif "localhost" in localhost_origins and "127.0.0.1" not in localhost_origins: + origins.append( + urlunsplit(_replace_hostname(localhost_origins["localhost"], "127.0.0.1")) + ) + + return origins + + +def _replace_hostname(split_result: SplitResult, hostname: str) -> SplitResult: + # This is a separate function, since hostname is not an element of the SplitResult + # namedtuple, but only a property. Thus, we need to replace the netloc item, from + # which the hostname is generated. + if split_result.port is None: + netloc = hostname + else: + netloc = f"{hostname}:{split_result.port}" + return split_result._replace(netloc=netloc) diff --git a/scripts/add_chats.py b/scripts/add_chats.py index b8c15194..5f550289 100644 --- a/scripts/add_chats.py +++ b/scripts/add_chats.py @@ -1,71 +1,70 @@ import datetime import json -import os import httpx -from ragna.core._utils import default_user - def main(): client = httpx.Client(base_url="http://127.0.0.1:31476") - client.get("/").raise_for_status() + client.get("/health").raise_for_status() + + # ## authentication + # + # username = default_user() + # token = ( + # client.post( + # "/token", + # data={ + # "username": username, + # "password": os.environ.get( + # "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username + # ), + # }, + # ) + # .raise_for_status() + # .json() + # ) + # client.headers["Authorization"] = f"Bearer {token}" + + print() - ## authentication + ## documents - username = default_user() - token = ( + documents = ( client.post( - "/token", - data={ - "username": username, - "password": os.environ.get( - "RAGNA_DEMO_AUTHENTICATION_PASSWORD", username - ), - }, + "/api/documents", json=[{"name": f"document{i}.txt"} for i in range(5)] ) .raise_for_status() .json() ) - client.headers["Authorization"] = f"Bearer {token}" - ## documents - - documents = [] - for i in range(5): - name = f"document{i}.txt" - document_upload = ( - client.post("/document", json={"name": name}).raise_for_status().json() - ) - parameters = document_upload["parameters"] - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": f"Content of {name}".encode()}, - ).raise_for_status() - documents.append(document_upload["document"]) + client.put( + "/api/documents", + files=[ + ("documents", (document["id"], f"Content of {document['name']}".encode())) + for document in documents + ], + ).raise_for_status() ## chat 1 chat = ( client.post( - "/chats", + "/api/chats", json={ "name": "Test chat", - "documents": documents[:2], + "document_ids": [document["id"] for document in documents[:2]], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", - "params": {}, }, ) .raise_for_status() .json() ) - client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "Hello!"}, ).raise_for_status() @@ -73,55 +72,53 @@ def main(): chat = ( client.post( - "/chats", + "/api/chats", json={ "name": f"Chat {datetime.datetime.now():%x %X}", - "documents": documents[2:4], + "document_ids": [document["id"] for document in documents[2:]], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", - "params": {}, }, ) .raise_for_status() .json() ) - client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() for _ in range(3): client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "What is Ragna? Please, I need to know!"}, ).raise_for_status() - ## chat 3 + # ## chat 3 chat = ( client.post( - "/chats", + "/api/chats", json={ "name": ( "Really long chat name that likely needs to be truncated somehow. " "If you can read this, truncating failed :boom:" ), - "documents": [documents[i] for i in [0, 2, 4]], + "document_ids": [documents[i]["id"] for i in [0, 2, 4]], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", - "params": {}, }, ) .raise_for_status() .json() ) - client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status() client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "Hello!"}, ).raise_for_status() client.post( - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": "Ok, in that case show me some pretty markdown!"}, ).raise_for_status() - chats = client.get("/chats").raise_for_status().json() + chats = client.get("/api/chats").raise_for_status().json() print(json.dumps(chats)) diff --git a/tests/deploy/api/conftest.py b/tests/deploy/api/conftest.py new file mode 100644 index 00000000..4bc8053c --- /dev/null +++ b/tests/deploy/api/conftest.py @@ -0,0 +1,41 @@ +import contextlib +import json + +import httpx +import pytest + + +@pytest.fixture(scope="package", autouse=True) +def enhance_raise_for_status(package_mocker): + raise_for_status = httpx.Response.raise_for_status + + def enhanced_raise_for_status(self): + __tracebackhide__ = True + + try: + return raise_for_status(self) + except httpx.HTTPStatusError as error: + content = None + with contextlib.suppress(Exception): + content = error.response.read() + content = content.decode() + content = "\n" + json.dumps(json.loads(content), indent=2) + + if content is None: + raise error + + message = f"{error}\nResponse content: {content}" + raise httpx.HTTPStatusError( + message, request=error.request, response=error.response + ) from None + + yield package_mocker.patch( + ".".join( + [ + httpx.Response.__module__, + httpx.Response.__name__, + raise_for_status.__name__, + ] + ), + new=enhanced_raise_for_status, + ) diff --git a/tests/deploy/api/test_batch_endpoints.py b/tests/deploy/api/test_batch_endpoints.py deleted file mode 100644 index 3c85c77c..00000000 --- a/tests/deploy/api/test_batch_endpoints.py +++ /dev/null @@ -1,85 +0,0 @@ -from fastapi import status -from fastapi.testclient import TestClient - -from ragna.deploy import Config -from ragna.deploy._api import app -from tests.deploy.utils import authenticate_with_api - - -def test_batch_sequential_upload_equivalence(tmp_local_root): - "Check that uploading documents sequentially and in batch gives the same result" - config = Config(local_root=tmp_local_root) - - document_root = config.local_root / "documents" - document_root.mkdir() - document_path1 = document_root / "test1.txt" - with open(document_path1, "w") as file: - file.write("!\n") - document_path2 = document_root / "test2.txt" - with open(document_path2, "w") as file: - file.write("?\n") - - with TestClient( - app(config=Config(), ignore_unavailable_components=False) - ) as client: - authenticate_with_api(client) - - document1_upload = ( - client.post("/document", json={"name": document_path1.name}) - .raise_for_status() - .json() - ) - document2_upload = ( - client.post("/document", json={"name": document_path2.name}) - .raise_for_status() - .json() - ) - - documents_upload = ( - client.post( - "/documents", json={"names": [document_path1.name, document_path2.name]} - ) - .raise_for_status() - .json() - ) - - assert ( - document1_upload["parameters"]["url"] - == documents_upload[0]["parameters"]["url"] - ) - assert ( - document2_upload["parameters"]["url"] - == documents_upload[1]["parameters"]["url"] - ) - - assert ( - document1_upload["document"]["name"] - == documents_upload[0]["document"]["name"] - ) - assert ( - document2_upload["document"]["name"] - == documents_upload[1]["document"]["name"] - ) - - # assuming that if test passes for first document it will also pass for the other - with open(document_path1, "rb") as file: - response_sequential_upload1 = client.request( - document1_upload["parameters"]["method"], - document1_upload["parameters"]["url"], - data=document1_upload["parameters"]["data"], - files={"file": file}, - ) - response_batch_upload1 = client.request( - documents_upload[0]["parameters"]["method"], - documents_upload[0]["parameters"]["url"], - data=documents_upload[0]["parameters"]["data"], - files={"file": file}, - ) - - assert response_sequential_upload1.status_code == status.HTTP_200_OK - assert response_batch_upload1.status_code == status.HTTP_200_OK - - assert ( - response_sequential_upload1.json()["name"] - == response_batch_upload1.json()["name"] - ) diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index b7fe464c..8e23cdeb 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -5,8 +5,7 @@ from ragna import assistants from ragna.core import RagnaException from ragna.deploy import Config -from ragna.deploy._api import app -from tests.deploy.utils import authenticate_with_api +from tests.deploy.utils import authenticate_with_api, make_api_app @pytest.mark.parametrize("ignore_unavailable_components", [True, False]) @@ -21,20 +20,20 @@ def test_ignore_unavailable_components(ignore_unavailable_components): if ignore_unavailable_components: with TestClient( - app( + make_api_app( config=config, ignore_unavailable_components=ignore_unavailable_components, ) ) as client: authenticate_with_api(client) - components = client.get("/components").raise_for_status().json() + components = client.get("/api/components").raise_for_status().json() assert [assistant["title"] for assistant in components["assistants"]] == [ available_assistant.display_name() ] else: with pytest.raises(RagnaException, match="not available"): - app( + make_api_app( config=config, ignore_unavailable_components=ignore_unavailable_components, ) @@ -47,7 +46,7 @@ def test_ignore_unavailable_components_at_least_one(): config = Config(assistants=[unavailable_assistant]) with pytest.raises(RagnaException, match="No component available"): - app( + make_api_app( config=config, ignore_unavailable_components=True, ) @@ -63,35 +62,26 @@ def test_unknown_component(tmp_local_root): file.write("!\n") with TestClient( - app(config=Config(), ignore_unavailable_components=False) + make_api_app(config=Config(), ignore_unavailable_components=False) ) as client: authenticate_with_api(client) - document_upload = ( - client.post("/document", json={"name": document_path.name}) + document = ( + client.post("/api/documents", json=[{"name": document_path.name}]) .raise_for_status() - .json() + .json()[0] ) - document = document_upload["document"] - assert document["name"] == document_path.name - parameters = document_upload["parameters"] with open(document_path, "rb") as file: - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": file}, - ) + client.put("/api/documents", files={"documents": (document["id"], file)}) response = client.post( - "/chats", + "/api/chats", json={ "name": "test-chat", + "document_ids": [document["id"]], "source_storage": "unknown_source_storage", "assistant": "unknown_assistant", - "params": {}, - "documents": [document], }, ) diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index 4abbf7cf..e023c0ee 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -4,8 +4,7 @@ from fastapi.testclient import TestClient from ragna.deploy import Config -from ragna.deploy._api import app -from tests.deploy.utils import TestAssistant, authenticate_with_api +from tests.deploy.utils import TestAssistant, authenticate_with_api, make_api_app from tests.utils import skip_on_windows @@ -21,31 +20,28 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): with open(document_path, "w") as file: file.write("!\n") - with TestClient(app(config=config, ignore_unavailable_components=False)) as client: + with TestClient( + make_api_app(config=config, ignore_unavailable_components=False) + ) as client: authenticate_with_api(client) - assert client.get("/chats").raise_for_status().json() == [] + assert client.get("/api/chats").raise_for_status().json() == [] - document_upload = ( - client.post("/document", json={"name": document_path.name}) + documents = ( + client.post("/api/documents", json=[{"name": document_path.name}]) .raise_for_status() .json() ) - document = document_upload["document"] + assert len(documents) == 1 + document = documents[0] assert document["name"] == document_path.name - parameters = document_upload["parameters"] with open(document_path, "rb") as file: - client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": file}, - ) + client.put("/api/documents", files={"documents": (document["id"], file)}) - components = client.get("/components").raise_for_status().json() - documents = components["documents"] - assert set(documents) == config.document.supported_suffixes() + components = client.get("/api/components").raise_for_status().json() + supported_documents = components["documents"] + assert set(supported_documents) == config.document.supported_suffixes() source_storages = [ json_schema["title"] for json_schema in components["source_storages"] ] @@ -60,26 +56,32 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): source_storage = source_storages[0] assistant = assistants[0] - chat_metadata = { + chat_creation = { "name": "test-chat", + "document_ids": [document["id"]], "source_storage": source_storage, "assistant": assistant, "params": {"multiple_answer_chunks": multiple_answer_chunks}, - "documents": [document], } - chat = client.post("/chats", json=chat_metadata).raise_for_status().json() - assert chat["metadata"] == chat_metadata + chat = client.post("/api/chats", json=chat_creation).raise_for_status().json() + for field in ["name", "source_storage", "assistant", "params"]: + assert chat[field] == chat_creation[field] + assert [document["id"] for document in chat["documents"]] == chat_creation[ + "document_ids" + ] assert not chat["prepared"] assert chat["messages"] == [] - assert client.get("/chats").raise_for_status().json() == [chat] - assert client.get(f"/chats/{chat['id']}").raise_for_status().json() == chat + assert client.get("/api/chats").raise_for_status().json() == [chat] + assert client.get(f"/api/chats/{chat['id']}").raise_for_status().json() == chat - message = client.post(f"/chats/{chat['id']}/prepare").raise_for_status().json() + message = ( + client.post(f"/api/chats/{chat['id']}/prepare").raise_for_status().json() + ) assert message["role"] == "system" assert message["sources"] == [] - chat = client.get(f"/chats/{chat['id']}").raise_for_status().json() + chat = client.get(f"/api/chats/{chat['id']}").raise_for_status().json() assert chat["prepared"] assert len(chat["messages"]) == 1 assert chat["messages"][-1] == message @@ -88,7 +90,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): if stream_answer: with client.stream( "POST", - f"/chats/{chat['id']}/answer", + f"/api/chats/{chat['id']}/answer", json={"prompt": prompt, "stream": True}, ) as response: chunks = [json.loads(chunk) for chunk in response.iter_lines()] @@ -97,7 +99,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): message["content"] = "".join(chunk["content"] for chunk in chunks) else: message = ( - client.post(f"/chats/{chat['id']}/answer", json={"prompt": prompt}) + client.post(f"/api/chats/{chat['id']}/answer", json={"prompt": prompt}) .raise_for_status() .json() ) @@ -107,7 +109,7 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): document_path.name } - chat = client.get(f"/chats/{chat['id']}").raise_for_status().json() + chat = client.get(f"/api/chats/{chat['id']}").raise_for_status().json() assert len(chat["messages"]) == 3 assert chat["messages"][-1] == message assert ( @@ -116,5 +118,5 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer): and chat["messages"][-2]["content"] == prompt ) - client.delete(f"/chats/{chat['id']}").raise_for_status() - assert client.get("/chats").raise_for_status().json() == [] + client.delete(f"/api/chats/{chat['id']}").raise_for_status() + assert client.get("/api/chats").raise_for_status().json() == [] diff --git a/tests/deploy/test_config.py b/tests/deploy/test_config.py index a8250acd..be403cd7 100644 --- a/tests/deploy/test_config.py +++ b/tests/deploy/test_config.py @@ -17,24 +17,6 @@ def test_env_var_prefix(mocker, tmp_path): assert config.local_root == env_var -def test_env_var_api_prefix(mocker): - env_var = "hostname" - mocker.patch.dict(os.environ, values={"RAGNA_API_HOSTNAME": env_var}) - - config = Config() - - assert config.api.hostname == env_var - - -def test_env_var_ui_prefix(mocker): - env_var = "hostname" - mocker.patch.dict(os.environ, values={"RAGNA_UI_HOSTNAME": env_var}) - - config = Config() - - assert config.ui.hostname == env_var - - @pytest.mark.xfail() def test_explicit_gt_env_var(mocker, tmp_path): explicit = tmp_path / "explicit" @@ -65,15 +47,14 @@ def test_env_var_gt_config_file(mocker, tmp_path): def test_api_database_url_default_path(tmp_path): config = Config(local_root=tmp_path) - assert Path(urlsplit(config.api.database_url).path[1:]).parent == tmp_path + assert Path(urlsplit(config.database_url).path[1:]).parent == tmp_path -@pytest.mark.parametrize("config_subsection", ["api", "ui"]) -def test_origins_default(config_subsection): +def test_origins_default(): hostname, port = "0.0.0.0", "80" - config = Config(ui=dict(hostname=hostname, port=port)) + config = Config(hostname=hostname, port=port) - assert getattr(config, config_subsection).origins == [f"http://{hostname}:{port}"] + assert config.origins == [f"http://{hostname}:{port}"] def test_from_file_path_not_exists(tmp_path): diff --git a/tests/deploy/ui/test_ui.py b/tests/deploy/ui/test_ui.py index 6c012123..e8fd4842 100644 --- a/tests/deploy/ui/test_ui.py +++ b/tests/deploy/ui/test_ui.py @@ -1,151 +1,129 @@ -import subprocess -import sys +import contextlib +import multiprocessing import time import httpx -import panel as pn import pytest -from playwright.sync_api import Page, expect +from playwright.sync_api import expect -from ragna._utils import timeout_after +from ragna._cli.core import deploy as _deploy from ragna.deploy import Config from tests.deploy.utils import TestAssistant from tests.utils import get_available_port +@contextlib.contextmanager +def deploy(config): + process = multiprocessing.Process( + target=_deploy, + kwargs=dict( + config=config, + api=False, + ui=True, + ignore_unavailable_components=False, + open_browser=False, + ), + ) + try: + process.start() + + client = httpx.Client(base_url=config._url) + + # FIXME: create a generic utility for this + def server_available() -> bool: + try: + return client.get("/health").is_success + except httpx.ConnectError: + return False + + while not server_available(): + time.sleep(0.1) + + yield process + finally: + process.terminate() + process.join() + process.close() + + @pytest.fixture -def config( - tmp_local_root, -): - config = Config( +def default_config(tmp_local_root): + return Config( local_root=tmp_local_root, assistants=[TestAssistant], - ui=dict(port=get_available_port()), - api=dict(port=get_available_port()), + port=get_available_port(), ) - path = tmp_local_root / "ragna.toml" - config.to_file(path) - return config - - -class Server: - def __init__(self, config): - self.config = config - self.base_url = f"http://{config.ui.hostname}:{config.ui.port}" - - def server_up(self): - try: - return httpx.get(self.base_url).is_success - except httpx.ConnectError: - return False - - @timeout_after(60) - def start(self): - self.proc = subprocess.Popen( - [ - sys.executable, - "-m", - "ragna", - "ui", - "--config", - self.config.local_root / "ragna.toml", - "--start-api", - "--ignore-unavailable-components", - "--no-open-browser", - ], - stdout=sys.stdout, - stderr=sys.stderr, - ) - - while not self.server_up(): - time.sleep(1) - - def stop(self): - self.proc.kill() - pn.state.kill_all_servers() - - def __enter__(self): - self.start() - return self - - def __exit__(self, *args): - self.stop() - - -def test_health(config, page: Page) -> None: - with Server(config) as server: - health_url = f"{server.base_url}/health" - response = page.goto(health_url) - assert response.ok - - -def test_start_chat(config, page: Page) -> None: - with Server(config) as server: - # Index page, no auth - index_url = server.base_url - page.goto(index_url) - expect(page.get_by_role("button", name="Sign In")).to_be_visible() - - # Authorize with no credentials - page.get_by_role("button", name="Sign In").click() - expect(page.get_by_role("button", name=" New Chat")).to_be_visible() - - # expect auth token to be set - cookies = page.context.cookies() - assert len(cookies) == 1 - cookie = cookies[0] - assert cookie.get("name") == "auth_token" - auth_token = cookie.get("value") - assert auth_token is not None - - # New page button - new_chat_button = page.get_by_role("button", name=" New Chat") - expect(new_chat_button).to_be_visible() - new_chat_button.click() - - document_root = config.local_root / "documents" - document_root.mkdir() - document_name = "test.txt" - document_path = document_root / document_name - with open(document_path, "w") as file: - file.write("!\n") - - # File upload selector - with page.expect_file_chooser() as fc_info: - page.locator(".fileUpload").click() - file_chooser = fc_info.value - file_chooser.set_files(document_path) - - # Upload document and expect to see it listed - file_list = page.locator(".fileListContainer") - expect(file_list.first).to_have_text(str(document_name)) - - chat_dialog = page.get_by_role("dialog") - expect(chat_dialog).to_be_visible() - start_chat_button = page.get_by_role("button", name="Start Conversation") - expect(start_chat_button).to_be_visible() - time.sleep(0.5) # hack while waiting for button to be fully clickable - start_chat_button.click(delay=5) - - chat_box_row = page.locator(".chat-interface-input-row") - expect(chat_box_row).to_be_visible() - - chat_box = chat_box_row.get_by_role("textbox") - expect(chat_box).to_be_visible() - - # Document should be in the database - chats_url = f"http://{config.api.hostname}:{config.api.port}/chats" - chats = httpx.get( - chats_url, headers={"Authorization": f"Bearer {auth_token}"} - ).json() - assert len(chats) == 1 - chat = chats[0] - chat_documents = chat["metadata"]["documents"] - assert len(chat_documents) == 1 - assert chat_documents[0]["name"] == document_name - - chat_box.fill("Tell me about the documents") - - chat_button = chat_box_row.get_by_role("button") - expect(chat_button).to_be_visible() - chat_button.click() + + +@pytest.fixture +def index_page(default_config, page): + config = default_config + with deploy(default_config): + page.goto(f"http://{config.hostname}:{config.port}/ui") + yield page + + +def test_start_chat(index_page, tmp_path) -> None: + # expect(page.get_by_role("button", name="Sign In")).to_be_visible() + + # # Authorize with no credentials + # page.get_by_role("button", name="Sign In").click() + # expect(page.get_by_role("button", name=" New Chat")).to_be_visible() + # + # # expect auth token to be set + # cookies = page.context.cookies() + # assert len(cookies) == 1 + # cookie = cookies[0] + # assert cookie.get("name") == "auth_token" + # auth_token = cookie.get("value") + # assert auth_token is not None + + # New page button + new_chat_button = index_page.get_by_role("button", name=" New Chat") + expect(new_chat_button).to_be_visible() + new_chat_button.click() + + # document_name = "test.txt" + # document_path = tmp_path / document_name + # with open(document_path, "w") as file: + # file.write("!\n") + + # # File upload selector + # with index_page.expect_file_chooser() as fc_info: + # index_page.locator(".fileUpload").click() + # file_chooser = fc_info.value + # file_chooser.set_files(document_path) + + # # Upload document and expect to see it listed + # file_list = page.locator(".fileListContainer") + # expect(file_list.first).to_have_text(str(document_name)) + # + # chat_dialog = page.get_by_role("dialog") + # expect(chat_dialog).to_be_visible() + # start_chat_button = page.get_by_role("button", name="Start Conversation") + # expect(start_chat_button).to_be_visible() + # time.sleep(0.5) # hack while waiting for button to be fully clickable + # start_chat_button.click(delay=5) + # + # chat_box_row = page.locator(".chat-interface-input-row") + # expect(chat_box_row).to_be_visible() + # + # chat_box = chat_box_row.get_by_role("textbox") + # expect(chat_box).to_be_visible() + # + # # Document should be in the database + # chats_url = f"http://{config.api.hostname}:{config.api.port}/chats" + # chats = httpx.get( + # chats_url, headers={"Authorization": f"Bearer {auth_token}"} + # ).json() + # assert len(chats) == 1 + # chat = chats[0] + # chat_documents = chat["metadata"]["documents"] + # assert len(chat_documents) == 1 + # assert chat_documents[0]["name"] == document_name + # + # chat_box.fill("Tell me about the documents") + # + # chat_button = chat_box_row.get_by_role("button") + # expect(chat_button).to_be_visible() + # chat_button.click() diff --git a/tests/deploy/utils.py b/tests/deploy/utils.py index 48b8c2ae..f8d1277a 100644 --- a/tests/deploy/utils.py +++ b/tests/deploy/utils.py @@ -5,6 +5,7 @@ from ragna.assistants import RagnaDemoAssistant from ragna.core._utils import default_user +from ragna.deploy._core import make_app class TestAssistant(RagnaDemoAssistant): @@ -26,7 +27,18 @@ def answer(self, messages, *, multiple_answer_chunks: bool = True): yield content +def make_api_app(*, config, ignore_unavailable_components): + return make_app( + config, + api=True, + ui=False, + ignore_unavailable_components=ignore_unavailable_components, + open_browser=False, + ) + + def authenticate_with_api(client: TestClient) -> None: + return username = default_user() token = ( client.post( diff --git a/tests/test_js_utils.py b/tests/test_js_utils.py deleted file mode 100644 index 202af9e7..00000000 --- a/tests/test_js_utils.py +++ /dev/null @@ -1,47 +0,0 @@ -from textwrap import dedent - -from ragna.deploy._ui.js_utils import preformat, redirect_script - - -def test_preformat_basic(): - output = preformat("{ This is awesome {{var}} }") - assert output == "{{ This is awesome {var} }}" - - -def test_preformat_basic_fmt(): - output = preformat("{ This is awesome {{var}} }").format(var="test") - assert output == "{ This is awesome test }" - - -def test_preformat_multivars(): - output = preformat("{ {{var1}} This is awesome {{var2}} }").format( - var1="test1", var2="test2" - ) - assert output == "{ test1 This is awesome test2 }" - - -def test_preformat_unsubs(): - output = preformat("{ This is {Hello} awesome {{var}} }").format(var="test") - assert output == "{ This is {Hello} awesome test }" - - -def test_redirect_script(): - output = redirect_script(remove="foo", append="bar") - expected = dedent( - r""" - - """ - ) - assert dedent(output.object) == expected