Skip to content

Commit edab0e1

Browse files
authored
[ENH]: Add token storage for chroma cloud splade sparse vectors (#5944)
## Description of changes Added token/label storage for Splade sparse vectors. - Improvements & Bug fixes - ... - New functionality - ... ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent 33305f4 commit edab0e1

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

chromadb/utils/embedding_functions/chroma_cloud_splade_embedding_function.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
SparseVectors,
44
Documents,
55
)
6-
from typing import Dict, Any
6+
from typing import Dict, Any, List, Optional
77
from enum import Enum
88
from chromadb.utils.embedding_functions.schemas import validate_config_schema
99
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
@@ -21,6 +21,7 @@ def __init__(
2121
self,
2222
api_key_env_var: str = "CHROMA_API_KEY",
2323
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
24+
store_tokens: bool = False,
2425
):
2526
"""
2627
Initialize the ChromaCloudSpladeEmbeddingFunction.
@@ -49,6 +50,7 @@ def __init__(
4950
f"or in any existing client instances"
5051
)
5152
self.model = model
53+
self.store_tokens = bool(store_tokens)
5254
self._api_url = "https://embed.trychroma.com/embed_sparse"
5355
self._session = httpx.Client()
5456
self._session.headers.update(
@@ -87,6 +89,7 @@ def __call__(self, input: Documents) -> SparseVectors:
8789
"texts": list(input),
8890
"task": "",
8991
"target": "",
92+
"fetch_tokens": "true" if self.store_tokens is True else "false",
9093
}
9194

9295
try:
@@ -120,13 +123,17 @@ def _parse_response(self, response: Any) -> SparseVectors:
120123
if isinstance(emb, dict):
121124
indices = emb.get("indices", [])
122125
values = emb.get("values", [])
126+
raw_labels = emb.get("labels") if self.store_tokens else None
127+
labels: Optional[List[str]] = raw_labels if raw_labels else None
123128
else:
124129
# Already a SparseVector, extract its data
130+
assert(isinstance(emb, SparseVector))
125131
indices = emb.indices
126132
values = emb.values
133+
labels = emb.labels if self.store_tokens else None
127134

128135
normalized_vectors.append(
129-
normalize_sparse_vector(indices=indices, values=values)
136+
normalize_sparse_vector(indices=indices, values=values, labels=labels)
130137
)
131138

132139
return normalized_vectors
@@ -148,18 +155,23 @@ def build_from_config(
148155
return ChromaCloudSpladeEmbeddingFunction(
149156
api_key_env_var=api_key_env_var,
150157
model=ChromaCloudSpladeEmbeddingModel(model),
158+
store_tokens=config.get("store_tokens", False),
151159
)
152160

153161
def get_config(self) -> Dict[str, Any]:
154-
return {"api_key_env_var": self.api_key_env_var, "model": self.model.value}
162+
return {
163+
"api_key_env_var": self.api_key_env_var,
164+
"model": self.model.value,
165+
"store_tokens": self.store_tokens,
166+
}
155167

156168
def validate_config_update(
157169
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
158170
) -> None:
159-
if "model" in new_config:
160-
raise ValueError(
161-
"model cannot be changed after the embedding function has been initialized"
162-
)
171+
immutable_keys = {"store_tokens", "model"}
172+
for key in immutable_keys:
173+
if key in new_config and new_config[key] != old_config.get(key):
174+
raise ValueError(f"Updating '{key}' is not supported for chroma-cloud-splade")
163175

164176
@staticmethod
165177
def validate_config(config: Dict[str, Any]) -> None:

schemas/embedding_functions/chroma-cloud-splade.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
"type": "string",
1717
"description": "Environment variable name that contains your API key for the Chroma Embedding API",
1818
"default": "CHROMA_API_KEY"
19+
},
20+
"store_tokens": {
21+
"type": "boolean",
22+
"description": "Whether to store token labels in the sparse vector output",
23+
"default": false
1924
}
2025
},
2126
"required": [

0 commit comments

Comments
 (0)