Skip to content

Commit 21dfd10

Browse files
authored
[ENH] Add CMEK to collection schema (#5975)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - N/A - New functionality - Add CMEK to collection schema, and with builder methods - Propagate CMEK from schema to downstream ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 50f776a commit 21dfd10

File tree

7 files changed

+133
-82
lines changed

7 files changed

+133
-82
lines changed

clients/new-js/packages/chromadb/src/api/types.gen.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,12 @@ export type RawWhereFields = {
414414
* This represents the server-side schema structure used for index management
415415
*/
416416
export type Schema = {
417+
/**
418+
* Customer-managed encryption key for collection data
419+
*/
420+
cmek?: {
421+
[key: string]: unknown;
422+
} | null;
417423
/**
418424
* Default index configurations for each value type
419425
*/

rust/chroma/src/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub use chroma_types::BoolInvertedIndexConfig;
2020
pub use chroma_types::BoolInvertedIndexType;
2121
pub use chroma_types::BoolValueType;
2222
pub use chroma_types::BooleanOperator;
23+
pub use chroma_types::Cmek;
2324
pub use chroma_types::Collection;
2425
pub use chroma_types::CollectionConfiguration;
2526
pub use chroma_types::CollectionConversionError;

rust/frontend/src/impls/service_based_frontend.rs

Lines changed: 56 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -182,20 +182,6 @@ impl ServiceBasedFrontend {
182182
.collection)
183183
}
184184

185-
async fn get_collection_dimension(
186-
&mut self,
187-
collection_id: CollectionUuid,
188-
) -> Result<Option<u32>, GetCollectionError> {
189-
Ok(self
190-
.collections_with_segments_provider
191-
.get_collection_with_segments(collection_id)
192-
.await
193-
.map_err(|err| Box::new(err) as Box<dyn ChromaError>)?
194-
.collection
195-
.dimension
196-
.map(|dim| dim as u32))
197-
}
198-
199185
async fn set_collection_dimension(
200186
&mut self,
201187
collection_id: CollectionUuid,
@@ -219,10 +205,11 @@ impl ServiceBasedFrontend {
219205
option_embeddings: Option<&Vec<Embedding>>,
220206
update_if_not_present: bool,
221207
read_length: F,
222-
) -> Result<(), ValidationError>
208+
) -> Result<Collection, ValidationError>
223209
where
224210
F: Fn(&Embedding) -> Option<usize>,
225211
{
212+
let collection = self.get_cached_collection(collection_id).await?;
226213
if let Some(embeddings) = option_embeddings {
227214
let emb_dims = embeddings
228215
.iter()
@@ -237,28 +224,23 @@ impl ServiceBasedFrontend {
237224
low as u32
238225
} else {
239226
// No embedding to check, return
240-
return Ok(());
227+
return Ok(collection);
241228
};
242-
match self.get_collection_dimension(collection_id).await {
243-
Ok(Some(expected_dim)) => {
229+
match collection.dimension.map(|dim| dim as u32) {
230+
Some(expected_dim) => {
244231
if expected_dim != emb_dim {
245232
return Err(ValidationError::DimensionMismatch(expected_dim, emb_dim));
246233
}
247-
248-
Ok(())
249234
}
250-
Ok(None) => {
235+
None => {
251236
if update_if_not_present {
252237
self.set_collection_dimension(collection_id, emb_dim)
253238
.await?;
254239
}
255-
Ok(())
256240
}
257-
Err(err) => Err(err.into()),
258-
}
259-
} else {
260-
Ok(())
241+
};
261242
}
243+
Ok(collection)
262244
}
263245

264246
pub async fn reset(&mut self) -> Result<ResetResponse, ResetError> {
@@ -830,31 +812,31 @@ impl ServiceBasedFrontend {
830812
..
831813
}: AddCollectionRecordsRequest,
832814
) -> Result<AddCollectionRecordsResponse, AddCollectionRecordsError> {
833-
self.validate_embedding(
834-
collection_id,
835-
Some(&embeddings),
836-
true,
837-
|embedding: &Vec<f32>| Some(embedding.len()),
838-
)
839-
.await
840-
.map_err(|err| err.boxed())?;
815+
let collection = self
816+
.validate_embedding(
817+
collection_id,
818+
Some(&embeddings),
819+
true,
820+
|embedding: &Vec<f32>| Some(embedding.len()),
821+
)
822+
.await
823+
.map_err(|err| err.boxed())?;
841824

842825
let embeddings = Some(embeddings.into_iter().map(Some).collect());
843826

844827
let (records, log_size_bytes) =
845828
to_records(ids, embeddings, documents, uris, metadatas, Operation::Add)
846829
.map_err(|err| Box::new(err) as Box<dyn ChromaError>)?;
847830

848-
// TODO: Extract CMEK from collection metadata
849-
// For now, pass None until collection-level CMEK storage is implemented
850-
let cmek = None;
851-
852831
let retries = Arc::new(AtomicUsize::new(0));
853832
let add_to_retry = || {
854833
let mut self_clone = self.clone();
855834
let records_clone = records.clone();
856835
let tenant_id_clone = tenant_id.clone();
857-
let cmek_clone = cmek.clone();
836+
let cmek_clone = collection
837+
.schema
838+
.as_ref()
839+
.and_then(|schema| schema.cmek.clone());
858840
async move {
859841
self_clone
860842
.retryable_push_logs(&tenant_id_clone, collection_id, records_clone, cmek_clone)
@@ -922,11 +904,12 @@ impl ServiceBasedFrontend {
922904
..
923905
}: UpdateCollectionRecordsRequest,
924906
) -> Result<UpdateCollectionRecordsResponse, UpdateCollectionRecordsError> {
925-
self.validate_embedding(collection_id, embeddings.as_ref(), true, |embedding| {
926-
embedding.as_ref().map(|emb| emb.len())
927-
})
928-
.await
929-
.map_err(|err| err.boxed())?;
907+
let collection = self
908+
.validate_embedding(collection_id, embeddings.as_ref(), true, |embedding| {
909+
embedding.as_ref().map(|emb| emb.len())
910+
})
911+
.await
912+
.map_err(|err| err.boxed())?;
930913

931914
let (records, log_size_bytes) = to_records(
932915
ids,
@@ -938,16 +921,15 @@ impl ServiceBasedFrontend {
938921
)
939922
.map_err(|err| Box::new(err) as Box<dyn ChromaError>)?;
940923

941-
// TODO: Extract CMEK from collection metadata
942-
// For now, pass None until collection-level CMEK storage is implemented
943-
let cmek = None;
944-
945924
let retries = Arc::new(AtomicUsize::new(0));
946925
let add_to_retry = || {
947926
let mut self_clone = self.clone();
948927
let records_clone = records.clone();
949928
let tenant_id_clone = tenant_id.clone();
950-
let cmek_clone = cmek.clone();
929+
let cmek_clone = collection
930+
.schema
931+
.as_ref()
932+
.and_then(|schema| schema.cmek.clone());
951933
async move {
952934
self_clone
953935
.retryable_push_logs(&tenant_id_clone, collection_id, records_clone, cmek_clone)
@@ -1015,14 +997,15 @@ impl ServiceBasedFrontend {
1015997
..
1016998
}: UpsertCollectionRecordsRequest,
1017999
) -> Result<UpsertCollectionRecordsResponse, UpsertCollectionRecordsError> {
1018-
self.validate_embedding(
1019-
collection_id,
1020-
Some(&embeddings),
1021-
true,
1022-
|embedding: &Vec<f32>| Some(embedding.len()),
1023-
)
1024-
.await
1025-
.map_err(|err| err.boxed())?;
1000+
let collection = self
1001+
.validate_embedding(
1002+
collection_id,
1003+
Some(&embeddings),
1004+
true,
1005+
|embedding: &Vec<f32>| Some(embedding.len()),
1006+
)
1007+
.await
1008+
.map_err(|err| err.boxed())?;
10261009

10271010
let embeddings = Some(embeddings.into_iter().map(Some).collect());
10281011

@@ -1036,16 +1019,15 @@ impl ServiceBasedFrontend {
10361019
)
10371020
.map_err(|err| Box::new(err) as Box<dyn ChromaError>)?;
10381021

1039-
// TODO: Extract CMEK from collection metadata
1040-
// For now, pass None until collection-level CMEK storage is implemented
1041-
let cmek = None;
1042-
10431022
let retries = Arc::new(AtomicUsize::new(0));
10441023
let add_to_retry = || {
10451024
let mut self_clone = self.clone();
10461025
let records_clone = records.clone();
10471026
let tenant_id_clone = tenant_id.clone();
1048-
let cmek_clone = cmek.clone();
1027+
let cmek_clone = collection
1028+
.schema
1029+
.as_ref()
1030+
.and_then(|schema| schema.cmek.clone());
10491031
async move {
10501032
self_clone
10511033
.retryable_push_logs(&tenant_id_clone, collection_id, records_clone, cmek_clone)
@@ -1229,10 +1211,13 @@ impl ServiceBasedFrontend {
12291211

12301212
let log_size_bytes = records.iter().map(OperationRecord::size_bytes).sum();
12311213

1232-
// TODO: Extract CMEK from collection metadata
1233-
// For now, pass None until collection-level CMEK storage is implemented
1234-
self.log_client
1235-
.push_logs(&tenant_id, collection_id, records, None)
1214+
let cmek = self
1215+
.get_cached_collection(collection_id)
1216+
.await
1217+
.map_err(|err| DeleteCollectionRecordsError::Internal(err.boxed()))?
1218+
.schema
1219+
.and_then(|schema| schema.cmek.clone());
1220+
self.retryable_push_logs(&tenant_id, collection_id, records, cmek)
12361221
.await
12371222
.map_err(|err| {
12381223
if err.code() == ErrorCodes::Unavailable {
@@ -1987,10 +1972,8 @@ impl ServiceBasedFrontend {
19871972
collection_id: CollectionUuid,
19881973
_attached_function_id: chroma_types::AttachedFunctionUuid,
19891974
) -> Result<(), chroma_types::AttachFunctionError> {
1990-
let embedding_dim = self
1991-
.get_collection_dimension(collection_id)
1992-
.await?
1993-
.unwrap_or(1);
1975+
let collection = self.get_cached_collection(collection_id).await?;
1976+
let embedding_dim = collection.dimension.unwrap_or(1);
19941977
let fake_embedding = vec![0.0; embedding_dim as usize];
19951978
// TODO(tanujnay112): Make this either a configurable or better yet a separate
19961979
// RPC to the logs service.
@@ -2006,10 +1989,9 @@ impl ServiceBasedFrontend {
20061989
};
20071990
num_fake_logs
20081991
];
2009-
// TODO: Extract CMEK from collection metadata
2010-
// For now, pass None until collection-level CMEK storage is implemented
2011-
self.log_client
2012-
.push_logs(&tenant, collection_id, logs, None)
1992+
1993+
let cmek = collection.schema.and_then(|schema| schema.cmek.clone());
1994+
self.retryable_push_logs(&tenant, collection_id, logs, cmek)
20131995
.await?;
20141996
Ok(())
20151997
}

0 commit comments

Comments
 (0)