Skip to content

Commit

Permalink
Add client response type validation and support for LLM clients (#1473)
Browse files Browse the repository at this point in the history
- Introduce `UnresolvedResponseType` and `ResponseType` enums in
`clientspec.rs`
- Add `ensure_client_response_type()` method in `helpers.rs` to validate
client response types
- Update OpenAI client to support custom response type configuration
- Add integration test for client response type validation
- Add test file for validating invalid response format in BAML
configuration

<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> Add client response type validation and support for LLM clients,
including enums, validation methods, client updates, tests, and
documentation.
> 
>   - **Behavior**:
> - Introduce `UnresolvedResponseType` and `ResponseType` enums in
`clientspec.rs`.
> - Add `ensure_client_response_type()` method in `helpers.rs` to
validate client response types.
> - Update OpenAI client in `openai.rs` to support custom response type
configuration.
>   - **Tests**:
> - Add integration test `test_client_response_type` in
`test_functions.py` for client response type validation.
> - Add `bad_response_format.baml` test file for invalid response format
validation.
>   - **Documentation**:
> - Add `client-response-type.mdx` snippet and include it in
`azure.mdx`, `openai-generic.mdx`, and `openai.mdx`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for 5dc801d. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
hellovai authored Feb 18, 2025
1 parent 5c6b213 commit 2987d59
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
client<llm> MyClient {
provider openai
options {
model "gpt-4o"
client_response_type "invalid"
}
}

client<llm> MyClient2 {
provider openai
options {
model "gpt-4o"
client_response_type "openai"
}
}

client<llm> MyClient3 {
provider openai
options {
model "gpt-4o"
client_response_type "anthropic"
}
}

// error: client_response_type must be one of "openai", "anthropic", "google", or "vertex". Got: invalid
// --> client/bad_response_format.baml:5
// |
// 4 | model "gpt-4o"
// 5 | client_response_type "invalid"
// |
33 changes: 31 additions & 2 deletions engine/baml-lib/llm-client/src/clients/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use baml_types::{GetEnvVar, StringOr, UnresolvedValue};
use indexmap::IndexMap;

use crate::{
SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter,
UnresolvedRolesSelection,
SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedResponseType, UnresolvedRolesSelection
};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -326,6 +325,36 @@ impl<Meta: Clone> PropertyHandler<Meta> {
UnresolvedAllowedRoleMetadata::None
}

pub fn ensure_client_response_type(&mut self) -> Option<UnresolvedResponseType> {
self.ensure_string("client_response_type", false)
.and_then(|(key_span, value, _)| {
if let StringOr::Value(value) = value {
return Some(match value.as_str() {
"openai" => UnresolvedResponseType::OpenAI,
"anthropic" => UnresolvedResponseType::Anthropic,
"google" => UnresolvedResponseType::Google,
"vertex" => UnresolvedResponseType::Vertex,
other => {
self.push_error(
format!(
"client_response_type must be one of \"openai\", \"anthropic\", \"google\", or \"vertex\". Got: {}",
other
),
key_span,
);
return None;
}
})
} else {
self.push_error(
"client_response_type must be one of \"openai\", \"anthropic\", \"google\", or \"vertex\" and not an environment variable",
key_span,
);
None
}
})
}

pub fn ensure_query_params(&mut self) -> Option<IndexMap<String, StringOr>> {
self.ensure_map("query_params", false).map(|(_, value, _)| {
value
Expand Down
9 changes: 7 additions & 2 deletions engine/baml-lib/llm-client/src/clients/openai.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::collections::HashSet;

use crate::{
AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes,
UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection,
AllowedRoleMetadata, FinishReasonFilter, ResponseType, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedResponseType, UnresolvedRolesSelection
};
use anyhow::Result;

Expand All @@ -22,6 +21,7 @@ pub struct UnresolvedOpenAI<Meta> {
properties: IndexMap<String, (Meta, UnresolvedValue<Meta>)>,
query_params: IndexMap<String, StringOr>,
finish_reason_filter: UnresolvedFinishReasonFilter,
client_response_type: Option<UnresolvedResponseType>,
}

impl<Meta> UnresolvedOpenAI<Meta> {
Expand All @@ -48,6 +48,7 @@ impl<Meta> UnresolvedOpenAI<Meta> {
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
finish_reason_filter: self.finish_reason_filter.clone(),
client_response_type: self.client_response_type.clone(),
}
}
}
Expand All @@ -63,6 +64,7 @@ pub struct ResolvedOpenAI {
pub query_params: IndexMap<String, String>,
pub proxy_url: Option<String>,
pub finish_reason_filter: FinishReasonFilter,
pub client_response_type: ResponseType,
}

impl ResolvedOpenAI {
Expand Down Expand Up @@ -221,6 +223,7 @@ impl<Meta: Clone> UnresolvedOpenAI<Meta> {
query_params,
proxy_url: super::helpers::get_proxy_url(ctx),
finish_reason_filter: self.finish_reason_filter.resolve(ctx)?,
client_response_type: self.client_response_type.as_ref().map_or(Ok(ResponseType::OpenAI), |v| v.resolve(ctx))?,
})
}

Expand Down Expand Up @@ -350,6 +353,7 @@ impl<Meta: Clone> UnresolvedOpenAI<Meta> {
let headers = properties.ensure_headers().unwrap_or_default();
let finish_reason_filter = properties.ensure_finish_reason_filter();
let query_params = properties.ensure_query_params().unwrap_or_default();
let client_response_type = properties.ensure_client_response_type();
let (properties, errors) = properties.finalize();

if !errors.is_empty() {
Expand All @@ -366,6 +370,7 @@ impl<Meta: Clone> UnresolvedOpenAI<Meta> {
properties,
query_params,
finish_reason_filter,
client_response_type,
})
}
}
32 changes: 32 additions & 0 deletions engine/baml-lib/llm-client/src/clientspec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,35 @@ impl AllowedRoleMetadata {
}
}
}


#[derive(Clone, Debug)]
pub enum UnresolvedResponseType {
OpenAI,
Anthropic,
Google,
Vertex,
}

#[derive(Clone, Debug)]
pub enum ResponseType {
OpenAI,
Anthropic,
Google,
Vertex,
}

impl UnresolvedResponseType {
pub fn required_env_vars(&self) -> HashSet<String> {
HashSet::new()
}

pub fn resolve(&self, _: &impl GetEnvVar) -> Result<ResponseType> {
match self {
Self::OpenAI => Ok(ResponseType::OpenAI),
Self::Anthropic => Ok(ResponseType::Anthropic),
Self::Google => Ok(ResponseType::Google),
Self::Vertex => Ok(ResponseType::Vertex),
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl WithChat for OpenAIClient {
model_name,
either::Either::Right(prompt),
false,
ResponseType::OpenAI,
self.properties.client_response_type.clone(),
)
.await
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
use anyhow::{Context, Result};
use baml_types::BamlMap;
use internal_baml_jinja::RenderedChatMessage;
pub use internal_llm_client::ResponseType;
use reqwest::Response;
use serde::de::DeserializeOwned;

Expand Down Expand Up @@ -119,13 +120,6 @@ pub async fn make_request(
Ok((response, system_now, instant_now))
}

pub enum ResponseType {
OpenAI,
Anthropic,
Google,
Vertex,
}

pub async fn make_parsed_request(
client: &(impl WithClient + RequestBuilder),
model_name: Option<String>,
Expand Down
2 changes: 2 additions & 0 deletions fern/03-reference/baml/clients/providers/azure.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ client<llm> MyClient {

<Markdown src="/snippets/finish-reason.mdx" />

<Markdown src="/snippets/client-response-type.mdx" />

## Provider request parameters
These are other `options` that are passed through to the provider, without modification by BAML. For example if the request has a `temperature` field, you can define it in the client here so every call has that set.

Expand Down
2 changes: 2 additions & 0 deletions fern/03-reference/baml/clients/providers/openai-generic.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ client<llm> MyClient {

<Markdown src="/snippets/finish-reason.mdx" />

<Markdown src="/snippets/client-response-type.mdx" />

## Provider request parameters
These are other parameters that are passed through to the provider, without modification by BAML. For example if the request has a `temperature` field, you can define it in the client here so every call has that set.

Expand Down
2 changes: 2 additions & 0 deletions fern/03-reference/baml/clients/providers/openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ client<llm> MyClient {

<Markdown src="/snippets/finish-reason.mdx" />

<Markdown src="/snippets/client-response-type.mdx" />

## Provider request parameters
These are other parameters that are passed through to the provider, without modification by BAML. For example if the request has a `temperature` field, you can define it in the client here so every call has that set.

Expand Down
13 changes: 13 additions & 0 deletions fern/snippets/client-response-type.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<ParamField path="client_response_type" type="openai | anthropic | google | vertex" default="openai">
<Warning>
Please let [us know on Discord](https://www.boundaryml.com/discord) if you have this use case! This is in alpha and we'd like to make sure we continue to cover your use cases.
</Warning>

The type of response to return from the client.

Sometimes you may expect a different response format than the provider default.
For example, using Azure you may be proxying to an endpoint that returns a different format than the OpenAI default.

**Default: `openai`**
</ParamField>

8 changes: 8 additions & 0 deletions integ-tests/python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,3 +1903,11 @@ async def test_semantic_streaming():

final = await stream.get_final_response()
print(final)

@pytest.mark.asyncio
async def test_client_response_type():
cr = baml_py.ClientRegistry()
cr.add_llm_client("temp_client", "openai", { "client_response_type": "anthropic", "model": "gpt-4o" })
cr.set_primary("temp_client")
with pytest.raises(errors.BamlClientError):
_ = await b.TestOpenAI("test", baml_options={ "client_registry": cr })

0 comments on commit 2987d59

Please sign in to comment.