Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Prompt and Parser separately #1505

Draft
wants to merge 1 commit into
base: canary
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions engine/baml-lib/jinja-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ impl std::fmt::Display for RenderedPrompt {
message
.parts
.iter()
.map(|p| p.to_string())
.map(ChatMessagePart::to_string)
.collect::<Vec<String>>()
.join("")
)?;
Expand Down Expand Up @@ -499,7 +499,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -548,7 +548,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -619,7 +619,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -675,7 +675,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -761,7 +761,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -817,7 +817,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -857,7 +857,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -897,7 +897,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -937,7 +937,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -980,7 +980,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -1046,7 +1046,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -1118,7 +1118,6 @@ mod render_tests {
Ok(())
}


#[test]
fn render_with_kwargs_default_role() -> anyhow::Result<()> {
setup_logging();
Expand All @@ -1131,7 +1130,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -1215,7 +1214,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -1276,7 +1275,7 @@ mod render_tests {
let ir = make_test_ir(
"
class C {

}
",
)?;
Expand Down Expand Up @@ -1846,7 +1845,7 @@ mod render_tests {
let ir = make_test_ir(
r#"
class A {
a_prop1 string
a_prop1 string
a_prop2 B[] @alias("alias_a_prop2")
}

Expand Down
13 changes: 13 additions & 0 deletions engine/baml-runtime/src/internal/llm_client/llm_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;

use anyhow::Result;
use internal_baml_core::ir::ClientWalker;
use internal_baml_jinja::RenderedChatMessage;

use crate::{
client_registry::ClientProperty, runtime_interface::InternalClientLookup, RuntimeContext,
Expand Down Expand Up @@ -31,6 +32,18 @@ impl std::fmt::Debug for LLMProvider {
}
}

impl LLMProvider {
pub fn chat_to_message(
&self,
chat: &[RenderedChatMessage],
) -> Result<serde_json::Map<String, serde_json::Value>> {
match self {
LLMProvider::Primitive(provider) => provider.chat_to_message(chat),
LLMProvider::Strategy(provider) => todo!("Strategy provider"),
}
}
}

impl WithRetryPolicy for LLMProvider {
fn retry_policy_name(&self) -> Option<&str> {
match self {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,11 @@ impl AwsClient {
// Exposing the secret key here is relatively safe. First, we expose it only
// to check if it starts with $. If so, the remainer should be an env
// var name, which is also safe to expose.
if aws_secret_access_key.api_key.expose_secret().starts_with("$") {
if aws_secret_access_key
.api_key
.expose_secret()
.starts_with("$")
{
return Err(anyhow::anyhow!(
"AWS secret access key expected, please set: env.{}",
&aws_secret_access_key.api_key.expose_secret()[1..]
Expand Down
26 changes: 23 additions & 3 deletions engine/baml-runtime/src/internal/llm_client/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;
use anyhow::Result;
use baml_types::{BamlMap, BamlValue};
use internal_baml_core::ir::{repr::IntermediateRepr, ClientWalker};
use internal_baml_jinja::RenderedChatMessage;
use internal_llm_client::{AllowedRoleMetadata, ClientProvider, OpenAIClientProviderVariant};

use crate::{
Expand All @@ -21,8 +22,8 @@ use super::{
OrchestratorNodeIterator,
},
traits::{
WithClient, WithClientProperties, WithPrompt, WithRenderRawCurl, WithRetryPolicy,
WithSingleCallable, WithStreamable,
ToProviderMessage, WithClient, WithClientProperties, WithPrompt, WithRenderRawCurl,
WithRetryPolicy, WithSingleCallable, WithStreamable,
},
LLMResponse,
};
Expand All @@ -32,8 +33,8 @@ mod aws;
mod google;
mod openai;
pub(super) mod request;
mod vertex;
mod stream_request;
mod vertex;

// use crate::internal::llm_client::traits::ambassador_impl_WithRenderRawCurl;
// use crate::internal::llm_client::traits::ambassador_impl_WithRetryPolicy;
Expand Down Expand Up @@ -198,6 +199,25 @@ impl TryFrom<(&ClientWalker<'_>, &RuntimeContext)> for LLMPrimitiveProvider {
}
}

impl LLMPrimitiveProvider {
pub fn chat_to_message(
&self,
chat: &[RenderedChatMessage],
) -> Result<serde_json::Map<String, serde_json::Value>> {
use super::traits::ToProviderMessageExt;

match self {
LLMPrimitiveProvider::OpenAI(client) => client.chat_to_message(chat),
LLMPrimitiveProvider::Anthropic(client) => client.chat_to_message(chat),
LLMPrimitiveProvider::Google(client) => client.chat_to_message(chat),
LLMPrimitiveProvider::Vertex(client) => client.chat_to_message(chat),
LLMPrimitiveProvider::Aws(client) => {
todo!("AWS client does not implement ToProviderMessageExt::chat_to_message")
}
}
}
}

impl<'ir> WithPrompt<'ir> for LLMPrimitiveProvider {
async fn render_prompt(
&'ir self,
Expand Down
47 changes: 43 additions & 4 deletions engine/baml-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ pub(crate) mod internal;
#[cfg(not(target_arch = "wasm32"))]
pub mod cli;
pub mod client_registry;
pub mod test_constraints;
pub mod errors;
pub mod request;
mod runtime;
pub mod runtime_interface;
pub mod test_constraints;
pub mod tracing;
pub mod type_builder;
mod types;
Expand All @@ -30,12 +30,19 @@ use baml_types::Constraint;
use cfg_if::cfg_if;
use client_registry::ClientRegistry;
use indexmap::IndexMap;
use internal::llm_client::llm_provider::LLMProvider;
use internal::llm_client::orchestrator::OrchestrationScope;
use internal::prompt_renderer::PromptRenderer;
use internal_baml_core::configuration::CloudProject;
use internal_baml_core::configuration::CodegenGenerator;
use internal_baml_core::configuration::Generator;
use internal_baml_core::configuration::GeneratorOutputType;
use internal_baml_core::ir::FunctionWalker;
use internal_llm_client::AllowedRoleMetadata;
use internal_llm_client::ClientSpec;
use on_log_event::LogEventCallbackSync;
use runtime::InternalBamlRuntime;
use runtime_interface::InternalClientLookup;
use std::sync::OnceLock;

#[cfg(not(target_arch = "wasm32"))]
Expand Down Expand Up @@ -64,8 +71,8 @@ pub use internal_baml_core::internal_baml_diagnostics;
pub use internal_baml_core::internal_baml_diagnostics::Diagnostics as DiagnosticsError;
pub use internal_baml_core::ir::{scope_diagnostics, FieldType, IRHelper, TypeValue};

use crate::test_constraints::{evaluate_test_constraints, TestConstraintsResult};
use crate::internal::llm_client::LLMResponse;
use crate::test_constraints::{evaluate_test_constraints, TestConstraintsResult};

#[cfg(not(target_arch = "wasm32"))]
static TOKIO_SINGLETON: OnceLock<std::io::Result<Arc<tokio::runtime::Runtime>>> = OnceLock::new();
Expand Down Expand Up @@ -184,6 +191,32 @@ impl BamlRuntime {
}

impl BamlRuntime {
pub async fn render_prompt(
&self,
function_name: &str,
ctx: &RuntimeContext,
params: &BamlMap<String, BamlValue>,
node_index: Option<usize>,
) -> Result<(RenderedPrompt, OrchestrationScope, AllowedRoleMetadata)> {
self.inner
.render_prompt(function_name, ctx, params, node_index)
.await
}

pub fn llm_provider_from_function(
&self,
function_name: &str,
ctx: &RuntimeContext,
) -> Result<Arc<LLMProvider>> {
let renderer = PromptRenderer::from_function(
&self.inner.get_function(&function_name, &ctx)?,
self.inner.ir(),
&ctx,
)?;

self.inner.get_llm_provider(renderer.client_spec(), ctx)
}

pub fn get_test_params_and_constraints(
&self,
function_name: &str,
Expand Down Expand Up @@ -267,8 +300,14 @@ impl BamlRuntime {
} else {
match val {
Some(Ok(value)) => {
let value_with_constraints = value.0.map_meta(|(_,constraints,_)| constraints.clone());
evaluate_test_constraints(&params, &value_with_constraints, complete_resp, constraints)
let value_with_constraints =
value.0.map_meta(|(_, constraints, _)| constraints.clone());
evaluate_test_constraints(
&params,
&value_with_constraints,
complete_resp,
constraints,
)
}
_ => TestConstraintsResult::empty(),
}
Expand Down
Loading
Loading