Skip to content

Commit

Permalink
enable lazy caching for chatinterface (#10015)
Browse files Browse the repository at this point in the history
* lazy chat

* add changeset

* lazy caching

* lazy caching

* revert

* fix this

* changes

* changes

* format

* changes

* add env variable

* revert

* add changeset

* lazy

* fix

* chat interface

* fix test

---------

Co-authored-by: gradio-pr-bot <[email protected]>
  • Loading branch information
abidlabs and gradio-pr-bot authored Nov 23, 2024
1 parent 369a44e commit db162bf
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 26 deletions.
5 changes: 5 additions & 0 deletions .changeset/many-horses-judge.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": patch
---

fix:enable lazy caching for chatinterface
3 changes: 2 additions & 1 deletion .config/playwright-setup.js
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ function spawn_gradio_app(app, port, verbose) {
...process.env,
PYTHONUNBUFFERED: "true",
GRADIO_ANALYTICS_ENABLED: "False",
GRADIO_IS_E2E_TEST: "1"
GRADIO_IS_E2E_TEST: "1",
GRADIO_RESET_EXAMPLES_CACHE: "True"
}
});
_process.stdout.setEncoding("utf8");
Expand Down
27 changes: 27 additions & 0 deletions demo/test_chatinterface_examples/lazy_caching_examples_testcase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import gradio as gr

def generate(
message: str,
chat_history: list[dict],
):

output = ""
for character in message:
output += character
yield output


demo = gr.ChatInterface(
fn=generate,
examples=[
["Hey"],
["Can you explain briefly to me what is the Python programming language?"],
],
cache_examples=True,
cache_mode="lazy",
type="messages",
)


if __name__ == "__main__":
demo.launch()
2 changes: 1 addition & 1 deletion demo/test_chatinterface_examples/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_examples"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/eager_caching_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_messages_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_tuples_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/tuples_examples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def generate(\n", " message: str,\n", " chat_history: list[dict],\n", "):\n", "\n", " output = \"\"\n", " for character in message:\n", " output += character\n", " yield output\n", "\n", "\n", "demo = gr.ChatInterface(\n", " fn=generate,\n", " examples=[\n", " [\"Hey\"],\n", " [\"Can you explain briefly to me what is the Python programming language?\"],\n", " ],\n", " cache_examples=False,\n", " type=\"messages\",\n", ")\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: test_chatinterface_examples"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["# Downloading files from the demo repo\n", "import os\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/eager_caching_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/lazy_caching_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_messages_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/multimodal_tuples_examples_testcase.py\n", "!wget -q https://github.com/gradio-app/gradio/raw/main/demo/test_chatinterface_examples/tuples_examples_testcase.py"]}, {"cell_type": "code", "execution_count": null, "id": "44380577570523278879349135829904343037", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def generate(\n", " message: str,\n", " chat_history: list[dict],\n", "):\n", "\n", " output = \"\"\n", " for character in message:\n", " output += character\n", " yield output\n", "\n", "\n", "demo = gr.ChatInterface(\n", " fn=generate,\n", " examples=[\n", " [\"Hey\"],\n", " [\"Can you explain briefly to me what is the Python programming language?\"],\n", " ],\n", " cache_examples=False,\n", " type=\"messages\",\n", ")\n", "\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
25 changes: 14 additions & 11 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import functools
import inspect
import warnings
from collections.abc import AsyncGenerator, Callable, Sequence
from collections.abc import AsyncGenerator, Callable, Generator, Sequence
from pathlib import Path
from typing import Literal, Union, cast

Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(
example_labels: labels for the examples, to be displayed instead of the examples themselves. If provided, should be a list of strings with the same length as the examples list. Only applies when examples are displayed within the chatbot (i.e. when `additional_inputs` is not provided).
example_icons: icons for the examples, to be displayed above the examples. If provided, should be a list of string URLs or local paths with the same length as the examples list. Only applies when examples are displayed within the chatbot (i.e. when `additional_inputs` is not provided).
cache_examples: if True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
cache_mode: if "eager", all examples are cached at app launch. The "lazy" option is not yet supported. If None, will use the GRADIO_CACHE_MODE environment variable if defined, or default to "eager".
cache_mode: if "eager", all examples are cached at app launch. If "lazy", examples are cached for all users after the first use by any user of the app. If None, will use the GRADIO_CACHE_MODE environment variable if defined, or default to "eager".
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
description: a description for the interface; if provided, appears above the chatbot and beneath the title in regular font. Accepts Markdown and HTML content.
theme: a Theme object or a string representing a theme. If a string, will look for a built-in theme with that name (e.g. "soft" or "default"), or will attempt to load a theme from the Hugging Face Hub (e.g. "gradio/monochrome"). If None, will use the Default theme.
Expand Down Expand Up @@ -369,7 +369,7 @@ def _setup_events(self) -> None:
and self.examples
and not self._additional_inputs_in_examples
):
if self.cache_examples and self.cache_mode == "eager":
if self.cache_examples:
self.chatbot.example_select(
self.example_clicked,
None,
Expand Down Expand Up @@ -718,15 +718,15 @@ def option_clicked(

def example_clicked(
self, example: SelectData
) -> tuple[TupleFormat | list[MessageDict], str | MultimodalPostprocess]:
) -> Generator[
tuple[TupleFormat | list[MessageDict], str | MultimodalPostprocess], None, None
]:
"""
When an example is clicked, the chat history is set to the complete example value
(including files). The saved input value is also set to complete example value
if multimodal is True, otherwise it is set to the text of the example.
When an example is clicked, the chat history (and saved input) is initially set only
to the example message. Then, if example caching is enabled, the cached response is loaded
and added to the chat history as well.
"""
if self.cache_examples and self.cache_mode == "eager":
history = self.examples_handler.load_from_cache(example.index)[0].root
elif self.type == "tuples":
if self.type == "tuples":
history = [(example.value["text"], None)]
for file in example.value.get("files", []):
history.append(((file["path"]), None))
Expand All @@ -735,7 +735,10 @@ def example_clicked(
for file in example.value.get("files", []):
history.append(MessageDict(role="user", content=file))
message = example.value if self.multimodal else example.value["text"]
return history, message
yield history, message
if self.cache_examples:
history = self.examples_handler.load_from_cache(example.index)[0].root
yield history, message

def _process_example(
self, message: ExampleMessage | str, response: MessageDict | str | None
Expand Down
24 changes: 22 additions & 2 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import csv
import inspect
import os
import shutil
import warnings
from collections.abc import Callable, Iterable, Sequence
from functools import partial
Expand Down Expand Up @@ -276,6 +277,11 @@ def __init__(
simplify_file_data=False, verbose=False, dataset_file_name="log.csv"
)
self.cached_folder = utils.get_cache_folder() / str(self.dataset._id)
if (
os.environ.get("GRADIO_RESET_EXAMPLES_CACHE") == "True"
and self.cached_folder.exists()
):
shutil.rmtree(self.cached_folder)
self.cached_file = Path(self.cached_folder) / "log.csv"
self.cached_indices_file = Path(self.cached_folder) / "indices.csv"
self.run_on_click = run_on_click
Expand Down Expand Up @@ -495,13 +501,15 @@ def sync_lazy_cache(self, example_value: tuple[int, list[Any]], *input_values):
with open(self.cached_indices_file, "a") as f:
f.write(f"{example_index}\n")

async def cache(self) -> None:
async def cache(self, example_id: int | None = None) -> None:
"""
Caches examples so that their predictions can be shown immediately.
Parameters:
example_id: The id of the example to process (zero-indexed). If None, all examples are cached.
"""
if self.root_block is None:
raise Error("Cannot cache examples if not in a Blocks context.")
if Path(self.cached_file).exists():
if Path(self.cached_file).exists() and example_id is None:
print(
f"Using cache from '{utils.abspath(self.cached_folder)}' directory. If method or examples have changed since last caching, delete this folder to clear cache.\n"
)
Expand Down Expand Up @@ -548,6 +556,8 @@ async def get_final_item(*args):
if self.outputs is None:
raise ValueError("self.outputs is missing")
for i, example in enumerate(self.examples):
if example_id is not None and i != example_id:
continue
print(f"Caching example {i + 1}/{len(self.examples)}")
processed_input = self._get_processed_example(example)
if self.batch:
Expand All @@ -574,6 +584,16 @@ def load_from_cache(self, example_id: int) -> list[Any]:
Parameters:
example_id: The id of the example to process (zero-indexed).
"""
if self.cache_examples == "lazy":
if cached_index := self._get_cached_index_if_cached(example_id) is None:
client_utils.synchronize_async(self.cache, example_id)
with open(self.cached_indices_file, "a") as f:
f.write(f"{example_id}\n")
with open(self.cached_indices_file) as f:
example_id = len(f.readlines()) - 1
else:
example_id = cached_index

with open(self.cached_file, encoding="utf-8") as cache:
examples = list(csv.reader(cache))
example = examples[example_id + 1] # +1 to adjust for header
Expand Down
10 changes: 10 additions & 0 deletions guides/04_additional-features/09_environment-variables.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ Environment variables in Gradio provide a way to customize your applications and
export GRADIO_NODE_NUM_PORTS=200
```

### 18. `GRADIO_RESET_EXAMPLES_CACHE`

- **Description**: If set to "True", Gradio will delete and recreate the examples cache directory when the app starts instead of reusing the cached example if they already exist.
- **Default**: `"False"`
- **Options**: `"True"`, `"False"`
- **Example**:
```sh
export GRADIO_RESET_EXAMPLES_CACHE="True"
```

## How to Set Environment Variables

To set environment variables in your terminal, use the `export` command followed by the variable name and its value. For example:
Expand Down
3 changes: 2 additions & 1 deletion js/spa/test/test_chatinterface_examples.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ const cases = [
"tuples_examples",
"multimodal_tuples_examples",
"multimodal_messages_examples",
"eager_caching_examples"
"eager_caching_examples",
"lazy_caching_examples"
];

for (const test_case of cases) {
Expand Down
14 changes: 4 additions & 10 deletions test/test_chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_example_caching(self, connect):
assert prediction_hi[0].root[0] == ("hi", "hi hi")

@pytest.mark.asyncio
async def test_example_caching_lazy(self, connect):
async def test_example_caching_lazy(self):
with patch(
"gradio.utils.get_cache_folder", return_value=Path(tempfile.mkdtemp())
):
Expand All @@ -105,16 +105,10 @@ async def test_example_caching_lazy(self, connect):
cache_examples=True,
cache_mode="lazy",
)
async for _ in chatbot.examples_handler.async_lazy_cache(
(0, ["hello"]), "hello"
):
pass
with connect(chatbot):
prediction_hello = chatbot.examples_handler.load_from_cache(0)
prediction_hello = chatbot.examples_handler.load_from_cache(0)
assert prediction_hello[0].root[0] == ("hello", "hello hello")
with pytest.raises(IndexError):
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hi[0].root[0] == ("hi", "hi hi")
prediction_hi = chatbot.examples_handler.load_from_cache(1)
assert prediction_hi[0].root[0] == ("hi", "hi hi")

def test_example_caching_async(self, connect):
with patch(
Expand Down

0 comments on commit db162bf

Please sign in to comment.