From c52734480b9eb1cc3b61728a0d802828e0877cb1 Mon Sep 17 00:00:00 2001 From: William Black <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 16 May 2024 03:39:29 -0700 Subject: [PATCH 1/5] [DOC] Add tutorial for adding your own objects (#368) Co-authored-by: Philip Meier Co-authored-by: Pavithra Eswaramoorthy --- .../images/ragna-tutorial-components.png | 3 + docs/examples/gallery_streaming.py | 23 +- docs/references/config.md | 17 +- docs/tutorials/gallery_custom_components.py | 413 ++++++++++++++++++ docs/tutorials/gallery_rest_api.py | 2 +- ragna/_docs.py | 69 ++- 6 files changed, 488 insertions(+), 39 deletions(-) create mode 100644 docs/assets/images/ragna-tutorial-components.png create mode 100644 docs/tutorials/gallery_custom_components.py diff --git a/docs/assets/images/ragna-tutorial-components.png b/docs/assets/images/ragna-tutorial-components.png new file mode 100644 index 00000000..6ac80dcc --- /dev/null +++ b/docs/assets/images/ragna-tutorial-components.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66c1f4c3a148be4273390406804dcb2df4cef609c43f2127ac76c60bc6067712 +size 45798 diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index ca46d884..6ffda30a 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -99,26 +99,7 @@ def answer(self, prompt, sources): rest_api = ragna_docs.RestApi() -client = rest_api.start(config, authenticate=True) - -# %% -# Upload the document. - -document_upload = ( - client.post("/document", json={"name": document_path.name}) - .raise_for_status() - .json() -) - -document = document_upload["document"] - -parameters = document_upload["parameters"] -client.request( - parameters["method"], - parameters["url"], - data=parameters["data"], - files={"file": open(document_path, "rb")}, -).raise_for_status() +client, document = rest_api.start(config, authenticate=True, upload_document=True) # %% # Start and prepare the chat @@ -173,6 +154,6 @@ def answer(self, prompt, sources): # %% # Before we close the example, let's stop the REST API and have a look at what would -# have printed in the terminal if we had started it the regular way. +# have printed in the terminal if we had started it with the `ragna api` command. rest_api.stop() diff --git a/docs/references/config.md b/docs/references/config.md index 5223e71f..c6ed18b9 100644 --- a/docs/references/config.md +++ b/docs/references/config.md @@ -33,16 +33,21 @@ There are two main ways to generate a configuration file: ## Referencing Python objects Some configuration options reference Python objects, e.g. -`document = ragna.core.LocalDocument`. You can inject your own objects here and do not -need to rely on the defaults by Ragna. To do so, make sure that the module the object is -defined in is on the -[`PYTHONPATH`](https://docs.python.org/3/using/cmdline.html#envvar-PYTHONPATH). The -`document` configuration mentioned before internally is roughly treated as +`document = ragna.core.LocalDocument`. Internally, this is roughly treated as `from ragna.core import LocalDocument`. +You can inject your own objects here and do not need to rely on the defaults by Ragna. +To do so, make sure that the module the object is defined in is on +[Python's search path](https://docs.python.org/3/library/sys.html#sys.path). There are +multiple ways to achieve this, e.g.: + +- Install your module as part of a package in your current environment. +- Set the [`PYTHONPATH`](https://docs.python.org/3/using/cmdline.html#envvar-PYTHONPATH) + environment variable to include the directory your module is located in. + ## Environment variables -All configuration options can be set or overritten by environment variables by using the +All configuration options can be set or overridden by environment variables by using the `RAGNA_` prefix. For example, `document = ragna.core.LocalDocument` in the configuration file is equivalent to setting `RAGNA_DOCUMENT=ragna.core.LocalDocument`. diff --git a/docs/tutorials/gallery_custom_components.py b/docs/tutorials/gallery_custom_components.py new file mode 100644 index 00000000..49b5fa26 --- /dev/null +++ b/docs/tutorials/gallery_custom_components.py @@ -0,0 +1,413 @@ +""" +# Custom Components + +While Ragna has builtin support for a few [source storages][ragna.source_storages] +and [assistants][ragna.assistants], its real strength lies in allowing users +to incorporate custom components. This tutorial covers how to do that. +""" + +# %% +# ## Components +# +# ### Source Storage +# +# [ragna.core.SourceStorage][]s are objects that take a number of documents and +# [ragna.core.SourceStorage.store][] their content in way such that relevant parts for a +# given user prompt can be [ragna.core.SourceStorage.retrieve][]d in the form of +# [ragna.core.Source][]s. Usually, source storages are vector databases. +# +# In this tutorial, we define a minimal `TutorialSourceStorage` that is similar to +# [ragna.source_storages.RagnaDemoSourceStorage][]. In `.store()` we create the +# `Source`s with the first 100 characters of each document and store them in memory +# based on the unique `chat_id`. In `retrieve()` we return all the stored sources +# for the chat and regardless of the user `prompt`. +# +# !!! note +# +# The `chat_id` used in both methods will not be passed by default, but rather has +# to be requested explicitly. How this works in detail will be explained later in +# this tutorial in the [Custom Parameters](#custom-parameters) section. + +import uuid + +from ragna.core import Document, Source, SourceStorage + + +class TutorialSourceStorage(SourceStorage): + def __init__(self): + self._storage: dict[uuid.UUID, list[Source]] = {} + + def store(self, documents: list[Document], *, chat_id: uuid.UUID) -> None: + print(f"Running {type(self).__name__}().store()") + + self._storage[chat_id] = [ + Source( + id=str(uuid.uuid4()), + document=document, + location="N/A", + content=(content := next(document.extract_pages()).text[:100]), + num_tokens=len(content.split()), + ) + for document in documents + ] + + def retrieve( + self, documents: list[Document], prompt: str, *, chat_id: uuid.UUID + ) -> list[Source]: + print(f"Running {type(self).__name__}().retrieve()") + return self._storage[chat_id] + + +# %% +# ### Assistant +# +# [ragna.core.Assistant][]s are objects that take a user prompt and relevant +# [ragna.core.Source][]s and generate a response form that. Usually, assistants are +# LLMs. +# +# In this tutorial, we define a minimal `TutorialAssistant` that is similar to +# [ragna.assistants.RagnaDemoAssistant][]. In `.answer()` we mirror back the user +# prompt and also the number of sources we were given. +# +# !!! note +# +# The answer needs to be `yield`ed instead of `return`ed. By yielding multiple +# times, the answer can be streamed back. See the +# [streaming example](../../generated/examples/gallery_streaming.md) for more +# information. + +from typing import Iterator + +from ragna.core import Assistant, Source + + +class TutorialAssistant(Assistant): + def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: + print(f"Running {type(self).__name__}().answer()") + yield ( + f"To answer the user prompt '{prompt}', " + f"I was given {len(sources)} source(s)." + ) + + +# %% +# ## Usage +# +# Now that we have defined a custom [ragna.core.SourceStorage][] and +# [ragna.core.Assistant][], let's have a look on how to use them with Ragna. Let's start +# with the Python API. +# +# ### Python API +# +# We first create a sample document. + +from pathlib import Path + +import ragna._docs as ragna_docs + +print(ragna_docs.SAMPLE_CONTENT) + +document_path = Path.cwd() / "ragna.txt" + +with open(document_path, "w") as file: + file.write(ragna_docs.SAMPLE_CONTENT) + +# %% +# Next, we create a new [ragna.core.Chat][] with our custom components. + +from ragna import Rag + +chat = Rag().chat( + documents=[document_path], + source_storage=TutorialSourceStorage, + assistant=TutorialAssistant, +) + +# %% +# From here on, you can use your custom components exactly like the builtin ones. For +# more details, have a look at the +# [Python API tutorial](../../generated/tutorials/gallery_python_api.md). + +_ = await chat.prepare() + +# %% + +answer = await chat.answer("What is Ragna?") + +# %% + +print(answer) + +# %% +for idx, source in enumerate(answer.sources, 1): + print(f"{idx}. {source.content}") + +# %% +# ### REST API +# +# To use custom components in Ragna's REST API, we need to include the components in +# the corresponding arrays in the configuration file. If you don't have a `ragna.toml` +# configuration file yet, see the [config reference](../../references/config.md) on how +# to create one. +# +# For example, if we put the `TutorialSourceStorage` and `TutorialAssistant` classes in +# a `tutorial.py` module, we can expand the `source_storages` and `assistants` arrays +# like this +# +# ```toml +# source_storages = [ +# ... +# "tutorial.TutorialSourceStorage" +# ] +# assistants = [ +# ... +# "tutorial.TutorialAssistant" +# ] +# ``` +# +# !!! note +# +# Make sure the `tutorial.py` module is on Python's search path. See the +# [config reference](../../references/config.md#referencing-python-objects) for +# details. +# +# With the configuration set up, we now start the REST API. For the purpose of this +# tutorial we replicate the same programmatically. For general information on how to use +# the REST API, have a look at the +# [REST API tutorial](../../generated/tutorials/gallery_rest_api.md). + +from ragna.deploy import Config + +config = Config( + source_storages=[TutorialSourceStorage], + assistants=[TutorialAssistant], +) + +rest_api = ragna_docs.RestApi() + +client, document = rest_api.start(config, authenticate=True, upload_document=True) + +# %% +# To select our custom components, we pass their display names to the chat creation. +# +# !!! tip +# +# By default [ragna.core.Component.display_name][] returns the name of the class, +# but can be overridden, e.g. to format the name better. + +import json + +response = client.post( + "/chats", + json={ + "name": "Tutorial REST API", + "documents": [document], + "source_storage": TutorialSourceStorage.display_name(), + "assistant": TutorialAssistant.display_name(), + "params": {}, + }, +).raise_for_status() +chat = response.json() + +client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + +response = client.post( + f"/chats/{chat['id']}/answer", + json={"prompt": "What is Ragna?"}, +).raise_for_status() +answer = response.json() +print(json.dumps(answer, indent=2)) + +# %% +# Let's stop the REST API and have a look at what would have printed in the terminal if +# we had started it with the `ragna api` command. + +rest_api.stop() + +# %% +# ### Web UI +# +# The setup for the web UI is exactly the same as for the [REST API](#rest-api). If you +# have your configuration set up, you can start the web UI and use the custom +# components. See the [web UI tutorial](../../generated/tutorials/gallery_rest_api.md) +# for details. +# +# +# ![Create-chat-modal of Ragna's web UI with the TutorialSourceStorage and TutorialAssistant as selected options](../../assets/images/ragna-tutorial-components.png) +# +# !!! warning +# +# See the section about [custom parameters for the web UI](#web-ui_1) for some +# limitations using custom components in the web UI. + +# %% +# ## Custom parameters +# +# Ragna supports passing parameters to the components on a per-chat level. The extra +# parameters are defined by adding them to the signature of the method. For example, +# let's define a more elaborate assistant that accepts the following arguments: +# +# - `my_required_parameter` must be passed and has to be an `int` +# - `my_optional_parameter` may be passed and has to be a `str` if passed + + +class ElaborateTutorialAssistant(Assistant): + def answer( + self, + prompt: str, + sources: list[Source], + *, + my_required_parameter: int, + my_optional_parameter: str = "foo", + ) -> Iterator[str]: + print(f"Running {type(self).__name__}().answer()") + yield ( + f"I was given {my_required_parameter=} and {my_optional_parameter=}." + ) + + +# %% +# ### Python API +# +# To pass custom parameters to the components, pass them as keyword arguments when +# creating a chat. + +chat = Rag().chat( + documents=[document_path], + source_storage=TutorialSourceStorage, + assistant=ElaborateTutorialAssistant, + my_required_parameter=3, + my_optional_parameter="bar", +) + +_ = await chat.prepare() +print(await chat.answer("Hello!")) + +# %% +# The chat creation will fail if a required parameter is not passed or a wrong type is +# passed for any parameter. + +try: + Rag().chat( + documents=[document_path], + source_storage=TutorialSourceStorage, + assistant=ElaborateTutorialAssistant, + ) +except Exception as exc: + print(exc) + +# %% + +try: + Rag().chat( + documents=[document_path], + source_storage=TutorialSourceStorage, + assistant=ElaborateTutorialAssistant, + my_required_parameter="bar", + my_optional_parameter=3, + ) +except Exception as exc: + print(exc) + +# %% +# ### REST API + +config = Config( + source_storages=[TutorialSourceStorage], + assistants=[ElaborateTutorialAssistant], +) + +rest_api = ragna_docs.RestApi() + +client, document = rest_api.start(config, authenticate=True, upload_document=True) + +# %% +# To pass custom parameters, define them in the `params` mapping when creating a new +# chat. + +response = client.post( + "/chats", + json={ + "name": "Tutorial REST API", + "documents": [document], + "source_storage": TutorialSourceStorage.display_name(), + "assistant": ElaborateTutorialAssistant.display_name(), + "params": { + "my_required_parameter": 3, + "my_optional_parameter": "bar", + }, + }, +).raise_for_status() +chat = response.json() + +# %% + +client.post(f"/chats/{chat['id']}/prepare").raise_for_status() + +response = client.post( + f"/chats/{chat['id']}/answer", + json={"prompt": "What is Ragna?"}, +).raise_for_status() +answer = response.json() +print(json.dumps(answer, indent=2)) + +# %% +# Let's stop the REST API and have a look at what would have printed in the terminal if +# we had started it with the `ragna api` command. + +rest_api.stop() + +# %% +# ### Web UI +# +# !!! warning +# Unfortunately, Ragna's web UI currently does **not** support arbitrary custom +# parameters. This is actively worked on and progress is tracked in +# [#217](https://github.com/Quansight/ragna/issues/217). Until this is resolved, +# the source storage and the assistant together take the following four +# parameters: +# +# 1. `chunk_size: int` +# 2. `chunk_overlap: int` +# 3. `num_tokens: int` +# 4. `max_new_tokens: int` +# +# When using Ragna's builtin components, all source storages cover parameters 1. to +# 3. and all assistant cover parameter 4. + +# %% +# ## Sync or `async`? +# +# So far we have `def`ined all of the abstract methods of the components as regular, +# i.e. synchronous methods. However, this is *not* a requirement. If you want to run +# `async` code inside the methods, define them with `async def` and Ragna will +# handle the rest for you. Internally, Ragna runs synchronous methods on a thread to +# avoid blocking the main thread. Asynchronous methods are scheduled on the main event +# loop. +# +# Let's have a look at a minimal example. + +import asyncio +import time +from typing import AsyncIterator + + +class AsyncAssistant(Assistant): + async def answer( + self, prompt: str, sources: list[Source] + ) -> AsyncIterator[str]: + print(f"Running {type(self).__name__}().answer()") + start = time.perf_counter() + await asyncio.sleep(0.3) + stop = time.perf_counter() + yield f"I've waited for {stop - start} seconds!" + + +chat = Rag().chat( + documents=[document_path], + source_storage=TutorialSourceStorage, + assistant=AsyncAssistant, +) + +_ = await chat.prepare() +print(await chat.answer("Hello!")) diff --git a/docs/tutorials/gallery_rest_api.py b/docs/tutorials/gallery_rest_api.py index 9d57d0b4..befcbfb3 100644 --- a/docs/tutorials/gallery_rest_api.py +++ b/docs/tutorials/gallery_rest_api.py @@ -189,6 +189,6 @@ # %% # Before we close the tutorial, let's stop the REST API and have a look at what would -# have printed in the terminal if we had started it the regular way. +# have printed in the terminal if we had started it with the `ragna api` command. rest_api.stop() diff --git a/ragna/_docs.py b/ragna/_docs.py index 53f3993f..20ac122f 100644 --- a/ragna/_docs.py +++ b/ragna/_docs.py @@ -1,14 +1,17 @@ import inspect +import itertools import os import subprocess import sys import tempfile +import textwrap from pathlib import Path -from typing import Optional +from typing import Any, Optional, cast import httpx from ragna._utils import timeout_after +from ragna.core import RagnaException from ragna.deploy import Config __all__ = ["SAMPLE_CONTENT", "RestApi"] @@ -29,7 +32,18 @@ class RestApi: def __init__(self) -> None: self._process: Optional[subprocess.Popen] = None - def start(self, config: Config, *, authenticate: bool = False) -> httpx.Client: + def start( + self, + config: Config, + *, + authenticate: bool = False, + upload_document: bool = False, + ) -> tuple[httpx.Client, Optional[dict]]: + if upload_document and not authenticate: + raise RagnaException( + "Cannot upload a document without authenticating first. " + "Set authenticate=True when using upload_document=True." + ) python_path, config_path = self._prepare_config(config) client = httpx.Client(base_url=config.api.url) @@ -39,7 +53,12 @@ def start(self, config: Config, *, authenticate: bool = False) -> httpx.Client: if authenticate: self._authenticate(client) - return client + if upload_document: + document = self._upload_document(client) + else: + document = None + + return client, document def _prepare_config(self, config: Config) -> tuple[str, str]: deploy_directory = Path(tempfile.mkdtemp()) @@ -50,22 +69,32 @@ def _prepare_config(self, config: Config) -> tuple[str, str]: config_path = str(deploy_directory / "ragna.toml") config.local_root = deploy_directory + config.api.database_url = f"sqlite:///{deploy_directory / 'ragna.db'}" sys.modules["__main__"].__file__ = inspect.getouterframes( inspect.currentframe() )[2].filename custom_module = deploy_directory.name + custom_components = set() with open(deploy_directory / f"{custom_module}.py", "w") as file: - # TODO: this currently only handles assistants. When needed, we can extend - # to source storages. - file.write("from ragna import assistants\n\n") - - for assistant in config.assistants: - if assistant.__module__ == "__main__": - file.write(f"{inspect.getsource(assistant)}\n\n") - assistant.__module__ = custom_module + # FIXME Find a way to automatically detect necessary imports + file.write("import uuid; from uuid import *\n") + file.write("import textwrap; from textwrap import*\n") + file.write("from typing import *\n") + file.write("from ragna import *\n") + file.write("from ragna.core import *\n") + + for component in itertools.chain(config.source_storages, config.assistants): + if component.__module__ == "__main__": + custom_components.add(component) + file.write(f"{textwrap.dedent(inspect.getsource(component))}\n\n") + component.__module__ = custom_module config.to_file(config_path) + + for component in custom_components: + component.__module__ = "__main__" + return python_path, config_path def _start_api( @@ -123,6 +152,24 @@ def _authenticate(self, client: httpx.Client) -> None: client.headers["Authorization"] = f"Bearer {token}" + def _upload_document(self, client: httpx.Client) -> dict[str, Any]: + name, content = "ragna.txt", SAMPLE_CONTENT + + response = client.post("/document", json={"name": name}).raise_for_status() + document_upload = response.json() + + document = cast(dict[str, Any], document_upload["document"]) + + parameters = document_upload["parameters"] + client.request( + parameters["method"], + parameters["url"], + data=parameters["data"], + files={"file": content}, + ).raise_for_status() + + return document + def stop(self, *, quiet: bool = False) -> None: if self._process is None: return From 84cf4f627ebf52b061681a7a8b106daef3e79a1d Mon Sep 17 00:00:00 2001 From: Pierre-Olivier Simonard Date: Fri, 17 May 2024 09:04:48 +0200 Subject: [PATCH 2/5] Fix #415 : chat bubble (#419) --- ragna/deploy/_ui/css/left_sidebar/button.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ragna/deploy/_ui/css/left_sidebar/button.css b/ragna/deploy/_ui/css/left_sidebar/button.css index 19577d4e..bc06efd2 100644 --- a/ragna/deploy/_ui/css/left_sidebar/button.css +++ b/ragna/deploy/_ui/css/left_sidebar/button.css @@ -24,7 +24,7 @@ } :host(.chat_button) div button:before { - content: url("imgs/chat_bubble.svg"); + content: url("../../imgs/chat_bubble.svg"); margin-right: 10px; display: inline-block; } From a45bd9063f106a7728c5fb65c9077387b25c9706 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 28 May 2024 09:32:36 +0200 Subject: [PATCH 3/5] refactor assistant streaming and create OpenAI compliant base class (#425) --- docs/examples/gallery_streaming.py | 2 + docs/tutorials/gallery_python_api.py | 4 +- ragna/assistants/__init__.py | 2 + ragna/assistants/_ai21labs.py | 8 +- ragna/assistants/_anthropic.py | 42 ++++----- ragna/assistants/_api.py | 59 ------------ ragna/assistants/_cohere.py | 33 +++---- ragna/assistants/_google.py | 12 +-- ragna/assistants/_http_api.py | 80 ++++++++++++++++ ragna/assistants/_llamafile.py | 25 +++++ ragna/assistants/_openai.py | 134 ++++++++++++++++----------- tests/assistants/test_api.py | 13 ++- 12 files changed, 241 insertions(+), 173 deletions(-) delete mode 100644 ragna/assistants/_api.py create mode 100644 ragna/assistants/_http_api.py create mode 100644 ragna/assistants/_llamafile.py diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index 6ffda30a..84d92d08 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -29,6 +29,8 @@ # - [OpenAI](https://openai.com/) # - [ragna.assistants.Gpt35Turbo16k][] # - [ragna.assistants.Gpt4][] +# - [llamafile](https://github.com/Mozilla-Ocho/llamafile) +# - [ragna.assistants.LlamafileAssistant][] from ragna import assistants diff --git a/docs/tutorials/gallery_python_api.py b/docs/tutorials/gallery_python_api.py index d1919bd3..a7d0ef63 100644 --- a/docs/tutorials/gallery_python_api.py +++ b/docs/tutorials/gallery_python_api.py @@ -85,10 +85,12 @@ # - [ragna.assistants.Gpt4][] # - [AI21 Labs](https://www.ai21.com/) # - [ragna.assistants.Jurassic2Ultra][] +# - [llamafile](https://github.com/Mozilla-Ocho/llamafile) +# - [ragna.assistants.LlamafileAssistant][] # # !!! note # -# To use any of builtin assistants, you need to +# To use some of the builtin assistants, you need to # [procure API keys](../../references/faq.md#where-do-i-get-api-keys-for-the-builtin-assistants) # first and set the corresponding environment variables. diff --git a/ragna/assistants/__init__.py b/ragna/assistants/__init__.py index 823d87ac..d583e7a0 100644 --- a/ragna/assistants/__init__.py +++ b/ragna/assistants/__init__.py @@ -9,6 +9,7 @@ "Gpt35Turbo16k", "Gpt4", "Jurassic2Ultra", + "LlamafileAssistant", "RagnaDemoAssistant", ] @@ -17,6 +18,7 @@ from ._cohere import Command, CommandLight from ._demo import RagnaDemoAssistant from ._google import GeminiPro, GeminiUltra +from ._llamafile import LlamafileAssistant from ._openai import Gpt4, Gpt35Turbo16k # isort: split diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 19cfd59b..1c61a213 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -2,10 +2,10 @@ from ragna.core import Source -from ._api import ApiAssistant +from ._http_api import HttpApiAssistant -class Ai21LabsAssistant(ApiAssistant): +class Ai21LabsAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "AI21_API_KEY" _MODEL_TYPE: str @@ -21,8 +21,8 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) - async def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters # See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index fa8922fe..37f132b5 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,12 +1,11 @@ -import json from typing import AsyncIterator, cast from ragna.core import PackageRequirement, RagnaException, Requirement, Source -from ._api import ApiAssistant +from ._http_api import HttpApiAssistant -class AnthropicApiAssistant(ApiAssistant): +class AnthropicAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "ANTHROPIC_API_KEY" _MODEL: str @@ -36,15 +35,12 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str: + "" ) - async def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - import httpx_sse - # See https://docs.anthropic.com/claude/reference/messages_post # See https://docs.anthropic.com/claude/reference/streaming - async with httpx_sse.aconnect_sse( - self._client, + async for data in self._stream_sse( "POST", "https://api.anthropic.com/v1/messages", headers={ @@ -61,23 +57,19 @@ async def _call_api( "temperature": 0.0, "stream": True, }, - ) as event_source: - await self._assert_api_call_is_success(event_source.response) - - async for sse in event_source.aiter_sse(): - data = json.loads(sse.data) - # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response - if "error" in data: - raise RagnaException(data["error"].pop("message"), **data["error"]) - elif data["type"] == "message_stop": - break - elif data["type"] != "content_block_delta": - continue + ): + # See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response + if "error" in data: + raise RagnaException(data["error"].pop("message"), **data["error"]) + elif data["type"] == "message_stop": + break + elif data["type"] != "content_block_delta": + continue - yield cast(str, data["delta"].pop("text")) + yield cast(str, data["delta"].pop("text")) -class ClaudeOpus(AnthropicApiAssistant): +class ClaudeOpus(AnthropicAssistant): """[Claude 3 Opus](https://docs.anthropic.com/claude/docs/models-overview) !!! info "Required environment variables" @@ -92,7 +84,7 @@ class ClaudeOpus(AnthropicApiAssistant): _MODEL = "claude-3-opus-20240229" -class ClaudeSonnet(AnthropicApiAssistant): +class ClaudeSonnet(AnthropicAssistant): """[Claude 3 Sonnet](https://docs.anthropic.com/claude/docs/models-overview) !!! info "Required environment variables" @@ -107,7 +99,7 @@ class ClaudeSonnet(AnthropicApiAssistant): _MODEL = "claude-3-sonnet-20240229" -class ClaudeHaiku(AnthropicApiAssistant): +class ClaudeHaiku(AnthropicAssistant): """[Claude 3 Haiku](https://docs.anthropic.com/claude/docs/models-overview) !!! info "Required environment variables" diff --git a/ragna/assistants/_api.py b/ragna/assistants/_api.py deleted file mode 100644 index 446d9545..00000000 --- a/ragna/assistants/_api.py +++ /dev/null @@ -1,59 +0,0 @@ -import abc -import contextlib -import json -import os -from typing import AsyncIterator - -import httpx -from httpx import Response - -import ragna -from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement, Source - - -class ApiAssistant(Assistant): - _API_KEY_ENV_VAR: str - - @classmethod - def requirements(cls) -> list[Requirement]: - return [EnvVarRequirement(cls._API_KEY_ENV_VAR), *cls._extra_requirements()] - - @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [] - - def __init__(self) -> None: - self._client = httpx.AsyncClient( - headers={"User-Agent": f"{ragna.__version__}/{self}"}, - timeout=60, - ) - self._api_key = os.environ[self._API_KEY_ENV_VAR] - - async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 - ) -> AsyncIterator[str]: - async for chunk in self._call_api( - prompt, sources, max_new_tokens=max_new_tokens - ): - yield chunk - - @abc.abstractmethod - def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int - ) -> AsyncIterator[str]: ... - - async def _assert_api_call_is_success(self, response: Response) -> None: - if response.is_success: - return - - content = await response.aread() - with contextlib.suppress(Exception): - content = json.loads(content) - - raise RagnaException( - "API call failed", - request_method=response.request.method, - request_url=str(response.request.url), - response_status_code=response.status_code, - response_content=content, - ) diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index a93a264f..b47737f8 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -1,12 +1,11 @@ -import json from typing import AsyncIterator, cast from ragna.core import RagnaException, Source -from ._api import ApiAssistant +from ._http_api import HttpApiAssistant -class CohereApiAssistant(ApiAssistant): +class CohereAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "COHERE_API_KEY" _MODEL: str @@ -24,13 +23,13 @@ def _make_preamble(self) -> str: def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]: return [{"title": source.id, "snippet": source.content} for source in sources] - async def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: # See https://docs.cohere.com/docs/cochat-beta # See https://docs.cohere.com/reference/chat # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag - async with self._client.stream( + async for event in self._stream_jsonl( "POST", "https://api.cohere.ai/v1/chat", headers={ @@ -47,21 +46,17 @@ async def _call_api( "max_tokens": max_new_tokens, "documents": self._make_source_documents(sources), }, - ) as response: - await self._assert_api_call_is_success(response) + ): + if event["event_type"] == "stream-end": + if event["event_type"] == "COMPLETE": + break - async for chunk in response.aiter_lines(): - event = json.loads(chunk) - if event["event_type"] == "stream-end": - if event["event_type"] == "COMPLETE": - break + raise RagnaException(event["error_message"]) + if "text" in event: + yield cast(str, event["text"]) - raise RagnaException(event["error_message"]) - if "text" in event: - yield cast(str, event["text"]) - -class Command(CohereApiAssistant): +class Command(CohereAssistant): """ [Cohere Command](https://docs.cohere.com/docs/models#command) @@ -73,7 +68,7 @@ class Command(CohereApiAssistant): _MODEL = "command" -class CommandLight(CohereApiAssistant): +class CommandLight(CohereAssistant): """ [Cohere Command-Light](https://docs.cohere.com/docs/models#command) diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index afbb829a..8e1caf1e 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -3,7 +3,7 @@ from ragna._compat import anext from ragna.core import PackageRequirement, Requirement, Source -from ._api import ApiAssistant +from ._http_api import HttpApiAssistant # ijson does not support reading from an (async) iterator, but only from file-like @@ -25,7 +25,7 @@ async def read(self, n: int) -> bytes: return await anext(self._ait, b"") # type: ignore[call-arg] -class GoogleApiAssistant(ApiAssistant): +class GoogleAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "GOOGLE_API_KEY" _MODEL: str @@ -48,8 +48,8 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: ] ) - async def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: import ijson @@ -88,7 +88,7 @@ async def _call_api( yield chunk -class GeminiPro(GoogleApiAssistant): +class GeminiPro(GoogleAssistant): """[Google Gemini Pro](https://ai.google.dev/models/gemini) !!! info "Required environment variables" @@ -103,7 +103,7 @@ class GeminiPro(GoogleApiAssistant): _MODEL = "gemini-pro" -class GeminiUltra(GoogleApiAssistant): +class GeminiUltra(GoogleAssistant): """[Google Gemini Ultra](https://ai.google.dev/models/gemini) !!! info "Required environment variables" diff --git a/ragna/assistants/_http_api.py b/ragna/assistants/_http_api.py new file mode 100644 index 00000000..1151a62a --- /dev/null +++ b/ragna/assistants/_http_api.py @@ -0,0 +1,80 @@ +import contextlib +import json +import os +from typing import Any, AsyncIterator, Optional + +import httpx +from httpx import Response + +import ragna +from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement + + +class HttpApiAssistant(Assistant): + _API_KEY_ENV_VAR: Optional[str] + + @classmethod + def requirements(cls) -> list[Requirement]: + requirements: list[Requirement] = ( + [EnvVarRequirement(cls._API_KEY_ENV_VAR)] + if cls._API_KEY_ENV_VAR is not None + else [] + ) + requirements.extend(cls._extra_requirements()) + return requirements + + @classmethod + def _extra_requirements(cls) -> list[Requirement]: + return [] + + def __init__(self) -> None: + self._client = httpx.AsyncClient( + headers={"User-Agent": f"{ragna.__version__}/{self}"}, + timeout=60, + ) + self._api_key: Optional[str] = ( + os.environ[self._API_KEY_ENV_VAR] + if self._API_KEY_ENV_VAR is not None + else None + ) + + async def _assert_api_call_is_success(self, response: Response) -> None: + if response.is_success: + return + + content = await response.aread() + with contextlib.suppress(Exception): + content = json.loads(content) + + raise RagnaException( + "API call failed", + request_method=response.request.method, + request_url=str(response.request.url), + response_status_code=response.status_code, + response_content=content, + ) + + async def _stream_sse( + self, + method: str, + url: str, + **kwargs: Any, + ) -> AsyncIterator[dict[str, Any]]: + import httpx_sse + + async with httpx_sse.aconnect_sse( + self._client, method, url, **kwargs + ) as event_source: + await self._assert_api_call_is_success(event_source.response) + + async for sse in event_source.aiter_sse(): + yield json.loads(sse.data) + + async def _stream_jsonl( + self, method: str, url: str, **kwargs: Any + ) -> AsyncIterator[dict[str, Any]]: + async with self._client.stream(method, url, **kwargs) as response: + await self._assert_api_call_is_success(response) + + async for chunk in response.aiter_lines(): + yield json.loads(chunk) diff --git a/ragna/assistants/_llamafile.py b/ragna/assistants/_llamafile.py new file mode 100644 index 00000000..3e78a625 --- /dev/null +++ b/ragna/assistants/_llamafile.py @@ -0,0 +1,25 @@ +import os + +from ._openai import OpenaiCompliantHttpApiAssistant + + +class LlamafileAssistant(OpenaiCompliantHttpApiAssistant): + """[llamafile](https://github.com/Mozilla-Ocho/llamafile) + + To use this assistant, start the llamafile server manually. By default, the server + is expected at `http://localhost:8080`. This can be changed with the + `RAGNA_LLAMAFILE_BASE_URL` environment variable. + + !!! info "Required packages" + + - `httpx_sse` + """ + + _API_KEY_ENV_VAR = None + _STREAMING_METHOD = "sse" + _MODEL = None + + @property + def _url(self) -> str: + base_url = os.environ.get("RAGNA_LLAMAFILE_BASE_URL", "http://localhost:8080") + return f"{base_url}/v1/chat/completions" diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index a9dad3ee..37957be2 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,22 +1,26 @@ -import json -from typing import AsyncIterator, cast +import abc +from typing import Any, AsyncIterator, Literal, Optional, cast -from ragna.core import PackageRequirement, Requirement, Source +from ragna.core import PackageRequirement, RagnaException, Requirement, Source -from ._api import ApiAssistant +from ._http_api import HttpApiAssistant -class OpenaiApiAssistant(ApiAssistant): - _API_KEY_ENV_VAR = "OPENAI_API_KEY" - _MODEL: str +class OpenaiCompliantHttpApiAssistant(HttpApiAssistant): + _STREAMING_METHOD: Literal["sse", "jsonl"] + _MODEL: Optional[str] @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [PackageRequirement("httpx_sse")] + def requirements(cls) -> list[Requirement]: + requirements = super().requirements() + requirements.extend( + {"sse": [PackageRequirement("httpx_sse")]}.get(cls._STREAMING_METHOD, []) + ) + return requirements - @classmethod - def display_name(cls) -> str: - return f"OpenAI/{cls._MODEL}" + @property + @abc.abstractmethod + def _url(self) -> str: ... def _make_system_content(self, sources: list[Source]) -> str: # See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb @@ -27,50 +31,72 @@ def _make_system_content(self, sources: list[Source]) -> str: ) return instruction + "\n\n".join(source.content for source in sources) - async def _call_api( - self, prompt: str, sources: list[Source], *, max_new_tokens: int + def _stream( + self, + method: str, + url: str, + **kwargs: Any, + ) -> AsyncIterator[dict[str, Any]]: + stream = { + "sse": self._stream_sse, + "jsonl": self._stream_jsonl, + }.get(self._STREAMING_METHOD) + if stream is None: + raise RagnaException + + return stream(method, url, **kwargs) + + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - import httpx_sse - # See https://platform.openai.com/docs/api-reference/chat/create # and https://platform.openai.com/docs/api-reference/chat/streaming - async with httpx_sse.aconnect_sse( - self._client, - "POST", - "https://api.openai.com/v1/chat/completions", - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {self._api_key}", - }, - json={ - "messages": [ - { - "role": "system", - "content": self._make_system_content(sources), - }, - { - "role": "user", - "content": prompt, - }, - ], - "model": self._MODEL, - "temperature": 0.0, - "max_tokens": max_new_tokens, - "stream": True, - }, - ) as event_source: - await self._assert_api_call_is_success(event_source.response) - - async for sse in event_source.aiter_sse(): - data = json.loads(sse.data) - choice = data["choices"][0] - if choice["finish_reason"] is not None: - break - - yield cast(str, choice["delta"]["content"]) - - -class Gpt35Turbo16k(OpenaiApiAssistant): + headers = { + "Content-Type": "application/json", + } + if self._api_key is not None: + headers["Authorization"] = f"Bearer {self._api_key}" + + json_ = { + "messages": [ + { + "role": "system", + "content": self._make_system_content(sources), + }, + { + "role": "user", + "content": prompt, + }, + ], + "temperature": 0.0, + "max_tokens": max_new_tokens, + "stream": True, + } + if self._MODEL is not None: + json_["model"] = self._MODEL + + async for data in self._stream("POST", self._url, headers=headers, json=json_): + choice = data["choices"][0] + if choice["finish_reason"] is not None: + break + + yield cast(str, choice["delta"]["content"]) + + +class OpenaiAssistant(OpenaiCompliantHttpApiAssistant): + _API_KEY_ENV_VAR = "OPENAI_API_KEY" + _STREAMING_METHOD = "sse" + + @classmethod + def display_name(cls) -> str: + return f"OpenAI/{cls._MODEL}" + + @property + def _url(self) -> str: + return "https://api.openai.com/v1/chat/completions" + + +class Gpt35Turbo16k(OpenaiAssistant): """[OpenAI GPT-3.5](https://platform.openai.com/docs/models/gpt-3-5) !!! info "Required environment variables" @@ -85,7 +111,7 @@ class Gpt35Turbo16k(OpenaiApiAssistant): _MODEL = "gpt-3.5-turbo-16k" -class Gpt4(OpenaiApiAssistant): +class Gpt4(OpenaiAssistant): """[OpenAI GPT-4](https://platform.openai.com/docs/models/gpt-4) !!! info "Required environment variables" diff --git a/tests/assistants/test_api.py b/tests/assistants/test_api.py index 97961456..02b964b5 100644 --- a/tests/assistants/test_api.py +++ b/tests/assistants/test_api.py @@ -4,21 +4,24 @@ from ragna import assistants from ragna._compat import anext -from ragna.assistants._api import ApiAssistant +from ragna.assistants._http_api import HttpApiAssistant from ragna.core import RagnaException from tests.utils import skip_on_windows -API_ASSISTANTS = [ +HTTP_API_ASSISTANTS = [ assistant for assistant in assistants.__dict__.values() if isinstance(assistant, type) - and issubclass(assistant, ApiAssistant) - and assistant is not ApiAssistant + and issubclass(assistant, HttpApiAssistant) + and assistant is not HttpApiAssistant ] @skip_on_windows -@pytest.mark.parametrize("assistant", API_ASSISTANTS) +@pytest.mark.parametrize( + "assistant", + [assistant for assistant in HTTP_API_ASSISTANTS if assistant._API_KEY_ENV_VAR], +) async def test_api_call_error_smoke(mocker, assistant): mocker.patch.dict(os.environ, {assistant._API_KEY_ENV_VAR: "SENTINEL"}) From da1bcc2026cdf5679d4afa226b746edb63fc3b7d Mon Sep 17 00:00:00 2001 From: William Black <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 10 Jun 2024 02:25:16 -0700 Subject: [PATCH 4/5] [ENH] Add support for Ollama assistants (#376) Co-authored-by: Philip Meier --- docs/examples/gallery_streaming.py | 8 ++ docs/tutorials/gallery_python_api.py | 8 ++ ragna/assistants/__init__.py | 16 +++ ragna/assistants/_ai21labs.py | 10 +- ragna/assistants/_anthropic.py | 5 +- ragna/assistants/_cohere.py | 5 +- ragna/assistants/_google.py | 49 ++----- ragna/assistants/_http_api.py | 195 +++++++++++++++++++++------ ragna/assistants/_llamafile.py | 14 +- ragna/assistants/_ollama.py | 83 ++++++++++++ ragna/assistants/_openai.py | 48 +++---- 11 files changed, 315 insertions(+), 126 deletions(-) create mode 100644 ragna/assistants/_ollama.py diff --git a/docs/examples/gallery_streaming.py b/docs/examples/gallery_streaming.py index 84d92d08..9ccdecb4 100644 --- a/docs/examples/gallery_streaming.py +++ b/docs/examples/gallery_streaming.py @@ -31,6 +31,14 @@ # - [ragna.assistants.Gpt4][] # - [llamafile](https://github.com/Mozilla-Ocho/llamafile) # - [ragna.assistants.LlamafileAssistant][] +# - [Ollama](https://ollama.com/) +# - [ragna.assistants.OllamaGemma2B][] +# - [ragna.assistants.OllamaLlama2][] +# - [ragna.assistants.OllamaLlava][] +# - [ragna.assistants.OllamaMistral][] +# - [ragna.assistants.OllamaMixtral][] +# - [ragna.assistants.OllamaOrcaMini][] +# - [ragna.assistants.OllamaPhi2][] from ragna import assistants diff --git a/docs/tutorials/gallery_python_api.py b/docs/tutorials/gallery_python_api.py index a7d0ef63..23667a71 100644 --- a/docs/tutorials/gallery_python_api.py +++ b/docs/tutorials/gallery_python_api.py @@ -87,6 +87,14 @@ # - [ragna.assistants.Jurassic2Ultra][] # - [llamafile](https://github.com/Mozilla-Ocho/llamafile) # - [ragna.assistants.LlamafileAssistant][] +# - [Ollama](https://ollama.com/) +# - [ragna.assistants.OllamaGemma2B][] +# - [ragna.assistants.OllamaLlama2][] +# - [ragna.assistants.OllamaLlava][] +# - [ragna.assistants.OllamaMistral][] +# - [ragna.assistants.OllamaMixtral][] +# - [ragna.assistants.OllamaOrcaMini][] +# - [ragna.assistants.OllamaPhi2][] # # !!! note # diff --git a/ragna/assistants/__init__.py b/ragna/assistants/__init__.py index d583e7a0..bcf5ead6 100644 --- a/ragna/assistants/__init__.py +++ b/ragna/assistants/__init__.py @@ -6,6 +6,13 @@ "CommandLight", "GeminiPro", "GeminiUltra", + "OllamaGemma2B", + "OllamaPhi2", + "OllamaLlama2", + "OllamaLlava", + "OllamaMistral", + "OllamaMixtral", + "OllamaOrcaMini", "Gpt35Turbo16k", "Gpt4", "Jurassic2Ultra", @@ -19,6 +26,15 @@ from ._demo import RagnaDemoAssistant from ._google import GeminiPro, GeminiUltra from ._llamafile import LlamafileAssistant +from ._ollama import ( + OllamaGemma2B, + OllamaLlama2, + OllamaLlava, + OllamaMistral, + OllamaMixtral, + OllamaOrcaMini, + OllamaPhi2, +) from ._openai import Gpt4, Gpt35Turbo16k # isort: split diff --git a/ragna/assistants/_ai21labs.py b/ragna/assistants/_ai21labs.py index 1c61a213..3e0c56b5 100644 --- a/ragna/assistants/_ai21labs.py +++ b/ragna/assistants/_ai21labs.py @@ -7,6 +7,7 @@ class Ai21LabsAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "AI21_API_KEY" + _STREAMING_PROTOCOL = None _MODEL_TYPE: str @classmethod @@ -27,7 +28,8 @@ async def answer( # See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters # See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters # See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response - response = await self._client.post( + async for data in self._call_api( + "POST", f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat", headers={ "accept": "application/json", @@ -46,10 +48,8 @@ async def answer( ], "system": self._make_system_content(sources), }, - ) - await self._assert_api_call_is_success(response) - - yield cast(str, response.json()["outputs"][0]["text"]) + ): + yield cast(str, data["outputs"][0]["text"]) # The Jurassic2Mid assistant receives a 500 internal service error from the remote diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 37f132b5..d74fc840 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -2,11 +2,12 @@ from ragna.core import PackageRequirement, RagnaException, Requirement, Source -from ._http_api import HttpApiAssistant +from ._http_api import HttpApiAssistant, HttpStreamingProtocol class AnthropicAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "ANTHROPIC_API_KEY" + _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE _MODEL: str @classmethod @@ -40,7 +41,7 @@ async def answer( ) -> AsyncIterator[str]: # See https://docs.anthropic.com/claude/reference/messages_post # See https://docs.anthropic.com/claude/reference/streaming - async for data in self._stream_sse( + async for data in self._call_api( "POST", "https://api.anthropic.com/v1/messages", headers={ diff --git a/ragna/assistants/_cohere.py b/ragna/assistants/_cohere.py index b47737f8..4108d31b 100644 --- a/ragna/assistants/_cohere.py +++ b/ragna/assistants/_cohere.py @@ -2,11 +2,12 @@ from ragna.core import RagnaException, Source -from ._http_api import HttpApiAssistant +from ._http_api import HttpApiAssistant, HttpStreamingProtocol class CohereAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "COHERE_API_KEY" + _STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL _MODEL: str @classmethod @@ -29,7 +30,7 @@ async def answer( # See https://docs.cohere.com/docs/cochat-beta # See https://docs.cohere.com/reference/chat # See https://docs.cohere.com/docs/retrieval-augmented-generation-rag - async for event in self._stream_jsonl( + async for event in self._call_api( "POST", "https://api.cohere.ai/v1/chat", headers={ diff --git a/ragna/assistants/_google.py b/ragna/assistants/_google.py index 8e1caf1e..70c82936 100644 --- a/ragna/assistants/_google.py +++ b/ragna/assistants/_google.py @@ -1,38 +1,15 @@ from typing import AsyncIterator -from ragna._compat import anext -from ragna.core import PackageRequirement, Requirement, Source +from ragna.core import Source -from ._http_api import HttpApiAssistant - - -# ijson does not support reading from an (async) iterator, but only from file-like -# objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects. -# See https://github.com/ICRAR/ijson/issues/44 for details. -# ijson actually doesn't care about most of the file interface and only requires the -# read() method to be present. -class AsyncIteratorReader: - def __init__(self, ait: AsyncIterator[bytes]) -> None: - self._ait = ait - - async def read(self, n: int) -> bytes: - # n is usually used to indicate how many bytes to read, but since we want to - # return a chunk as soon as it is available, we ignore the value of n. The only - # exception is n == 0, which is used by ijson to probe the return type and - # set up decoding. - if n == 0: - return b"" - return await anext(self._ait, b"") # type: ignore[call-arg] +from ._http_api import HttpApiAssistant, HttpStreamingProtocol class GoogleAssistant(HttpApiAssistant): _API_KEY_ENV_VAR = "GOOGLE_API_KEY" + _STREAMING_PROTOCOL = HttpStreamingProtocol.JSON _MODEL: str - @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [PackageRequirement("ijson")] - @classmethod def display_name(cls) -> str: return f"Google/{cls._MODEL}" @@ -51,9 +28,7 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: async def answer( self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 ) -> AsyncIterator[str]: - import ijson - - async with self._client.stream( + async for chunk in self._call_api( "POST", f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent", params={"key": self._api_key}, @@ -64,7 +39,10 @@ async def answer( ], # https://ai.google.dev/docs/safety_setting_gemini "safetySettings": [ - {"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"} + { + "category": f"HARM_CATEGORY_{category}", + "threshold": "BLOCK_NONE", + } for category in [ "HARASSMENT", "HATE_SPEECH", @@ -78,14 +56,9 @@ async def answer( "maxOutputTokens": max_new_tokens, }, }, - ) as response: - await self._assert_api_call_is_success(response) - - async for chunk in ijson.items( - AsyncIteratorReader(response.aiter_bytes(1024)), - "item.candidates.item.content.parts.item.text", - ): - yield chunk + parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"), + ): + yield chunk class GeminiPro(GoogleAssistant): diff --git a/ragna/assistants/_http_api.py b/ragna/assistants/_http_api.py index 1151a62a..d6f48a26 100644 --- a/ragna/assistants/_http_api.py +++ b/ragna/assistants/_http_api.py @@ -1,65 +1,83 @@ import contextlib +import enum import json import os from typing import Any, AsyncIterator, Optional import httpx -from httpx import Response import ragna -from ragna.core import Assistant, EnvVarRequirement, RagnaException, Requirement +from ragna._compat import anext +from ragna.core import ( + Assistant, + EnvVarRequirement, + PackageRequirement, + RagnaException, + Requirement, +) -class HttpApiAssistant(Assistant): - _API_KEY_ENV_VAR: Optional[str] +class HttpStreamingProtocol(enum.Enum): + SSE = enum.auto() + JSONL = enum.auto() + JSON = enum.auto() - @classmethod - def requirements(cls) -> list[Requirement]: - requirements: list[Requirement] = ( - [EnvVarRequirement(cls._API_KEY_ENV_VAR)] - if cls._API_KEY_ENV_VAR is not None - else [] - ) - requirements.extend(cls._extra_requirements()) - return requirements +class HttpApiCaller: @classmethod - def _extra_requirements(cls) -> list[Requirement]: - return [] + def requirements(cls, protocol: HttpStreamingProtocol) -> list[Requirement]: + streaming_requirements: dict[HttpStreamingProtocol, list[Requirement]] = { + HttpStreamingProtocol.SSE: [PackageRequirement("httpx_sse")], + HttpStreamingProtocol.JSON: [PackageRequirement("ijson")], + } + return streaming_requirements.get(protocol, []) - def __init__(self) -> None: - self._client = httpx.AsyncClient( - headers={"User-Agent": f"{ragna.__version__}/{self}"}, - timeout=60, - ) - self._api_key: Optional[str] = ( - os.environ[self._API_KEY_ENV_VAR] - if self._API_KEY_ENV_VAR is not None - else None - ) - - async def _assert_api_call_is_success(self, response: Response) -> None: - if response.is_success: - return + def __init__( + self, + client: httpx.AsyncClient, + protocol: Optional[HttpStreamingProtocol] = None, + ) -> None: + self._client = client + self._protocol = protocol - content = await response.aread() - with contextlib.suppress(Exception): - content = json.loads(content) + def __call__( + self, + method: str, + url: str, + *, + parse_kwargs: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + if self._protocol is None: + call_method = self._no_stream + else: + call_method = { + HttpStreamingProtocol.SSE: self._stream_sse, + HttpStreamingProtocol.JSONL: self._stream_jsonl, + HttpStreamingProtocol.JSON: self._stream_json, + }[self._protocol] + return call_method(method, url, parse_kwargs=parse_kwargs or {}, **kwargs) - raise RagnaException( - "API call failed", - request_method=response.request.method, - request_url=str(response.request.url), - response_status_code=response.status_code, - response_content=content, - ) + async def _no_stream( + self, + method: str, + url: str, + *, + parse_kwargs: dict[str, Any], + **kwargs: Any, + ) -> AsyncIterator[Any]: + response = await self._client.request(method, url, **kwargs) + await self._assert_api_call_is_success(response) + yield response.json() async def _stream_sse( self, method: str, url: str, + *, + parse_kwargs: dict[str, Any], **kwargs: Any, - ) -> AsyncIterator[dict[str, Any]]: + ) -> AsyncIterator[Any]: import httpx_sse async with httpx_sse.aconnect_sse( @@ -71,10 +89,103 @@ async def _stream_sse( yield json.loads(sse.data) async def _stream_jsonl( - self, method: str, url: str, **kwargs: Any - ) -> AsyncIterator[dict[str, Any]]: + self, + method: str, + url: str, + *, + parse_kwargs: dict[str, Any], + **kwargs: Any, + ) -> AsyncIterator[Any]: async with self._client.stream(method, url, **kwargs) as response: await self._assert_api_call_is_success(response) async for chunk in response.aiter_lines(): yield json.loads(chunk) + + # ijson does not support reading from an (async) iterator, but only from file-like + # objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects. + # See https://github.com/ICRAR/ijson/issues/44 for details. + # ijson actually doesn't care about most of the file interface and only requires the + # read() method to be present. + class _AsyncIteratorReader: + def __init__(self, ait: AsyncIterator[bytes]) -> None: + self._ait = ait + + async def read(self, n: int) -> bytes: + # n is usually used to indicate how many bytes to read, but since we want to + # return a chunk as soon as it is available, we ignore the value of n. The + # only exception is n == 0, which is used by ijson to probe the return type + # and set up decoding. + if n == 0: + return b"" + return await anext(self._ait, b"") # type: ignore[call-arg] + + async def _stream_json( + self, + method: str, + url: str, + *, + parse_kwargs: dict[str, Any], + **kwargs: Any, + ) -> AsyncIterator[Any]: + import ijson + + item = parse_kwargs["item"] + chunk_size = parse_kwargs.get("chunk_size", 16) + + async with self._client.stream(method, url, **kwargs) as response: + await self._assert_api_call_is_success(response) + + async for chunk in ijson.items( + self._AsyncIteratorReader(response.aiter_bytes(chunk_size)), item + ): + yield chunk + + async def _assert_api_call_is_success(self, response: httpx.Response) -> None: + if response.is_success: + return + + content = await response.aread() + with contextlib.suppress(Exception): + content = json.loads(content) + + raise RagnaException( + "API call failed", + request_method=response.request.method, + request_url=str(response.request.url), + response_status_code=response.status_code, + response_content=content, + ) + + +class HttpApiAssistant(Assistant): + _API_KEY_ENV_VAR: Optional[str] + _STREAMING_PROTOCOL: Optional[HttpStreamingProtocol] + + @classmethod + def requirements(cls) -> list[Requirement]: + requirements: list[Requirement] = ( + [EnvVarRequirement(cls._API_KEY_ENV_VAR)] + if cls._API_KEY_ENV_VAR is not None + else [] + ) + if cls._STREAMING_PROTOCOL is not None: + requirements.extend(HttpApiCaller.requirements(cls._STREAMING_PROTOCOL)) + requirements.extend(cls._extra_requirements()) + return requirements + + @classmethod + def _extra_requirements(cls) -> list[Requirement]: + return [] + + def __init__(self) -> None: + self._client = httpx.AsyncClient( + headers={"User-Agent": f"{ragna.__version__}/{self}"}, + timeout=60, + ) + self._api_key: Optional[str] = ( + os.environ[self._API_KEY_ENV_VAR] + if self._API_KEY_ENV_VAR is not None + else None + ) + self._call_api = HttpApiCaller(self._client, self._STREAMING_PROTOCOL) diff --git a/ragna/assistants/_llamafile.py b/ragna/assistants/_llamafile.py index 3e78a625..5c7cc1da 100644 --- a/ragna/assistants/_llamafile.py +++ b/ragna/assistants/_llamafile.py @@ -1,9 +1,11 @@ import os +from functools import cached_property -from ._openai import OpenaiCompliantHttpApiAssistant +from ._http_api import HttpStreamingProtocol +from ._openai import OpenaiLikeHttpApiAssistant -class LlamafileAssistant(OpenaiCompliantHttpApiAssistant): +class LlamafileAssistant(OpenaiLikeHttpApiAssistant): """[llamafile](https://github.com/Mozilla-Ocho/llamafile) To use this assistant, start the llamafile server manually. By default, the server @@ -16,10 +18,14 @@ class LlamafileAssistant(OpenaiCompliantHttpApiAssistant): """ _API_KEY_ENV_VAR = None - _STREAMING_METHOD = "sse" + _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE _MODEL = None - @property + @classmethod + def display_name(cls) -> str: + return "llamafile" + + @cached_property def _url(self) -> str: base_url = os.environ.get("RAGNA_LLAMAFILE_BASE_URL", "http://localhost:8080") return f"{base_url}/v1/chat/completions" diff --git a/ragna/assistants/_ollama.py b/ragna/assistants/_ollama.py new file mode 100644 index 00000000..3bb23c9f --- /dev/null +++ b/ragna/assistants/_ollama.py @@ -0,0 +1,83 @@ +import os +from functools import cached_property +from typing import AsyncIterator, cast + +from ragna.core import RagnaException, Source + +from ._http_api import HttpStreamingProtocol +from ._openai import OpenaiLikeHttpApiAssistant + + +class OllamaAssistant(OpenaiLikeHttpApiAssistant): + """[Ollama](https://ollama.com/) + + To use this assistant, start the Ollama server manually. By default, the server + is expected at `http://localhost:11434`. This can be changed with the + `RAGNA_OLLAMA_BASE_URL` environment variable. + """ + + _API_KEY_ENV_VAR = None + _STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL + _MODEL: str + + @classmethod + def display_name(cls) -> str: + return f"Ollama/{cls._MODEL}" + + @cached_property + def _url(self) -> str: + base_url = os.environ.get("RAGNA_OLLAMA_BASE_URL", "http://localhost:11434") + return f"{base_url}/api/chat" + + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens): + # Modeled after + # https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62 + if "error" in data: + raise RagnaException(data["error"]) + if not data["done"]: + yield cast(str, data["message"]["content"]) + + +class OllamaGemma2B(OllamaAssistant): + """[Gemma:2B](https://ollama.com/library/gemma)""" + + _MODEL = "gemma:2b" + + +class OllamaLlama2(OllamaAssistant): + """[Llama 2](https://ollama.com/library/llama2)""" + + _MODEL = "llama2" + + +class OllamaLlava(OllamaAssistant): + """[Llava](https://ollama.com/library/llava)""" + + _MODEL = "llava" + + +class OllamaMistral(OllamaAssistant): + """[Mistral](https://ollama.com/library/mistral)""" + + _MODEL = "mistral" + + +class OllamaMixtral(OllamaAssistant): + """[Mixtral](https://ollama.com/library/mixtral)""" + + _MODEL = "mixtral" + + +class OllamaOrcaMini(OllamaAssistant): + """[Orca Mini](https://ollama.com/library/orca-mini)""" + + _MODEL = "orca-mini" + + +class OllamaPhi2(OllamaAssistant): + """[Phi-2](https://ollama.com/library/phi)""" + + _MODEL = "phi" diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 37957be2..0f51d6d9 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,23 +1,15 @@ import abc -from typing import Any, AsyncIterator, Literal, Optional, cast +from functools import cached_property +from typing import Any, AsyncIterator, Optional, cast -from ragna.core import PackageRequirement, RagnaException, Requirement, Source +from ragna.core import Source -from ._http_api import HttpApiAssistant +from ._http_api import HttpApiAssistant, HttpStreamingProtocol -class OpenaiCompliantHttpApiAssistant(HttpApiAssistant): - _STREAMING_METHOD: Literal["sse", "jsonl"] +class OpenaiLikeHttpApiAssistant(HttpApiAssistant): _MODEL: Optional[str] - @classmethod - def requirements(cls) -> list[Requirement]: - requirements = super().requirements() - requirements.extend( - {"sse": [PackageRequirement("httpx_sse")]}.get(cls._STREAMING_METHOD, []) - ) - return requirements - @property @abc.abstractmethod def _url(self) -> str: ... @@ -32,23 +24,8 @@ def _make_system_content(self, sources: list[Source]) -> str: return instruction + "\n\n".join(source.content for source in sources) def _stream( - self, - method: str, - url: str, - **kwargs: Any, + self, prompt: str, sources: list[Source], *, max_new_tokens: int ) -> AsyncIterator[dict[str, Any]]: - stream = { - "sse": self._stream_sse, - "jsonl": self._stream_jsonl, - }.get(self._STREAMING_METHOD) - if stream is None: - raise RagnaException - - return stream(method, url, **kwargs) - - async def answer( - self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 - ) -> AsyncIterator[str]: # See https://platform.openai.com/docs/api-reference/chat/create # and https://platform.openai.com/docs/api-reference/chat/streaming headers = { @@ -75,7 +52,12 @@ async def answer( if self._MODEL is not None: json_["model"] = self._MODEL - async for data in self._stream("POST", self._url, headers=headers, json=json_): + return self._call_api("POST", self._url, headers=headers, json=json_) + + async def answer( + self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 + ) -> AsyncIterator[str]: + async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens): choice = data["choices"][0] if choice["finish_reason"] is not None: break @@ -83,15 +65,15 @@ async def answer( yield cast(str, choice["delta"]["content"]) -class OpenaiAssistant(OpenaiCompliantHttpApiAssistant): +class OpenaiAssistant(OpenaiLikeHttpApiAssistant): _API_KEY_ENV_VAR = "OPENAI_API_KEY" - _STREAMING_METHOD = "sse" + _STREAMING_PROTOCOL = HttpStreamingProtocol.SSE @classmethod def display_name(cls) -> str: return f"OpenAI/{cls._MODEL}" - @property + @cached_property def _url(self) -> str: return "https://api.openai.com/v1/chat/completions" From 6ef051a7a277bbb0edb0003ca0609f73202993f7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 Jun 2024 12:59:15 +0200 Subject: [PATCH 5/5] remove welcome modal (#427) --- ragna/deploy/_ui/css/modal_welcome/button.css | 4 -- ragna/deploy/_ui/main_page.py | 69 ++----------------- ragna/deploy/_ui/modal_welcome.py | 42 ----------- ragna/deploy/_ui/styles.py | 4 -- 4 files changed, 4 insertions(+), 115 deletions(-) delete mode 100644 ragna/deploy/_ui/css/modal_welcome/button.css delete mode 100644 ragna/deploy/_ui/modal_welcome.py diff --git a/ragna/deploy/_ui/css/modal_welcome/button.css b/ragna/deploy/_ui/css/modal_welcome/button.css deleted file mode 100644 index 5c98c041..00000000 --- a/ragna/deploy/_ui/css/modal_welcome/button.css +++ /dev/null @@ -1,4 +0,0 @@ -:host(.modal_welcome_close_button) { - width: 35%; - margin-left: 60%; -} diff --git a/ragna/deploy/_ui/main_page.py b/ragna/deploy/_ui/main_page.py index c8610e7b..4ba5ba94 100644 --- a/ragna/deploy/_ui/main_page.py +++ b/ragna/deploy/_ui/main_page.py @@ -3,12 +3,9 @@ import panel as pn import param -from . import js -from . import styles as ui from .central_view import CentralView from .left_sidebar import LeftSidebar from .modal_configuration import ModalConfiguration -from .modal_welcome import ModalWelcome from .right_sidebar import RightSidebar @@ -71,14 +68,6 @@ def open_modal(self): self.template.modal.objects[0].objects = [self.modal] self.template.open_modal() - def open_welcome_modal(self, event): - self.modal = ModalWelcome( - close_button_callback=lambda: self.template.close_modal(), - ) - - self.template.modal.objects[0].objects = [self.modal] - self.template.open_modal() - async def open_new_chat(self, new_chat_id): # called after creating a new chat. self.current_chat_id = new_chat_id @@ -111,59 +100,9 @@ def update_subviews_current_chat_id(self, avoid_senders=[]): def __panel__(self): asyncio.ensure_future(self.refresh_data()) - objects = [self.left_sidebar, self.central_view, self.right_sidebar] - - if self.chats is not None and len(self.chats) == 0: - """I haven't found a better way to open the modal when the pages load, - than simulating a click on the "New chat" button. - - calling self.template.open_modal() doesn't work - - calling self.on_click_new_chat doesn't work either - - trying to schedule a call to on_click_new_chat with pn.state.schedule_task - could have worked but my tests were yielding an unstable result. - """ - - new_chat_button_name = "open welcome modal" - open_welcome_modal = pn.widgets.Button( - name=new_chat_button_name, - button_type="primary", - ) - open_welcome_modal.on_click(self.open_welcome_modal) - - hack_open_modal = pn.pane.HTML( - """ - - """.replace( - "{new_chat_btn_name}", new_chat_button_name - ).strip(), - # This is not really styling per say, it's just a way to hide from the page the HTML item of this hack. - # It's not worth moving this to a separate file. - stylesheets=[ - ui.css( - ":host", - {"position": "absolute", "z-index": "-999"}, - ) - ], - ) - - objects.append( - pn.Row( - open_welcome_modal, - pn.pane.HTML(js.SHADOWROOT_INDEXING), - hack_open_modal, - visible=False, - ) - ) - - main_page = pn.Row( - *objects, + return pn.Row( + self.left_sidebar, + self.central_view, + self.right_sidebar, css_classes=["main_page_main_row"], ) - - return main_page diff --git a/ragna/deploy/_ui/modal_welcome.py b/ragna/deploy/_ui/modal_welcome.py deleted file mode 100644 index 71b6ad7f..00000000 --- a/ragna/deploy/_ui/modal_welcome.py +++ /dev/null @@ -1,42 +0,0 @@ -import panel as pn -import param - -from . import js -from . import styles as ui - - -class ModalWelcome(pn.viewable.Viewer): - close_button_callback = param.Callable() - - def __init__(self, **params): - super().__init__(**params) - - def did_click_on_close_button(self, event): - if self.close_button_callback is not None: - self.close_button_callback() - - def __panel__(self): - close_button = pn.widgets.Button( - name="Okay, let's go", - button_type="primary", - css_classes=["modal_welcome_close_button"], - ) - close_button.on_click(self.did_click_on_close_button) - - return pn.Column( - pn.pane.HTML( - f"""""" - + """

Welcome !


- Ragna is a RAG Orchestration Framework.
- With its UI, select and configure LLMs, upload documents, and chat with the LLM.
-
- Use Ragna UI out-of-the-box, as a daily-life interface with your favorite AI,
- or as a reference to build custom web applications. -


- """ - ), - close_button, - width=ui.WELCOME_MODAL_WIDTH, - height=ui.WELCOME_MODAL_HEIGHT, - sizing_mode="fixed", - ) diff --git a/ragna/deploy/_ui/styles.py b/ragna/deploy/_ui/styles.py index 7f994eeb..213e6e1a 100644 --- a/ragna/deploy/_ui/styles.py +++ b/ragna/deploy/_ui/styles.py @@ -46,7 +46,6 @@ "right_sidebar": [pn.widgets.Button, pn.Column, pn.pane.Markdown], "left_sidebar": [pn.widgets.Button, pn.pane.HTML, pn.Column], "main_page": [pn.Row], - "modal_welcome": [pn.widgets.Button], "modal_configuration": [ pn.widgets.IntSlider, pn.layout.Card, @@ -103,9 +102,6 @@ def css(selector: Union[str, Iterable[str]], declarations: dict[str, str]) -> st CONFIG_MODAL_MAX_HEIGHT = 850 CONFIG_MODAL_WIDTH = 800 -WELCOME_MODAL_HEIGHT = 275 -WELCOME_MODAL_WIDTH = 530 - CSS_VARS = css( ":root",