Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

DEFAULT_CHROMA_BM25_STOPWORDS: List[str] = list(_DEFAULT_STOPWORDS)


class _HashedToken:
__slots__ = ("hash", "label")

Expand All @@ -47,7 +48,7 @@ class ChromaBm25Config(TypedDict, total=False):
avg_doc_length: float
token_max_length: int
stopwords: List[str]
store_tokens: bool
include_tokens: bool


class ChromaBm25EmbeddingFunction(SparseEmbeddingFunction[Documents]):
Expand All @@ -58,15 +59,15 @@ def __init__(
avg_doc_length: float = DEFAULT_AVG_DOC_LENGTH,
token_max_length: int = DEFAULT_TOKEN_MAX_LENGTH,
stopwords: Optional[Iterable[str]] = None,
store_tokens: bool = False,
include_tokens: bool = False,
) -> None:
"""Initialize the BM25 sparse embedding function."""

self.k = float(k)
self.b = float(b)
self.avg_doc_length = float(avg_doc_length)
self.token_max_length = int(token_max_length)
self.store_tokens = bool(store_tokens)
self.include_tokens = bool(include_tokens)

if stopwords is not None:
self.stopwords: Optional[List[str]] = [str(word) for word in stopwords]
Expand All @@ -87,28 +88,30 @@ def _encode(self, text: str) -> SparseVector:

doc_len = float(len(tokens))
counts = Counter(
_HashedToken(self._hasher.hash(token), token if self.store_tokens else None)
_HashedToken(
self._hasher.hash(token), token if self.include_tokens else None
)
for token in tokens
)

sorted_keys = sorted(counts.keys())
indices: List[int] = []
values: List[float] = []
tokens: Optional[List[str]] = [] if self.store_tokens else None
labels: Optional[List[str]] = [] if self.include_tokens else None

for key in sorted_keys:
tf = float(counts[key])
denominator = tf + self.k * (
1 - self.b + (self.b * doc_len) / self.avg_doc_length
)
score = tf * (self.k + 1) / denominator

indices.append(key.hash)
values.append(score)
if tokens is not None:
tokens.append(key.label)
if labels is not None and key.label is not None:
labels.append(key.label)

return SparseVector(indices=indices, values=values, labels=tokens)
return SparseVector(indices=indices, values=values, labels=labels)

def __call__(self, input: Documents) -> SparseVectors:
sparse_vectors: SparseVectors = []
Expand Down Expand Up @@ -138,7 +141,7 @@ def build_from_config(
avg_doc_length=config.get("avg_doc_length", DEFAULT_AVG_DOC_LENGTH),
token_max_length=config.get("token_max_length", DEFAULT_TOKEN_MAX_LENGTH),
stopwords=config.get("stopwords"),
store_tokens=config.get("store_tokens", False),
include_tokens=config.get("include_tokens", False),
)

def get_config(self) -> Dict[str, Any]:
Expand All @@ -147,7 +150,7 @@ def get_config(self) -> Dict[str, Any]:
"b": self.b,
"avg_doc_length": self.avg_doc_length,
"token_max_length": self.token_max_length,
"store_tokens": self.store_tokens,
"include_tokens": self.include_tokens,
}

if self.stopwords is not None:
Expand All @@ -158,11 +161,18 @@ def get_config(self) -> Dict[str, Any]:
def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
mutable_keys = {"k", "b", "avg_doc_length", "token_max_length", "stopwords", "store_tokens"}
mutable_keys = {
"k",
"b",
"avg_doc_length",
"token_max_length",
"stopwords",
"include_tokens",
}
for key in new_config:
if key not in mutable_keys:
raise ValueError(f"Updating '{key}' is not supported for {NAME}")

@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
validate_config_schema(config, NAME)
validate_config_schema(config, NAME)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from chromadb.api.types import (
SparseEmbeddingFunction,
SparseVector,
SparseVectors,
Documents,
)
Expand All @@ -21,7 +22,7 @@ def __init__(
self,
api_key_env_var: str = "CHROMA_API_KEY",
model: ChromaCloudSpladeEmbeddingModel = ChromaCloudSpladeEmbeddingModel.SPLADE_PP_EN_V1,
store_tokens: bool = False,
include_tokens: bool = False,
):
"""
Initialize the ChromaCloudSpladeEmbeddingFunction.
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(
f"or in any existing client instances"
)
self.model = model
self.store_tokens = bool(store_tokens)
self.include_tokens = bool(include_tokens)
self._api_url = "https://embed.trychroma.com/embed_sparse"
self._session = httpx.Client()
self._session.headers.update(
Expand Down Expand Up @@ -89,7 +90,7 @@ def __call__(self, input: Documents) -> SparseVectors:
"texts": list(input),
"task": "",
"target": "",
"fetch_tokens": "true" if self.store_tokens is True else "false",
"fetch_tokens": "true" if self.include_tokens is True else "false",
}

try:
Expand Down Expand Up @@ -123,14 +124,14 @@ def _parse_response(self, response: Any) -> SparseVectors:
if isinstance(emb, dict):
indices = emb.get("indices", [])
values = emb.get("values", [])
raw_labels = emb.get("labels") if self.store_tokens else None
raw_labels = emb.get("labels") if self.include_tokens else None
labels: Optional[List[str]] = raw_labels if raw_labels else None
else:
# Already a SparseVector, extract its data
assert(isinstance(emb, SparseVector))
assert isinstance(emb, SparseVector)
indices = emb.indices
values = emb.values
labels = emb.labels if self.store_tokens else None
labels = emb.labels if self.include_tokens else None

normalized_vectors.append(
normalize_sparse_vector(indices=indices, values=values, labels=labels)
Expand All @@ -155,23 +156,25 @@ def build_from_config(
return ChromaCloudSpladeEmbeddingFunction(
api_key_env_var=api_key_env_var,
model=ChromaCloudSpladeEmbeddingModel(model),
store_tokens=config.get("store_tokens", False),
include_tokens=config.get("include_tokens", False),
)

def get_config(self) -> Dict[str, Any]:
return {
"api_key_env_var": self.api_key_env_var,
"model": self.model.value,
"store_tokens": self.store_tokens,
"include_tokens": self.include_tokens,
}

def validate_config_update(
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
) -> None:
immutable_keys = {"store_tokens", "model"}
immutable_keys = {"include_tokens", "model"}
for key in immutable_keys:
if key in new_config and new_config[key] != old_config.get(key):
raise ValueError(f"Updating '{key}' is not supported for chroma-cloud-splade")
raise ValueError(
f"Updating '{key}' is not supported for chroma-cloud-splade"
)

@staticmethod
def validate_config(config: Dict[str, Any]) -> None:
Expand Down
6 changes: 3 additions & 3 deletions rust/chroma/src/embed/bm25.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ where
H: TokenHasher,
{
/// Whether to store tokens in the created sparse vectors.
pub store_tokens: bool,
pub include_tokens: bool,
/// Tokenizer for converting text into tokens.
pub tokenizer: T,
/// Hasher for converting tokens into u32 identifiers.
Expand All @@ -57,7 +57,7 @@ impl BM25SparseEmbeddingFunction<Bm25Tokenizer, Murmur3AbsHasher> {
/// - hasher: Murmur3 with seed 0, abs() behavior
pub fn default_murmur3_abs() -> Self {
Self {
store_tokens: true,
include_tokens: true,
tokenizer: Bm25Tokenizer::default(),
hasher: Murmur3AbsHasher::default(),
k: 1.2,
Expand All @@ -78,7 +78,7 @@ where

let doc_len = tokens.len() as f32;

if self.store_tokens {
if self.include_tokens {
let mut token_ids = Vec::with_capacity(tokens.len());
for token in tokens {
let id = self.hasher.hash(&token);
Expand Down
4 changes: 2 additions & 2 deletions schemas/embedding_functions/chroma-cloud-splade.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"description": "Environment variable name that contains your API key for the Chroma Embedding API",
"default": "CHROMA_API_KEY"
},
"store_tokens": {
"include_tokens": {
"type": "boolean",
"description": "Whether to store token labels in the sparse vector output",
"default": false
Expand All @@ -28,4 +28,4 @@
"model"
],
"additionalProperties": false
}
}
4 changes: 2 additions & 2 deletions schemas/embedding_functions/chroma_bm25.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
"type": "string"
}
},
"store_tokens": {
"include_tokens": {
"type": "boolean",
"description": "Whether to store token strings in the sparse vectors (default: true)"
}
},
"additionalProperties": false
}
}