diff --git a/CHANGELOG.md b/CHANGELOG.md index 0342fcd..98cc5d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.0.6 + +* **Support streaming response types for /invoke if callable is async generator** + ## 0.0.5 * **Improve logging to hide body in case of sensitive data unless TRACE level** diff --git a/unstructured_platform_plugins/__version__.py b/unstructured_platform_plugins/__version__.py index cec386a..62c73e6 100644 --- a/unstructured_platform_plugins/__version__.py +++ b/unstructured_platform_plugins/__version__.py @@ -1 +1 @@ -__version__ = "0.0.5" # pragma: no cover +__version__ = "0.0.6" # pragma: no cover diff --git a/unstructured_platform_plugins/etl_uvicorn/api_generator.py b/unstructured_platform_plugins/etl_uvicorn/api_generator.py index 9df5f4d..c19cbb2 100644 --- a/unstructured_platform_plugins/etl_uvicorn/api_generator.py +++ b/unstructured_platform_plugins/etl_uvicorn/api_generator.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Optional from fastapi import FastAPI, status +from fastapi.responses import StreamingResponse from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from pydantic import BaseModel from starlette.responses import RedirectResponse @@ -110,7 +111,9 @@ class InvokeResponse(BaseModel): logging.getLogger("etl_uvicorn.fastapi") - async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> InvokeResponse: + ResponseType = StreamingResponse if inspect.isasyncgenfunction(func) else InvokeResponse + + async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> ResponseType: usage: list[UsageData] = [] request_dict = kwargs if kwargs else {} if "usage" in inspect.signature(func).parameters: @@ -118,8 +121,19 @@ async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> In else: logger.warning("usage data not an expected parameter, omitting") try: - output = await invoke_func(func=func, kwargs=request_dict) - return InvokeResponse(usage=usage, status_code=status.HTTP_200_OK, output=output) + if inspect.isasyncgenfunction(func): + # Stream response if function is an async generator + + async def _stream_response(): + async for output in func(**(request_dict or {})): + yield InvokeResponse( + usage=usage, status_code=status.HTTP_200_OK, output=output + ).model_dump_json() + "\n" + + return StreamingResponse(_stream_response(), media_type="application/x-ndjson") + else: + output = await invoke_func(func=func, kwargs=request_dict) + return InvokeResponse(usage=usage, status_code=status.HTTP_200_OK, output=output) except Exception as invoke_error: logger.error(f"failed to invoke plugin: {invoke_error}", exc_info=True) return InvokeResponse( @@ -132,7 +146,7 @@ async def wrap_fn(func: Callable, kwargs: Optional[dict[str, Any]] = None) -> In if input_schema_model.model_fields: @fastapi_app.post("/invoke", response_model=InvokeResponse) - async def run_job(request: input_schema_model) -> InvokeResponse: + async def run_job(request: input_schema_model) -> ResponseType: log_func_and_body(func=func, body=request.json()) # Create dictionary from pydantic model while preserving underlying types request_dict = {f: getattr(request, f) for f in request.model_fields} @@ -144,7 +158,7 @@ async def run_job(request: input_schema_model) -> InvokeResponse: else: @fastapi_app.post("/invoke", response_model=InvokeResponse) - async def run_job() -> InvokeResponse: + async def run_job() -> ResponseType: log_func_and_body(func=func) return await wrap_fn( func=func,