Skip to content

Commit

Permalink
Create endpoint to publish the plugin id
Browse files Browse the repository at this point in the history
  • Loading branch information
rbiseck3 committed Jun 11, 2024
1 parent 5219271 commit 14f66cb
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 26 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ Any plugin must be published in a dedicated docker image with all required depen
on port 8000 with the required endpoints to interact with the Unstructured Platform product:
* `/invoke`: A `POST` endpoint which gets all data to run the underlying logic in the request body and expects a json serializable response.
* `/schema`: A `GET` endpoint which publishes a json schema formatted response with the schema of the input and output expected by the plugin.
* `/id`: A `GET` endpoint which publishes a string unique identifier for this instance of the plugin. Will default to a hash of the schema
response if one is not set explicitly.


## Utility CLI
Expand Down
8 changes: 8 additions & 0 deletions unstructured_platform_plugins/etl_uvicorn/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,11 @@ def sample_method(self, content: dict[str, Any]) -> SampleClassMethodResponse:


sample_class = SampleClass()

hash_value = "plugin_id_123"

hash_lambda_fn = lambda: "plugin_id_hash_123"


def get_hash() -> str:
return "plugin_id_fn_123"
121 changes: 95 additions & 26 deletions unstructured_platform_plugins/etl_uvicorn/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import hashlib
import inspect
import json
from dataclasses import is_dataclass
from pathlib import Path
from typing import Any, Callable, Optional
Expand All @@ -8,7 +10,7 @@
from fastapi import FastAPI
from pydantic import BaseModel
from uvicorn.importer import import_from_string
from uvicorn.main import LOGGING_CONFIG, logger, main, run
from uvicorn.main import LOGGING_CONFIG, main, run

from unstructured_platform_plugins.etl_uvicorn.json_schema import (
parameters_to_json_schema,
Expand All @@ -18,23 +20,43 @@


def get_func(instance: Any, method_name: Optional[str] = None) -> Callable:
method_name = method_name or "__call__"
if inspect.isfunction(instance):
return instance
elif inspect.isclass(instance):
i = instance()
method_name = method_name or "__call__"
return getattr(i, method_name)
elif isinstance(instance, object):
try:
method_name = method_name or "__call__"
func = getattr(instance, method_name)
if inspect.ismethod(func):
return func
except Exception as e:
logger.debug(f"attempt to call instantiated class failed: {e}")
elif isinstance(instance, object) and hasattr(instance, method_name):
func = getattr(instance, method_name)
if inspect.ismethod(func):
return func
raise ValueError(f"type of instance not recognized: {type(instance)}")


def get_plugin_id(instance: Any, method_name: Optional[str] = None) -> str:
method_name = method_name or "__call__"
ref_id = None
if inspect.isfunction(instance):
ref_id = instance()
elif inspect.isclass(instance):
i = instance()
method_name = method_name or "__call__"
fn = getattr(i, method_name)
ref_id = fn()
elif isinstance(instance, object) and hasattr(instance, method_name):
func = getattr(instance, method_name)
if inspect.ismethod(func):
ref_id = func()
else:
ref_id = instance
if not ref_id:
raise ValueError(f"id could not be parsed from instance {instance}")
ref_id = str(ref_id)
if not ref_id.isidentifier():
raise ValueError(f"'{ref_id}' is not a valid identifier")
return ref_id


def get_input_schema(func: Callable) -> dict:
sig = inspect.signature(func)
parameters = list(sig.parameters.values())
Expand All @@ -53,6 +75,13 @@ def get_output_schema(func: Callable) -> dict:
return response_to_json_schema(get_output_sig(func))


def get_schema_dict(func) -> dict:
return {
"inputs": get_input_schema(func),
"outputs": get_output_schema(func),
}


def map_inputs(func: Callable, raw_inputs: dict[str, Any]) -> dict[str, Any]:
input_params = {p.name: p for p in inspect.signature(func).parameters.values()}
for k, v in input_params.items():
Expand All @@ -64,9 +93,22 @@ def map_inputs(func: Callable, raw_inputs: dict[str, Any]) -> dict[str, Any]:
return raw_inputs


def generate_fast_api(app: str, method_name: Optional[str] = None) -> FastAPI:
def generate_fast_api(
app: str,
method_name: Optional[str] = None,
id_str: Optional[str] = None,
id_method: Optional[str] = None,
) -> FastAPI:
instance = import_from_string(app)
func = get_func(instance, method_name)
if id_str:
id_ref = import_from_string(id_str)
plugin_id = get_plugin_id(instance=id_ref, method_name=id_method)
else:
plugin_id = hashlib.sha256(
json.dumps(get_schema_dict(func), sort_keys=True).encode()
).hexdigest()[:32]

fastapi_app = FastAPI()

response_type = get_output_sig(func)
Expand All @@ -92,14 +134,20 @@ async def run_job(request: InputSchema) -> response_type:
else:
return func(**request_dict)

class SchemaOutputResponse(BaseModel):
inputs: dict[str, Any]
outputs: dict[str, Any]

@fastapi_app.get("/schema")
async def get_schema() -> dict[str, Any]:
resp = {
"inputs": get_input_schema(func),
"outputs": get_output_schema(func),
}
async def get_schema() -> SchemaOutputResponse:
schema = get_schema_dict(func)
resp = SchemaOutputResponse(inputs=schema["inputs"], outputs=schema["outputs"])
return resp

@fastapi_app.get("/id")
async def get_id() -> str:
return plugin_id

# Run initial schema validation
try:
asyncio.run(get_schema())
Expand All @@ -119,9 +167,13 @@ def api_wrapper(
reload_excludes: list[str],
headers: list[str],
method_name: Optional[str] = None,
plugin_id: Optional[str] = None,
plugin_id_method: Optional[str] = None,
**kwargs,
):
fastapi_app = generate_fast_api(app, method_name)
fastapi_app = generate_fast_api(
app, method_name, id_str=plugin_id, id_method=plugin_id_method
)
# Explicitly map values that are manipulated in the original
# call to run(), preventing **kwargs reference
run(
Expand All @@ -136,15 +188,32 @@ def api_wrapper(

cmd = api_wrapper
cmd.params = main.params
cmd.params.append(
click.Option(
["--method-name"],
required=False,
type=str,
default=None,
help="If passed in instance is a class, what method to wrap. "
"Will fall back to __call__ if none is provided.",
)
cmd.params.extend(
[
click.Option(
["--method-name"],
required=False,
type=str,
default=None,
help="If passed in instance is a class, what method to wrap. "
"Will fall back to __call__ if none is provided.",
),
click.Option(
["--plugin-id"],
required=False,
type=str,
default=None,
help="Reference to either a value or function to get the plugin id once instantiated",
),
click.Option(
["--plugin-id-method"],
required=False,
type=str,
default=None,
help="If plugin id reference is a class, what method to wrap. "
"Will fall back to __call__ if none is provided.",
),
]
)
return cmd

Expand Down

0 comments on commit 14f66cb

Please sign in to comment.