diff --git a/engine/baml-lib/baml/tests/validation_files/client/bad_response_format.baml b/engine/baml-lib/baml/tests/validation_files/client/bad_response_format.baml new file mode 100644 index 000000000..0ab026876 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/client/bad_response_format.baml @@ -0,0 +1,30 @@ +client MyClient { + provider openai + options { + model "gpt-4o" + client_response_type "invalid" + } +} + +client MyClient2 { + provider openai + options { + model "gpt-4o" + client_response_type "openai" + } +} + +client MyClient3 { + provider openai + options { + model "gpt-4o" + client_response_type "anthropic" + } +} + +// error: client_response_type must be one of "openai", "anthropic", "google", or "vertex". Got: invalid +// --> client/bad_response_format.baml:5 +// | +// 4 | model "gpt-4o" +// 5 | client_response_type "invalid" +// | diff --git a/engine/baml-lib/llm-client/src/clients/helpers.rs b/engine/baml-lib/llm-client/src/clients/helpers.rs index 610d9fe3c..412776eb8 100644 --- a/engine/baml-lib/llm-client/src/clients/helpers.rs +++ b/engine/baml-lib/llm-client/src/clients/helpers.rs @@ -4,8 +4,7 @@ use baml_types::{GetEnvVar, StringOr, UnresolvedValue}; use indexmap::IndexMap; use crate::{ - SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, - UnresolvedRolesSelection, + SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedResponseType, UnresolvedRolesSelection }; #[derive(Debug, Clone)] @@ -326,6 +325,36 @@ impl PropertyHandler { UnresolvedAllowedRoleMetadata::None } + pub fn ensure_client_response_type(&mut self) -> Option { + self.ensure_string("client_response_type", false) + .and_then(|(key_span, value, _)| { + if let StringOr::Value(value) = value { + return Some(match value.as_str() { + "openai" => UnresolvedResponseType::OpenAI, + "anthropic" => UnresolvedResponseType::Anthropic, + "google" => UnresolvedResponseType::Google, + "vertex" => UnresolvedResponseType::Vertex, + other => { + self.push_error( + format!( + "client_response_type must be one of \"openai\", \"anthropic\", \"google\", or \"vertex\". Got: {}", + other + ), + key_span, + ); + return None; + } + }) + } else { + self.push_error( + "client_response_type must be one of \"openai\", \"anthropic\", \"google\", or \"vertex\" and not an environment variable", + key_span, + ); + None + } + }) + } + pub fn ensure_query_params(&mut self) -> Option> { self.ensure_map("query_params", false).map(|(_, value, _)| { value diff --git a/engine/baml-lib/llm-client/src/clients/openai.rs b/engine/baml-lib/llm-client/src/clients/openai.rs index e1b28b1e7..9508cf96b 100644 --- a/engine/baml-lib/llm-client/src/clients/openai.rs +++ b/engine/baml-lib/llm-client/src/clients/openai.rs @@ -1,8 +1,7 @@ use std::collections::HashSet; use crate::{ - AllowedRoleMetadata, FinishReasonFilter, RolesSelection, SupportedRequestModes, - UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedRolesSelection, + AllowedRoleMetadata, FinishReasonFilter, ResponseType, RolesSelection, SupportedRequestModes, UnresolvedAllowedRoleMetadata, UnresolvedFinishReasonFilter, UnresolvedResponseType, UnresolvedRolesSelection }; use anyhow::Result; @@ -22,6 +21,7 @@ pub struct UnresolvedOpenAI { properties: IndexMap)>, query_params: IndexMap, finish_reason_filter: UnresolvedFinishReasonFilter, + client_response_type: Option, } impl UnresolvedOpenAI { @@ -48,6 +48,7 @@ impl UnresolvedOpenAI { .map(|(k, v)| (k.clone(), v.clone())) .collect(), finish_reason_filter: self.finish_reason_filter.clone(), + client_response_type: self.client_response_type.clone(), } } } @@ -63,6 +64,7 @@ pub struct ResolvedOpenAI { pub query_params: IndexMap, pub proxy_url: Option, pub finish_reason_filter: FinishReasonFilter, + pub client_response_type: ResponseType, } impl ResolvedOpenAI { @@ -221,6 +223,7 @@ impl UnresolvedOpenAI { query_params, proxy_url: super::helpers::get_proxy_url(ctx), finish_reason_filter: self.finish_reason_filter.resolve(ctx)?, + client_response_type: self.client_response_type.as_ref().map_or(Ok(ResponseType::OpenAI), |v| v.resolve(ctx))?, }) } @@ -350,6 +353,7 @@ impl UnresolvedOpenAI { let headers = properties.ensure_headers().unwrap_or_default(); let finish_reason_filter = properties.ensure_finish_reason_filter(); let query_params = properties.ensure_query_params().unwrap_or_default(); + let client_response_type = properties.ensure_client_response_type(); let (properties, errors) = properties.finalize(); if !errors.is_empty() { @@ -366,6 +370,7 @@ impl UnresolvedOpenAI { properties, query_params, finish_reason_filter, + client_response_type, }) } } diff --git a/engine/baml-lib/llm-client/src/clientspec.rs b/engine/baml-lib/llm-client/src/clientspec.rs index 28d3dbaab..3f1f2e2ac 100644 --- a/engine/baml-lib/llm-client/src/clientspec.rs +++ b/engine/baml-lib/llm-client/src/clientspec.rs @@ -406,3 +406,35 @@ impl AllowedRoleMetadata { } } } + + +#[derive(Clone, Debug)] +pub enum UnresolvedResponseType { + OpenAI, + Anthropic, + Google, + Vertex, +} + +#[derive(Clone, Debug)] +pub enum ResponseType { + OpenAI, + Anthropic, + Google, + Vertex, +} + +impl UnresolvedResponseType { + pub fn required_env_vars(&self) -> HashSet { + HashSet::new() + } + + pub fn resolve(&self, _: &impl GetEnvVar) -> Result { + match self { + Self::OpenAI => Ok(ResponseType::OpenAI), + Self::Anthropic => Ok(ResponseType::Anthropic), + Self::Google => Ok(ResponseType::Google), + Self::Vertex => Ok(ResponseType::Vertex), + } + } +} diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs b/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs index 8c8dfccb4..d4bc5aa3a 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/openai/openai_client.rs @@ -99,7 +99,7 @@ impl WithChat for OpenAIClient { model_name, either::Either::Right(prompt), false, - ResponseType::OpenAI, + self.properties.client_response_type.clone(), ) .await } diff --git a/engine/baml-runtime/src/internal/llm_client/primitive/request.rs b/engine/baml-runtime/src/internal/llm_client/primitive/request.rs index 772c940aa..102eac1a5 100644 --- a/engine/baml-runtime/src/internal/llm_client/primitive/request.rs +++ b/engine/baml-runtime/src/internal/llm_client/primitive/request.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use anyhow::{Context, Result}; use baml_types::BamlMap; use internal_baml_jinja::RenderedChatMessage; +pub use internal_llm_client::ResponseType; use reqwest::Response; use serde::de::DeserializeOwned; @@ -119,13 +120,6 @@ pub async fn make_request( Ok((response, system_now, instant_now)) } -pub enum ResponseType { - OpenAI, - Anthropic, - Google, - Vertex, -} - pub async fn make_parsed_request( client: &(impl WithClient + RequestBuilder), model_name: Option, diff --git a/fern/03-reference/baml/clients/providers/azure.mdx b/fern/03-reference/baml/clients/providers/azure.mdx index 5ad841631..6015b3a01 100644 --- a/fern/03-reference/baml/clients/providers/azure.mdx +++ b/fern/03-reference/baml/clients/providers/azure.mdx @@ -99,6 +99,8 @@ client MyClient { + + ## Provider request parameters These are other `options` that are passed through to the provider, without modification by BAML. For example if the request has a `temperature` field, you can define it in the client here so every call has that set. diff --git a/fern/03-reference/baml/clients/providers/openai-generic.mdx b/fern/03-reference/baml/clients/providers/openai-generic.mdx index 7b93e180c..5226262df 100644 --- a/fern/03-reference/baml/clients/providers/openai-generic.mdx +++ b/fern/03-reference/baml/clients/providers/openai-generic.mdx @@ -66,6 +66,8 @@ client MyClient { + + ## Provider request parameters These are other parameters that are passed through to the provider, without modification by BAML. For example if the request has a `temperature` field, you can define it in the client here so every call has that set. diff --git a/fern/03-reference/baml/clients/providers/openai.mdx b/fern/03-reference/baml/clients/providers/openai.mdx index 5bb0341b0..7bf787b85 100644 --- a/fern/03-reference/baml/clients/providers/openai.mdx +++ b/fern/03-reference/baml/clients/providers/openai.mdx @@ -72,6 +72,8 @@ client MyClient { + + ## Provider request parameters These are other parameters that are passed through to the provider, without modification by BAML. For example if the request has a `temperature` field, you can define it in the client here so every call has that set. diff --git a/fern/snippets/client-response-type.mdx b/fern/snippets/client-response-type.mdx new file mode 100644 index 000000000..6c714980b --- /dev/null +++ b/fern/snippets/client-response-type.mdx @@ -0,0 +1,13 @@ + + + Please let [us know on Discord](https://www.boundaryml.com/discord) if you have this use case! This is in alpha and we'd like to make sure we continue to cover your use cases. + + + The type of response to return from the client. + + Sometimes you may expect a different response format than the provider default. + For example, using Azure you may be proxying to an endpoint that returns a different format than the OpenAI default. + + **Default: `openai`** + + diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index d7c2f900a..32c85162a 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -1903,3 +1903,11 @@ async def test_semantic_streaming(): final = await stream.get_final_response() print(final) + +@pytest.mark.asyncio +async def test_client_response_type(): + cr = baml_py.ClientRegistry() + cr.add_llm_client("temp_client", "openai", { "client_response_type": "anthropic", "model": "gpt-4o" }) + cr.set_primary("temp_client") + with pytest.raises(errors.BamlClientError): + _ = await b.TestOpenAI("test", baml_options={ "client_registry": cr })