Skip to content

Commit

Permalink
Make literals nullable in generated python (#1334)
Browse files Browse the repository at this point in the history
Fixes an issue with streaming for classes containing literal types.
During streaming there may not be enough tokens to satisfy a field's
literal value. In that case, the field should be `null`, and this should
be acceptable for the partial type.

The old generated `partial_types.py` does not allow the field to be null
however, so pydantic validation fails to create a
`partial_types.ClassForNullLiteral`:
```
class ClassForNullLiteral(BaseModel):
    a: Literal["hi"] = None
```

The new generated `partial_types.py` after this PR allows literal fields
to be null:
```
class ClassForNullLiteral(BaseModel):
    a: Optional[Literal["hi"]] = None
```

PR adds one parser test and one integ test.
<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> This PR allows literal fields in generated Python code to be nullable,
updating code generation and adding tests to ensure correct handling of
null literals.
> 
>   - **Behavior**:
> - Allows literal fields to be nullable in `partial_types.py` by
changing `a: Literal["hi"] = None` to `a: Optional[Literal["hi"]] =
None`.
> - Updates `BamlSyncClient` and `BamlAsyncClient` to handle nullable
literals in `sync_client.ts` and `async_client.ts`.
>   - **Code Generation**:
>     - Modifies `generate_types.rs` to wrap literals with `Optional[]`.
>     - Updates `mod.rs` to reflect changes in literal handling.
>   - **Testing**:
> - Adds `test_partial_class_with_null_literal` in `test_literals.rs`.
>     - Introduces `literal-or-null.baml` for integration testing.
>     - Adds `test_null_literal_class_hello` in `test_functions.py`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 910ab4b. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
imalsogreg authored Jan 15, 2025
1 parent 98b7783 commit 68745d0
Show file tree
Hide file tree
Showing 21 changed files with 367 additions and 37 deletions.
12 changes: 12 additions & 0 deletions engine/baml-lib/jsonish/src/tests/test_literals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,15 @@ test_failing_deserializer!(
FieldType::Literal(LiteralValue::String("THREE".into())),
])
);

test_partial_deserializer!(
test_partial_class_with_null_literal,
r#"
class Foo {
bar "hello"
}
"#,
r#"{}"#,
FieldType::class("Foo"),
{ "bar": null }
);
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ impl ToTypeReferenceInTypeDefinition for FieldType {
format!("Optional[\"{name}\"]")
}
}
FieldType::Literal(value) => to_python_literal(value),
FieldType::Literal(value) => format!("Optional[{}]", to_python_literal(value)),
FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir, true)),
FieldType::Map(key, value) => {
format!(
Expand Down
2 changes: 1 addition & 1 deletion engine/language_client_codegen/src/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl ToTypeReferenceInClientDefinition for FieldType {
}
FieldType::Class(name) => format!("partial_types.{name}"),
FieldType::RecursiveTypeAlias(name) => format!("types.{name}"),
FieldType::Literal(value) => to_python_literal(value),
FieldType::Literal(value) => format!("Optional[{}]", to_python_literal(value)),
FieldType::List(inner) => {
format!("List[{}]", inner.to_partial_type_ref(ir, with_checked))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class ClassForNullLiteral {
a "hi"
}

function NullLiteralClassHello(s: string) -> ClassForNullLiteral {
client GPT35
prompt #"
Return the empty object: {}.
"#
}

test NullLiteralClassHello {
functions [NullLiteralClassHello]
args { s "unused" }
}
77 changes: 65 additions & 12 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,29 @@ async def NestedAlias(
)
return cast(Union[Union[int, str, bool, float], List[str], Dict[str, List[str]]], raw.cast_to(types, types))

async def NullLiteralClassHello(
self,
s: str,
baml_options: BamlCallOptions = {},
) -> types.ClassForNullLiteral:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"NullLiteralClassHello",
{
"s": s,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(types.ClassForNullLiteral, raw.cast_to(types, types))

async def OptionalTest_Function(
self,
input: str,
Expand Down Expand Up @@ -4691,7 +4714,7 @@ def FnOutputLiteralBool(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Literal[False], Literal[False]]:
) -> baml_py.BamlStream[Optional[Literal[False]], Literal[False]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
Expand All @@ -4710,9 +4733,9 @@ def FnOutputLiteralBool(
__cr__,
)

return baml_py.BamlStream[Literal[False], Literal[False]](
return baml_py.BamlStream[Optional[Literal[False]], Literal[False]](
raw,
lambda x: cast(Literal[False], x.cast_to(types, partial_types)),
lambda x: cast(Optional[Literal[False]], x.cast_to(types, partial_types)),
lambda x: cast(Literal[False], x.cast_to(types, types)),
self.__ctx_manager.get(),
)
Expand All @@ -4721,7 +4744,7 @@ def FnOutputLiteralInt(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Literal[5], Literal[5]]:
) -> baml_py.BamlStream[Optional[Literal[5]], Literal[5]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
Expand All @@ -4740,9 +4763,9 @@ def FnOutputLiteralInt(
__cr__,
)

return baml_py.BamlStream[Literal[5], Literal[5]](
return baml_py.BamlStream[Optional[Literal[5]], Literal[5]](
raw,
lambda x: cast(Literal[5], x.cast_to(types, partial_types)),
lambda x: cast(Optional[Literal[5]], x.cast_to(types, partial_types)),
lambda x: cast(Literal[5], x.cast_to(types, types)),
self.__ctx_manager.get(),
)
Expand All @@ -4751,7 +4774,7 @@ def FnOutputLiteralString(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Literal["example output"], Literal["example output"]]:
) -> baml_py.BamlStream[Optional[Literal["example output"]], Literal["example output"]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
Expand All @@ -4770,9 +4793,9 @@ def FnOutputLiteralString(
__cr__,
)

return baml_py.BamlStream[Literal["example output"], Literal["example output"]](
return baml_py.BamlStream[Optional[Literal["example output"]], Literal["example output"]](
raw,
lambda x: cast(Literal["example output"], x.cast_to(types, partial_types)),
lambda x: cast(Optional[Literal["example output"]], x.cast_to(types, partial_types)),
lambda x: cast(Literal["example output"], x.cast_to(types, types)),
self.__ctx_manager.get(),
)
Expand Down Expand Up @@ -5113,7 +5136,7 @@ def LiteralUnionsTest(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Optional[Union[Literal[1], Literal[True], Literal["string output"]]], Union[Literal[1], Literal[True], Literal["string output"]]]:
) -> baml_py.BamlStream[Optional[Union[Optional[Literal[1]], Optional[Literal[True]], Optional[Literal["string output"]]]], Union[Literal[1], Literal[True], Literal["string output"]]]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
Expand All @@ -5132,9 +5155,9 @@ def LiteralUnionsTest(
__cr__,
)

return baml_py.BamlStream[Optional[Union[Literal[1], Literal[True], Literal["string output"]]], Union[Literal[1], Literal[True], Literal["string output"]]](
return baml_py.BamlStream[Optional[Union[Optional[Literal[1]], Optional[Literal[True]], Optional[Literal["string output"]]]], Union[Literal[1], Literal[True], Literal["string output"]]](
raw,
lambda x: cast(Optional[Union[Literal[1], Literal[True], Literal["string output"]]], x.cast_to(types, partial_types)),
lambda x: cast(Optional[Union[Optional[Literal[1]], Optional[Literal[True]], Optional[Literal["string output"]]]], x.cast_to(types, partial_types)),
lambda x: cast(Union[Literal[1], Literal[True], Literal["string output"]], x.cast_to(types, types)),
self.__ctx_manager.get(),
)
Expand Down Expand Up @@ -5317,6 +5340,36 @@ def NestedAlias(
self.__ctx_manager.get(),
)

def NullLiteralClassHello(
self,
s: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[partial_types.ClassForNullLiteral, types.ClassForNullLiteral]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"NullLiteralClassHello",
{
"s": s,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[partial_types.ClassForNullLiteral, types.ClassForNullLiteral](
raw,
lambda x: cast(partial_types.ClassForNullLiteral, x.cast_to(types, partial_types)),
lambda x: cast(types.ClassForNullLiteral, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def OptionalTest_Function(
self,
input: str,
Expand Down
1 change: 1 addition & 0 deletions integ-tests/python/baml_client/inlinedbaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"test-files/functions/output/int.baml": "function FnOutputInt(input: string) -> int {\n client GPT35\n prompt #\"\n Return the integer 5 with no additional context.\n \"#\n}\n\ntest FnOutputInt {\n functions [FnOutputInt]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/literal-boolean.baml": "function FnOutputLiteralBool(input: string) -> false {\n client GPT35\n prompt #\"\n Return a false: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralBool {\n functions [FnOutputLiteralBool]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/literal-int.baml": "function FnOutputLiteralInt(input: string) -> 5 {\n client GPT35\n prompt #\"\n Return an integer: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralInt {\n functions [FnOutputLiteralInt]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/literal-or-null.baml": "class ClassForNullLiteral {\n a \"hi\"\n}\n\nfunction NullLiteralClassHello(s: string) -> ClassForNullLiteral {\n client GPT35\n prompt #\"\n Return the empty object: {}.\n \"#\n}\n\ntest NullLiteralClassHello {\n functions [NullLiteralClassHello]\n args { s \"unused\" }\n}",
"test-files/functions/output/literal-string.baml": "function FnOutputLiteralString(input: string) -> \"example output\" {\n client GPT35\n prompt #\"\n Return a string: {{ ctx.output_format}}\n \"#\n}\n\ntest FnOutputLiteralString {\n functions [FnOutputLiteralString]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/literal-unions.baml": "function LiteralUnionsTest(input: string) -> 1 | true | \"string output\" {\n client GPT35\n prompt #\"\n Return one of these values without any additional context: \n {{ctx.output_format}}\n \"#\n}\n\ntest LiteralUnionsTest {\n functions [LiteralUnionsTest]\n args {\n input \"example input\"\n }\n}\n",
"test-files/functions/output/map-enum-key.baml": "enum MapKey {\n A\n B\n C\n}\n\nfunction InOutEnumMapKey(i1: map<MapKey, string>, i2: map<MapKey, string>) -> map<MapKey, string> {\n client \"openai/gpt-4o\"\n prompt #\"\n Merge these: {{i1}} {{i2}}\n\n {{ ctx.output_format }}\n \"#\n}\n",
Expand Down
13 changes: 8 additions & 5 deletions integ-tests/python/baml_client/partial_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class BookOrder(BaseModel):
quantity: Optional[int] = None
price: Optional[float] = None

class ClassForNullLiteral(BaseModel):
a: Optional[Literal["hi"]] = None

class ClassOptionalOutput(BaseModel):
prop1: Optional[str] = None
prop2: Optional[str] = None
Expand Down Expand Up @@ -197,13 +200,13 @@ class LinkedListAliasNode(BaseModel):
next: Optional["LinkedListAliasNode"] = None

class LiteralClassHello(BaseModel):
prop: Literal["hello"]
prop: Optional[Literal["hello"]] = None

class LiteralClassOne(BaseModel):
prop: Literal["one"]
prop: Optional[Literal["one"]] = None

class LiteralClassTwo(BaseModel):
prop: Literal["two"]
prop: Optional[Literal["two"]] = None

class MalformedConstraints(BaseModel):
foo: Checked[Optional[int],Literal["foo_check"]]
Expand Down Expand Up @@ -293,7 +296,7 @@ class RaysData(BaseModel):
class ReceiptInfo(BaseModel):
items: List["ReceiptItem"]
total_cost: Optional[float] = None
venue: Optional[Union[Literal["barisa"], Literal["ox_burger"]]] = None
venue: Optional[Union[Optional[Literal["barisa"]], Optional[Literal["ox_burger"]]]] = None

class ReceiptItem(BaseModel):
name: Optional[str] = None
Expand All @@ -303,7 +306,7 @@ class ReceiptItem(BaseModel):

class Recipe(BaseModel):
ingredients: Dict[str, Optional["Quantity"]]
recipe_type: Optional[Union[Literal["breakfast"], Literal["dinner"]]] = None
recipe_type: Optional[Union[Optional[Literal["breakfast"]], Optional[Literal["dinner"]]]] = None

class Resume(BaseModel):
name: Optional[str] = None
Expand Down
Loading

0 comments on commit 68745d0

Please sign in to comment.