diff --git a/crates/app/src/lib.rs b/crates/app/src/lib.rs index 806894bfa1..8c5b84b0e5 100644 --- a/crates/app/src/lib.rs +++ b/crates/app/src/lib.rs @@ -141,7 +141,7 @@ impl App { pub fn triggers_with_type<'a>( &'a self, trigger_type: &'a str, - ) -> impl Iterator { + ) -> impl Iterator> { self.triggers() .filter(move |trigger| trigger.locked.trigger_type == trigger_type) } diff --git a/crates/key-value-azure/src/lib.rs b/crates/key-value-azure/src/lib.rs index f36fde8abf..82195fc5ab 100644 --- a/crates/key-value-azure/src/lib.rs +++ b/crates/key-value-azure/src/lib.rs @@ -2,20 +2,23 @@ mod store; use serde::Deserialize; use spin_factor_key_value::runtime_config::spin::MakeKeyValueStore; -use store::{ + +pub use store::{ KeyValueAzureCosmos, KeyValueAzureCosmosAuthOptions, KeyValueAzureCosmosRuntimeConfigOptions, }; /// A key-value store that uses Azure Cosmos as the backend. -#[derive(Default)] pub struct AzureKeyValueStore { - _priv: (), + app_id: Option, } impl AzureKeyValueStore { /// Creates a new `AzureKeyValueStore`. - pub fn new() -> Self { - Self::default() + /// + /// When `app_id` is provided, the store will a partition key of `$app_id/$store_name`, + /// otherwise the partition key will be `id`. + pub fn new(app_id: Option) -> Self { + Self { app_id } } } @@ -55,6 +58,7 @@ impl MakeKeyValueStore for AzureKeyValueStore { runtime_config.database, runtime_config.container, auth_options, + self.app_id.clone(), ) } } diff --git a/crates/key-value-azure/src/store.rs b/crates/key-value-azure/src/store.rs index 001864ea77..86f0a8a92d 100644 --- a/crates/key-value-azure/src/store.rs +++ b/crates/key-value-azure/src/store.rs @@ -1,6 +1,5 @@ use anyhow::Result; use azure_data_cosmos::prelude::Operation; -use azure_data_cosmos::resources::collection::PartitionKey; use azure_data_cosmos::{ prelude::{AuthorizationToken, CollectionClient, CosmosClient, Query}, CosmosEntity, @@ -13,6 +12,12 @@ use std::sync::{Arc, Mutex}; pub struct KeyValueAzureCosmos { client: CollectionClient, + /// An optional app id + /// + /// If provided, the store will handle multiple stores per container using a + /// partition key of `/$app_id/$store_name`, otherwise there will be one container + /// per store, and the partition key will be `/id`. + app_id: Option, } /// Azure Cosmos Key / Value runtime config literal options for authentication @@ -71,6 +76,7 @@ impl KeyValueAzureCosmos { database: String, container: String, auth_options: KeyValueAzureCosmosAuthOptions, + app_id: Option, ) -> Result { let token = match auth_options { KeyValueAzureCosmosAuthOptions::RuntimeConfigValues(config) => { @@ -86,15 +92,16 @@ impl KeyValueAzureCosmos { let database_client = cosmos_client.database_client(database); let client = database_client.collection_client(container); - Ok(Self { client }) + Ok(Self { client, app_id }) } } #[async_trait] impl StoreManager for KeyValueAzureCosmos { - async fn get(&self, _name: &str) -> Result, Error> { + async fn get(&self, name: &str) -> Result, Error> { Ok(Arc::new(AzureCosmosStore { client: self.client.clone(), + store_id: self.app_id.as_ref().map(|i| format!("{i}/{name}")), })) } @@ -114,13 +121,10 @@ impl StoreManager for KeyValueAzureCosmos { #[derive(Clone)] struct AzureCosmosStore { client: CollectionClient, -} - -struct CompareAndSwap { - key: String, - client: CollectionClient, - bucket_rep: u32, - etag: Mutex>, + /// An optional store id to use as a partition key for all operations. + /// + /// If the store id not set, the store will use `/id` as the partition key. + store_id: Option, } #[async_trait] @@ -134,6 +138,7 @@ impl Store for AzureCosmosStore { let pair = Pair { id: key.to_string(), value: value.to_vec(), + store_id: self.store_id.clone(), }; self.client .create_document(pair) @@ -145,7 +150,10 @@ impl Store for AzureCosmosStore { async fn delete(&self, key: &str) -> Result<(), Error> { if self.exists(key).await? { - let document_client = self.client.document_client(key, &key).map_err(log_error)?; + let document_client = self + .client + .document_client(key, &self.store_id) + .map_err(log_error)?; document_client.delete_document().await.map_err(log_error)?; } Ok(()) @@ -160,12 +168,7 @@ impl Store for AzureCosmosStore { } async fn get_many(&self, keys: Vec) -> Result>)>, Error> { - let in_clause: String = keys - .into_iter() - .map(|k| format!("'{}'", k)) - .collect::>() - .join(", "); - let stmt = Query::new(format!("SELECT * FROM c WHERE c.id IN ({})", in_clause)); + let stmt = Query::new(self.get_in_query(keys)); let query = self .client .query_documents(stmt) @@ -175,9 +178,11 @@ impl Store for AzureCosmosStore { let mut stream = query.into_stream::(); while let Some(resp) = stream.next().await { let resp = resp.map_err(log_error)?; - for (pair, _) in resp.results { - res.push((pair.id, Some(pair.value))); - } + res.extend( + resp.results + .into_iter() + .map(|(pair, _)| (pair.id, Some(pair.value))), + ); } Ok(res) } @@ -200,7 +205,7 @@ impl Store for AzureCosmosStore { let operations = vec![Operation::incr("/value", delta).map_err(log_error)?]; let _ = self .client - .document_client(key.clone(), &key.as_str()) + .document_client(key.clone(), &self.store_id) .map_err(log_error)? .patch_document(operations) .await @@ -227,10 +232,31 @@ impl Store for AzureCosmosStore { client: self.client.clone(), etag: Mutex::new(None), bucket_rep, + store_id: self.store_id.clone(), })) } } +struct CompareAndSwap { + key: String, + client: CollectionClient, + bucket_rep: u32, + etag: Mutex>, + store_id: Option, +} + +impl CompareAndSwap { + fn get_query(&self) -> String { + let mut query = format!("SELECT * FROM c WHERE c.id='{}'", self.key); + self.append_store_id(&mut query, true); + query + } + + fn append_store_id(&self, query: &mut String, condition_already_exists: bool) { + append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists); + } +} + #[async_trait] impl Cas for CompareAndSwap { /// `current` will fetch the current value for the key and store the etag for the record. The @@ -238,10 +264,7 @@ impl Cas for CompareAndSwap { async fn current(&self) -> Result>, Error> { let mut stream = self .client - .query_documents(Query::new(format!( - "SELECT * FROM c WHERE c.id='{}'", - self.key - ))) + .query_documents(Query::new(self.get_query())) .query_cross_partition(true) .max_item_count(1) .into_stream::(); @@ -272,15 +295,15 @@ impl Cas for CompareAndSwap { /// `swap` updates the value for the key using the etag saved in the `current` function for /// optimistic concurrency. async fn swap(&self, value: Vec) -> Result<(), SwapError> { - let pk = PartitionKey::from(&self.key); let pair = Pair { id: self.key.clone(), value, + store_id: self.store_id.clone(), }; let doc_client = self .client - .document_client(&self.key, &pk) + .document_client(&self.key, &pair.partition_key()) .map_err(log_cas_error)?; let etag_value = self.etag.lock().unwrap().clone(); @@ -318,55 +341,97 @@ impl AzureCosmosStore { async fn get_pair(&self, key: &str) -> Result, Error> { let query = self .client - .query_documents(Query::new(format!("SELECT * FROM c WHERE c.id='{}'", key))) + .query_documents(Query::new(self.get_query(key))) .query_cross_partition(true) .max_item_count(1); // There can be no duplicated keys, so we create the stream and only take the first result. let mut stream = query.into_stream::(); - let res = stream.next().await; - match res { - Some(r) => { - let r = r.map_err(log_error)?; - match r.results.first().cloned() { - Some((p, _)) => Ok(Some(p)), - None => Ok(None), - } - } - None => Ok(None), - } + let Some(res) = stream.next().await else { + return Ok(None); + }; + Ok(res + .map_err(log_error)? + .results + .first() + .map(|(p, _)| p.clone())) } async fn get_keys(&self) -> Result, Error> { let query = self .client - .query_documents(Query::new("SELECT * FROM c".to_string())) + .query_documents(Query::new(self.get_keys_query())) .query_cross_partition(true); let mut res = Vec::new(); let mut stream = query.into_stream::(); while let Some(resp) = stream.next().await { let resp = resp.map_err(log_error)?; - for (pair, _) in resp.results { - res.push(pair.id); - } + res.extend(resp.results.into_iter().map(|(pair, _)| pair.id)); } Ok(res) } + + fn get_query(&self, key: &str) -> String { + let mut query = format!("SELECT * FROM c WHERE c.id='{}'", key); + self.append_store_id(&mut query, true); + query + } + + fn get_keys_query(&self) -> String { + let mut query = "SELECT * FROM c".to_owned(); + self.append_store_id(&mut query, false); + query + } + + fn get_in_query(&self, keys: Vec) -> String { + let in_clause: String = keys + .into_iter() + .map(|k| format!("'{k}'")) + .collect::>() + .join(", "); + + let mut query = format!("SELECT * FROM c WHERE c.id IN ({})", in_clause); + self.append_store_id(&mut query, true); + query + } + + fn append_store_id(&self, query: &mut String, condition_already_exists: bool) { + append_store_id_condition(query, self.store_id.as_deref(), condition_already_exists); + } +} + +/// Appends an option store id condition to the query. +fn append_store_id_condition( + query: &mut String, + store_id: Option<&str>, + condition_already_exists: bool, +) { + if let Some(s) = store_id { + if condition_already_exists { + query.push_str(" AND"); + } else { + query.push_str(" WHERE"); + } + query.push_str(" c.store_id='"); + query.push_str(s); + query.push('\'') + } } #[derive(Serialize, Deserialize, Clone, Debug)] pub struct Pair { - // In Azure CosmosDB, the default partition key is "/id", and this implementation assumes that partition ID is not changed. pub id: String, pub value: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub store_id: Option, } impl CosmosEntity for Pair { type Entity = String; fn partition_key(&self) -> Self::Entity { - self.id.clone() + self.store_id.clone().unwrap_or_else(|| self.id.clone()) } } diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index f1fb92e02d..3e7f22ada7 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -403,7 +403,7 @@ pub fn key_value_config_resolver( .register_store_type(spin_key_value_redis::RedisKeyValueStore::new()) .unwrap(); key_value - .register_store_type(spin_key_value_azure::AzureKeyValueStore::new()) + .register_store_type(spin_key_value_azure::AzureKeyValueStore::new(None)) .unwrap(); key_value .register_store_type(spin_key_value_aws::AwsDynamoKeyValueStore::new())