Skip to content

Commit 21b4751

Browse files
committed
[CHORE]: Change store_tokens to get_tokens
1 parent e87ccb6 commit 21b4751

File tree

5 files changed

+41
-30
lines changed

5 files changed

+41
-30
lines changed

chromadb/utils/embedding_functions/chroma_bm25_embedding_function.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(_DEFAULT_STOPWORDS)
2424

25+
2526
class _HashedToken:
2627
__slots__ = ("hash", "label")
2728

@@ -47,7 +48,7 @@ class ChromaBm25Config(TypedDict, total=False):
4748
avg_doc_length: float
4849
token_max_length: int
4950
stopwords: List[str]
50-
store_tokens: bool
51+
get_tokens: bool
5152

5253

5354
class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
@@ -58,15 +59,15 @@ def __init__(
5859
avg_doc_length: float = DEFAULT_AVG_DOC_LENGTH,
5960
token_max_length: int = DEFAULT_TOKEN_MAX_LENGTH,
6061
stopwords: Optional[Iterable[str]] = None,
61-
store_tokens: bool = False,
62+
get_tokens: bool = False,
6263
) -> None:
6364
"""Initialize the BM25 sparse embedding function."""
6465

6566
self.k = float(k)
6667
self.b = float(b)
6768
self.avg_doc_length = float(avg_doc_length)
6869
self.token_max_length = int(token_max_length)
69-
self.store_tokens = bool(store_tokens)
70+
self.get_tokens = bool(get_tokens)
7071

7172
if stopwords is not None:
7273
self.stopwords: Optional[List[str]] = [str(word) for word in stopwords]
@@ -87,28 +88,28 @@ def _encode(self, text: str) -> SparseVector:
8788

8889
doc_len = float(len(tokens))
8990
counts = Counter(
90-
_HashedToken(self._hasher.hash(token), token if self.store_tokens else None)
91+
_HashedToken(self._hasher.hash(token), token if self.get_tokens else None)
9192
for token in tokens
9293
)
9394

9495
sorted_keys = sorted(counts.keys())
9596
indices: List[int] = []
9697
values: List[float] = []
97-
tokens: Optional[List[str]] = [] if self.store_tokens else None
98+
labels: Optional[List[str]] = [] if self.get_tokens else None
9899

99100
for key in sorted_keys:
100101
tf = float(counts[key])
101102
denominator = tf + self.k * (
102103
1 - self.b + (self.b * doc_len) / self.avg_doc_length
103104
)
104105
score = tf * (self.k + 1) / denominator
105-
106+
106107
indices.append(key.hash)
107108
values.append(score)
108-
if tokens is not None:
109-
tokens.append(key.label)
109+
if labels is not None and key.label is not None:
110+
labels.append(key.label)
110111

111-
return SparseVector(indices=indices, values=values, labels=tokens)
112+
return SparseVector(indices=indices, values=values, labels=labels)
112113

113114
def __call__(self, input: Documents) -> SparseVectors:
114115
sparse_vectors: SparseVectors = []
@@ -138,7 +139,7 @@ def build_from_config(
138139
avg_doc_length=config.get("avg_doc_length", DEFAULT_AVG_DOC_LENGTH),
139140
token_max_length=config.get("token_max_length", DEFAULT_TOKEN_MAX_LENGTH),
140141
stopwords=config.get("stopwords"),
141-
store_tokens=config.get("store_tokens", False),
142+
get_tokens=config.get("get_tokens", False),
142143
)
143144

144145
def get_config(self) -> Dict[str, Any]:
@@ -147,7 +148,7 @@ def get_config(self) -> Dict[str, Any]:
147148
"b": self.b,
148149
"avg_doc_length": self.avg_doc_length,
149150
"token_max_length": self.token_max_length,
150-
"store_tokens": self.store_tokens,
151+
"get_tokens": self.get_tokens,
151152
}
152153

153154
if self.stopwords is not None:
@@ -158,11 +159,18 @@ def get_config(self) -> Dict[str, Any]:
158159
def validate_config_update(
159160
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
160161
) -> None:
161-
mutable_keys = {"k", "b", "avg_doc_length", "token_max_length", "stopwords", "store_tokens"}
162+
mutable_keys = {
163+
"k",
164+
"b",
165+
"avg_doc_length",
166+
"token_max_length",
167+
"stopwords",
168+
"get_tokens",
169+
}
162170
for key in new_config:
163171
if key not in mutable_keys:
164172
raise ValueError(f"Updating '{key}' is not supported for {NAME}")
165173

166174
@staticmethod
167175
def validate_config(config: Dict[str, Any]) -> None:
168-
validate_config_schema(config, NAME)
176+
validate_config_schema(config, NAME)

chromadb/utils/embedding_functions/chroma_cloud_splade_embedding_function.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from chromadb.api.types import (
22
SparseEmbeddingFunction,
3+
SparseVector,
34
SparseVectors,
45
Documents,
56
)
@@ -21,7 +22,7 @@ def __init__(
2122
self,
2223
api_key_env_var: str = "CHROMA_API_KEY",
2324
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
24-
store_tokens: bool = False,
25+
get_tokens: bool = False,
2526
):
2627
"""
2728
Initialize the ChromaCloudSpladeEmbeddingFunction.
@@ -50,7 +51,7 @@ def __init__(
5051
f"or in any existing client instances"
5152
)
5253
self.model = model
53-
self.store_tokens = bool(store_tokens)
54+
self.get_tokens = bool(get_tokens)
5455
self._api_url = "https://embed.trychroma.com/embed_sparse"
5556
self._session = httpx.Client()
5657
self._session.headers.update(
@@ -89,7 +90,7 @@ def __call__(self, input: Documents) -> SparseVectors:
8990
"texts": list(input),
9091
"task": "",
9192
"target": "",
92-
"fetch_tokens": "true" if self.store_tokens is True else "false",
93+
"fetch_tokens": "true" if self.get_tokens is True else "false",
9394
}
9495

9596
try:
@@ -123,14 +124,14 @@ def _parse_response(self, response: Any) -> SparseVectors:
123124
if isinstance(emb, dict):
124125
indices = emb.get("indices", [])
125126
values = emb.get("values", [])
126-
raw_labels = emb.get("labels") if self.store_tokens else None
127+
raw_labels = emb.get("labels") if self.get_tokens else None
127128
labels: Optional[List[str]] = raw_labels if raw_labels else None
128129
else:
129130
# Already a SparseVector, extract its data
130-
assert(isinstance(emb, SparseVector))
131+
assert isinstance(emb, SparseVector)
131132
indices = emb.indices
132133
values = emb.values
133-
labels = emb.labels if self.store_tokens else None
134+
labels = emb.labels if self.get_tokens else None
134135

135136
normalized_vectors.append(
136137
normalize_sparse_vector(indices=indices, values=values, labels=labels)
@@ -155,23 +156,25 @@ def build_from_config(
155156
return ChromaCloudSpladeEmbeddingFunction(
156157
api_key_env_var=api_key_env_var,
157158
model=ChromaCloudSpladeEmbeddingModel(model),
158-
store_tokens=config.get("store_tokens", False),
159+
get_tokens=config.get("get_tokens", False),
159160
)
160161

161162
def get_config(self) -> Dict[str, Any]:
162163
return {
163164
"api_key_env_var": self.api_key_env_var,
164165
"model": self.model.value,
165-
"store_tokens": self.store_tokens,
166+
"get_tokens": self.get_tokens,
166167
}
167168

168169
def validate_config_update(
169170
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
170171
) -> None:
171-
immutable_keys = {"store_tokens", "model"}
172+
immutable_keys = {"get_tokens", "model"}
172173
for key in immutable_keys:
173174
if key in new_config and new_config[key] != old_config.get(key):
174-
raise ValueError(f"Updating '{key}' is not supported for chroma-cloud-splade")
175+
raise ValueError(
176+
f"Updating '{key}' is not supported for chroma-cloud-splade"
177+
)
175178

176179
@staticmethod
177180
def validate_config(config: Dict[str, Any]) -> None:

rust/chroma/src/embed/bm25.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ where
3131
H: TokenHasher,
3232
{
3333
/// Whether to store tokens in the created sparse vectors.
34-
pub store_tokens: bool,
34+
pub get_tokens: bool,
3535
/// Tokenizer for converting text into tokens.
3636
pub tokenizer: T,
3737
/// Hasher for converting tokens into u32 identifiers.
@@ -57,7 +57,7 @@ impl BM25SparseEmbeddingFunction<Bm25Tokenizer, Murmur3AbsHasher> {
5757
/// - hasher: Murmur3 with seed 0, abs() behavior
5858
pub fn default_murmur3_abs() -> Self {
5959
Self {
60-
store_tokens: true,
60+
get_tokens: true,
6161
tokenizer: Bm25Tokenizer::default(),
6262
hasher: Murmur3AbsHasher::default(),
6363
k: 1.2,
@@ -78,7 +78,7 @@ where
7878

7979
let doc_len = tokens.len() as f32;
8080

81-
if self.store_tokens {
81+
if self.get_tokens {
8282
let mut token_ids = Vec::with_capacity(tokens.len());
8383
for token in tokens {
8484
let id = self.hasher.hash(&token);

schemas/embedding_functions/chroma-cloud-splade.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"description": "Environment variable name that contains your API key for the Chroma Embedding API",
1818
"default": "CHROMA_API_KEY"
1919
},
20-
"store_tokens": {
20+
"get_tokens": {
2121
"type": "boolean",
2222
"description": "Whether to store token labels in the sparse vector output",
2323
"default": false
@@ -28,4 +28,4 @@
2828
"model"
2929
],
3030
"additionalProperties": false
31-
}
31+
}

schemas/embedding_functions/chroma_bm25.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
"type": "string"
2929
}
3030
},
31-
"store_tokens": {
31+
"get_tokens": {
3232
"type": "boolean",
3333
"description": "Whether to store token strings in the sparse vectors (default: true)"
3434
}
3535
},
3636
"additionalProperties": false
37-
}
37+
}

0 commit comments

Comments
 (0)