Skip to content

Commit

Permalink
feat/enum and forward ref support (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
rbiseck3 authored Jul 2, 2024
1 parent 64a8f9f commit b7fb254
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 17 deletions.
113 changes: 107 additions & 6 deletions test/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,70 @@
import inspect
from dataclasses import dataclass
from typing import Any, Optional, TypedDict, Union
from enum import Enum
from typing import Any, Optional, TypedDict, Union, get_type_hints

import pytest
from pydantic import BaseModel

import unstructured_platform_plugins.schema.json_schema as js
from unstructured_platform_plugins.etl_uvicorn.utils import get_input_schema
from unstructured_platform_plugins.schema.model import is_valid_input_dict, is_valid_response_dict
from unstructured_platform_plugins.schema.utils import get_types_parameters


def test_string_enum_fn():
class StringEnum(Enum):
FIRST = "first"
SECOND = "second"
THIRD = "third"

def fn(input: StringEnum) -> None:
pass

sig = inspect.signature(fn)
input_schema = js.parameters_to_json_schema(parameters=list(sig.parameters.values()))
expected_schema = {
"type": "object",
"required": ["input"],
"properties": {"input": {"type": "string", "enum": ["first", "second", "third"]}},
}

assert input_schema == expected_schema
assert is_valid_response_dict(input_schema)


def test_int_enum_fn():
class IntEnum(Enum):
FIRST = 1
SECOND = 2
THIRD = 3

def fn(input: IntEnum) -> None:
pass

sig = inspect.signature(fn)
input_schema = js.parameters_to_json_schema(parameters=list(sig.parameters.values()))
expected_schema = {
"type": "object",
"required": ["input"],
"properties": {"input": {"type": "integer", "enum": [1, 2, 3]}},
}
assert input_schema == expected_schema
assert is_valid_response_dict(input_schema)


def test_mixed_enum_fn():
class MixedEnum(Enum):
FIRST = 1
SECOND = "second"
THIRD = 3

def fn(input: MixedEnum) -> None:
pass

sig = inspect.signature(fn)
with pytest.raises(ValueError):
js.parameters_to_json_schema(parameters=list(sig.parameters.values()))


def test_blank_fn():
Expand Down Expand Up @@ -172,7 +229,7 @@ def fn(g: Input) -> Optional[Output]:

def test_pydantic_base_model():
class Input(BaseModel):
x: int
x: "int"
y: Optional[str] = None

class Output(BaseModel):
Expand Down Expand Up @@ -223,7 +280,7 @@ def fn(g: Input) -> Optional[Output]:

def test_typed_dict():
class Input(TypedDict):
x: int
x: "int"
y: Optional[str]

class Output(TypedDict):
Expand Down Expand Up @@ -335,13 +392,19 @@ def fn(q: InputC) -> None:


def test_schema_to_base_model():
class g_enum(Enum):
FIRST = "first"
SECOND = "second"
THIRD = "third"

def fn(
a: int,
b: float | int = 4,
c: str | None = "my_string",
d: bool = False,
e: Optional[dict[str, Any]] = None,
f: list[float] = None,
g: Optional[g_enum] = None,
) -> None:
pass

Expand All @@ -352,13 +415,51 @@ class ExpectedInputModel(BaseModel):
d: bool = False
e: Optional[dict[str, Any]] = None
f: list[float] = None
g: Optional[g_enum] = None

input_schema = get_input_schema(fn)
input_model = js.schema_to_base_model(schema=input_schema)
input_model_schema = input_model.model_json_schema()
expected_model_schema = ExpectedInputModel.model_json_schema()
expected_model_schema["title"] = "reconstructed_model"
print()
print(input_model_schema)
print(expected_model_schema)
assert input_model_schema == expected_model_schema


# These need to be defined outside the code of test_forward_reference_typing
# for references to resolve:
@dataclass
class MockInputClass:
a: str


@dataclass
class MockOutputClass:
b: bool


def test_forward_reference_typing():

def fn(a: "MockInputClass") -> "MockOutputClass":
pass

parameters = get_types_parameters(fn)
input_schema = js.parameters_to_json_schema(parameters=parameters)
expected_input_schema = {
"type": "object",
"required": ["a"],
"properties": {
"a": {"type": "object", "properties": {"a": {"type": "string"}}, "required": ["a"]}
},
}
assert input_schema == expected_input_schema
assert is_valid_input_dict(input_schema)

return_annotation = get_type_hints(fn)["return"]
output_schema = js.response_to_json_schema(return_annotation=return_annotation)
expected_output_schema = {
"type": "object",
"properties": {"b": {"type": "boolean"}},
"required": ["b"],
}
assert output_schema == expected_output_schema
assert is_valid_response_dict(output_schema)
5 changes: 5 additions & 0 deletions unstructured_platform_plugins/etl_uvicorn/api_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from fastapi import FastAPI
from pydantic import BaseModel
from starlette.responses import RedirectResponse
from uvicorn.importer import import_from_string

from unstructured_platform_plugins.etl_uvicorn.utils import (
Expand Down Expand Up @@ -75,6 +76,10 @@ class SchemaOutputResponse(BaseModel):
inputs: dict[str, Any]
outputs: dict[str, Any]

@fastapi_app.get("/", include_in_schema=False)
async def docs_redirect():
return RedirectResponse("/docs")

@fastapi_app.get("/schema")
async def get_schema() -> SchemaOutputResponse:
schema = get_schema_dict(func)
Expand Down
15 changes: 8 additions & 7 deletions unstructured_platform_plugins/etl_uvicorn/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import inspect
from dataclasses import is_dataclass
from typing import Any, Callable, Optional
from types import NoneType
from typing import Any, Callable, Optional, get_type_hints

from pydantic import BaseModel

from unstructured_platform_plugins.schema.json_schema import (
parameters_to_json_schema,
response_to_json_schema,
)
from unstructured_platform_plugins.schema.utils import get_types_parameters


def get_func(instance: Any, method_name: Optional[str] = None) -> Callable:
Expand Down Expand Up @@ -49,16 +51,15 @@ def get_plugin_id(instance: Any, method_name: Optional[str] = None) -> str:


def get_input_schema(func: Callable) -> dict:
sig = inspect.signature(func)
parameters = list(sig.parameters.values())
parameters = get_types_parameters(func)
return parameters_to_json_schema(parameters)


def get_output_sig(func: Callable) -> Optional[Any]:
sig = inspect.signature(func)
outputs = (
sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None
)
inspect.signature(func)
type_hints = get_type_hints(func)
return_typing = type_hints["return"]
outputs = return_typing if return_typing is not NoneType else None
return outputs


Expand Down
46 changes: 42 additions & 4 deletions unstructured_platform_plugins/schema/json_schema.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import inspect
from dataclasses import MISSING, fields, is_dataclass
from enum import Enum, EnumMeta, EnumType
from inspect import Parameter
from pathlib import Path
from types import GenericAlias, NoneType, UnionType
from typing import Any, Optional, Type, Union
from typing import Any, Optional, Type, Union, get_type_hints

from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo, PydanticUndefined

from unstructured_platform_plugins.schema.utils import TypedParameter

# https://json-schema.org/understanding-json-schema/reference/type
types_map: dict[Type, str] = {
str: "string",
Expand Down Expand Up @@ -48,6 +51,21 @@ def path_to_json_schema(path: Path) -> dict:
return {"type": "string", "is_path": True}


def enum_to_json_schema(e: EnumMeta) -> dict:
values = [i.value for i in e]
value_types = [type(value) for value in values]
unique_value_types = list(set(value_types))
if len(unique_value_types) > 1:
raise ValueError(
"enum must have consistent types, found mixes: {}".format(
", ".join([e.__name__ for e in unique_value_types])
)
)
value_types = unique_value_types[0]
type_string = types_map[value_types]
return {"type": type_string, "enum": values}


def generic_alias_to_json_schema(t: GenericAlias) -> dict:
origin = t.__origin__
if origin is Union:
Expand All @@ -74,8 +92,9 @@ def dataclass_to_json_schema(class_or_instance) -> dict:
return resp
properties = {}
required = []
type_hints = get_type_hints(class_or_instance)
for f in fs:
t = f.type
t = type_hints[f.name]
f_resp = to_json_schema(t)
if f.default is not MISSING:
f_resp["default"] = f.default
Expand Down Expand Up @@ -114,7 +133,9 @@ def typed_dict_to_json_schem(typed_dict_class) -> dict:
return resp
properties = {}
required = []
for name, t in fs.items():
type_hints = get_type_hints(typed_dict_class)
for name in fs:
t = type_hints[name]
f_resp = to_json_schema(t)
properties[name] = f_resp
required.append(name)
Expand All @@ -131,15 +152,27 @@ def parameter_to_json_schema(parameter: Parameter) -> dict:
return resp


def typed_parameter_to_json_schema(parameter: TypedParameter) -> dict:
param_type = parameter.param_type
resp = to_json_schema(param_type)
if parameter.default != Parameter.empty:
resp["default"] = parameter.default
return resp


def to_json_schema(val: Any) -> dict:
if val in [None, NoneType]:
return {"type": "null"}
if val is Any:
return {}
if isinstance(val, TypedParameter):
return typed_parameter_to_json_schema(parameter=val)
if isinstance(val, Parameter):
return parameter_to_json_schema(parameter=val)
if isinstance(val, UnionType):
return union_type_to_json_schema(t=val)
if isinstance(val, EnumType):
return enum_to_json_schema(e=val)
if is_generic_alias(val=val):
return generic_alias_to_json_schema(t=val)
if val is Type:
Expand Down Expand Up @@ -175,7 +208,9 @@ def run_input_checks(parameters: list[Parameter]):
)


def parameters_to_json_schema(parameters: list[Parameter]) -> dict:
def parameters_to_json_schema(
parameters: list[Parameter], type_hints: Optional[dict[str, Type]] = None
) -> dict:
run_input_checks(parameters=parameters)
resp = {"type": "object"}
properties = {}
Expand Down Expand Up @@ -247,6 +282,9 @@ def schema_to_base_model_type(json_type_name, name: str, type_info: dict) -> Typ
json_type_name=item_type_name, name=f"{name}_type", type_info=items
)
t = list[subtype]
if "enum" in type_info and isinstance(type_info["enum"], list):
enum_content = type_info["enum"]
t = Enum(f"{name}_enum", {v: v for v in enum_content})
return t


Expand Down
26 changes: 26 additions & 0 deletions unstructured_platform_plugins/schema/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import inspect
from inspect import Parameter, _empty
from typing import Callable, get_type_hints


class TypedParameter(Parameter):
def __init__(self, *args, param_type=_empty, **kwargs):
super().__init__(*args, **kwargs)
self.param_type = param_type

@classmethod
def from_paramaeter(cls, param: Parameter) -> "TypedParameter":
return cls(
name=param.name, default=param.default, annotation=param.annotation, kind=param.kind
)


def get_types_parameters(fn: Callable) -> list[TypedParameter]:
type_hints = get_type_hints(fn)
parameters = list(inspect.signature(fn).parameters.values())
typed_params = []
for p in parameters:
typed_param = TypedParameter.from_paramaeter(param=p)
typed_param.param_type = type_hints[typed_param.name]
typed_params.append(typed_param)
return typed_params

0 comments on commit b7fb254

Please sign in to comment.