Skip to content

Commit

Permalink
Fix Google AI system prompt JSON (#1374)
Browse files Browse the repository at this point in the history
<!-- ELLIPSIS_HIDDEN -->



> [!IMPORTANT]
> Fix JSON structure in Google AI client, add GeminiOpenAiGeneric
client, and introduce new test functions for Gemini in multiple files.
> 
>   - **Behavior**:
> - Fix JSON structure in `googleai_client.rs` by renaming
`system_instructions` to `system_instruction` and wrapping `parts` in a
JSON object.
>     - Rename `messages` to `contents` in `googleai_client.rs`.
>   - **Clients**:
> - Add `GeminiOpenAiGeneric` client in `clients.baml` with
`openai-generic` provider.
>   - **Functions**:
> - Add `TestGeminiSystemAsChat` and `TestGeminiOpenAiGeneric` functions
in `test-files/providers/gemini.baml`.
> - Add corresponding async and sync functions in `async_client.py`,
`sync_client.py`, and `client.rb`.
>   - **Tests**:
> - Add tests for `TestGeminiSystemAsChat` and `TestGeminiOpenAiGeneric`
in `test_functions.py`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for b2e08a4. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
antoniosarosi authored Jan 26, 2025
1 parent 1addb34 commit fe366fe
Show file tree
Hide file tree
Showing 12 changed files with 566 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,13 @@ impl ToProviderMessageExt for GoogleAIClient {
if let Some(content) = first.first() {
if content.role == "system" {
res.insert(
"system_instructions".into(),
json!(self.parts_to_message(&content.parts)?),
"system_instruction".into(),
json!({
"parts": self.parts_to_message(&content.parts)?
}),
);
res.insert(
"messages".into(),
"contents".into(),
others
.iter()
.map(|c| self.role_to_message(c))
Expand Down
9 changes: 9 additions & 0 deletions integ-tests/baml_src/clients.baml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ client<llm> Gemini {
}
}

client<llm> GeminiOpenAiGeneric {
provider "openai-generic"
options {
base_url "https://generativelanguage.googleapis.com/v1beta/"
model "gemini-1.5-flash"
api_key env.GOOGLE_API_KEY
}
}

client<llm> Vertex {
provider vertex-ai
options {
Expand Down
16 changes: 16 additions & 0 deletions integ-tests/baml_src/test-files/providers/gemini.baml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ function TestGeminiSystem(input: string) -> string {
"#
}

function TestGeminiSystemAsChat(input: string) -> string {
client Gemini
prompt #"
{{ _.role('system') }} You are a helpful assistant

{{_.role("user")}} Write a nice short story about {{ input }}
"#
}

function TestGeminiOpenAiGeneric() -> string {
client GeminiOpenAiGeneric
prompt #"{{_.role("system")}} You are a helpful assistant
{{_.role("user")}} Write a poem about llamas
"#
}

test TestName {
functions [TestGeminiSystem]
args {
Expand Down
105 changes: 105 additions & 0 deletions integ-tests/python/baml_client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2810,6 +2810,29 @@ async def TestGemini(
)
return cast(str, raw.cast_to(types, types))

async def TestGeminiOpenAiGeneric(
self,

baml_options: BamlCallOptions = {},
) -> str:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"TestGeminiOpenAiGeneric",
{

},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(str, raw.cast_to(types, types))

async def TestGeminiSystem(
self,
input: str,
Expand All @@ -2833,6 +2856,29 @@ async def TestGeminiSystem(
)
return cast(str, raw.cast_to(types, types))

async def TestGeminiSystemAsChat(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> str:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = await self.__runtime.call_function(
"TestGeminiSystemAsChat",
{
"input": input,
},
self.__ctx_manager.get(),
tb,
__cr__,
)
return cast(str, raw.cast_to(types, types))

async def TestImageInput(
self,
img: baml_py.Image,
Expand Down Expand Up @@ -6888,6 +6934,35 @@ def TestGemini(
self.__ctx_manager.get(),
)

def TestGeminiOpenAiGeneric(
self,

baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Optional[str], str]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"TestGeminiOpenAiGeneric",
{
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[Optional[str], str](
raw,
lambda x: cast(Optional[str], x.cast_to(types, partial_types)),
lambda x: cast(str, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def TestGeminiSystem(
self,
input: str,
Expand Down Expand Up @@ -6918,6 +6993,36 @@ def TestGeminiSystem(
self.__ctx_manager.get(),
)

def TestGeminiSystemAsChat(
self,
input: str,
baml_options: BamlCallOptions = {},
) -> baml_py.BamlStream[Optional[str], str]:
__tb__ = baml_options.get("tb", None)
if __tb__ is not None:
tb = __tb__._tb # type: ignore (we know how to use this private attribute)
else:
tb = None
__cr__ = baml_options.get("client_registry", None)

raw = self.__runtime.stream_function(
"TestGeminiSystemAsChat",
{
"input": input,
},
None,
self.__ctx_manager.get(),
tb,
__cr__,
)

return baml_py.BamlStream[Optional[str], str](
raw,
lambda x: cast(Optional[str], x.cast_to(types, partial_types)),
lambda x: cast(str, x.cast_to(types, types)),
self.__ctx_manager.get(),
)

def TestImageInput(
self,
img: baml_py.Image,
Expand Down
Loading

0 comments on commit fe366fe

Please sign in to comment.