Skip to content

Commit 5e6d93b

Browse files
committed
[ENH]: Force statistics wrapper callers to specify an output collection
1 parent a1e139d commit 5e6d93b

File tree

2 files changed

+42
-43
lines changed

2 files changed

+42
-43
lines changed

chromadb/test/distributed/test_statistics_wrapper.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import json
66
import time
7+
from typing import Any
8+
79
from chromadb.api.client import Client as ClientCreator
810
from chromadb.base_types import SparseVector
911
from chromadb.config import System
@@ -31,7 +33,7 @@ def test_statistics_wrapper(basic_http_client: System) -> None:
3133
)
3234

3335
# Enable statistics
34-
attached_fn = attach_statistics_function(collection)
36+
attached_fn = attach_statistics_function(collection, "test_collection_statistics")
3537
assert attached_fn is not None
3638
assert attached_fn.function_name == "statistics"
3739
assert attached_fn.output_collection == "test_collection_statistics"
@@ -54,7 +56,7 @@ def test_statistics_wrapper(basic_http_client: System) -> None:
5456
time.sleep(60)
5557

5658
# Get statistics
57-
stats = get_statistics(collection)
59+
stats = get_statistics(collection, "test_collection_statistics")
5860
print("\nStatistics output:")
5961
print(json.dumps(stats, indent=2))
6062

@@ -124,14 +126,14 @@ def test_backfill_statistics(basic_http_client: System) -> None:
124126
initial_version = get_collection_version(client, collection.name)
125127

126128
# Enable statistics
127-
attached_fn = attach_statistics_function(collection)
129+
attached_fn = attach_statistics_function(collection, "my_collection_statistics")
128130
assert attached_fn.function_name == "statistics"
129131
assert attached_fn.output_collection == "my_collection_statistics"
130132

131133
# Wait for statistics to be computed
132134
wait_for_version_increase(client, collection.name, initial_version)
133135

134-
stats = get_statistics(collection)
136+
stats = get_statistics(collection, "my_collection_statistics")
135137
assert stats is not None
136138
assert "statistics" in stats
137139
assert "summary" in stats
@@ -190,7 +192,7 @@ def test_statistics_wrapper_custom_output_collection(basic_http_client: System)
190192
wait_for_version_increase(client, collection.name, initial_version)
191193

192194
# Get statistics
193-
stats = get_statistics(collection)
195+
stats = get_statistics(collection, "my_custom_stats")
194196
assert "statistics" in stats
195197
assert "key" in stats["statistics"]
196198

@@ -206,7 +208,7 @@ def test_statistics_wrapper_key_filter(basic_http_client: System) -> None:
206208
collection = client.create_collection(name="key_filter_test")
207209

208210
# Enable statistics
209-
attach_statistics_function(collection)
211+
attach_statistics_function(collection, "key_filter_test_statistics")
210212

211213
initial_version = get_collection_version(client, collection.name)
212214

@@ -225,13 +227,15 @@ def test_statistics_wrapper_key_filter(basic_http_client: System) -> None:
225227
time.sleep(60)
226228

227229
# Get all statistics (no key filter)
228-
all_stats = get_statistics(collection)
230+
all_stats = get_statistics(collection, "key_filter_test_statistics")
229231
assert "category" in all_stats["statistics"]
230232
assert "score" in all_stats["statistics"]
231233
assert "active" in all_stats["statistics"]
232234

233235
# Get statistics filtered by "category" key only
234-
category_stats = get_statistics(collection, key="category")
236+
category_stats = get_statistics(
237+
collection, "key_filter_test_statistics", key="category"
238+
)
235239
assert "category" in category_stats["statistics"]
236240
assert "score" not in category_stats["statistics"]
237241
assert "active" not in category_stats["statistics"]
@@ -242,7 +246,7 @@ def test_statistics_wrapper_key_filter(basic_http_client: System) -> None:
242246
assert category_stats["summary"]["total_count"] == 3
243247

244248
# Get statistics filtered by "score" key only
245-
score_stats = get_statistics(collection, key="score")
249+
score_stats = get_statistics(collection, "key_filter_test_statistics", key="score")
246250
assert "score" in score_stats["statistics"]
247251
assert "category" not in score_stats["statistics"]
248252
assert "active" not in score_stats["statistics"]
@@ -263,7 +267,7 @@ def test_statistics_wrapper_incremental_updates(basic_http_client: System) -> No
263267
client.reset()
264268

265269
collection = client.create_collection(name="incremental_test")
266-
attach_statistics_function(collection)
270+
attach_statistics_function(collection, "incremental_test_statistics")
267271

268272
initial_version = get_collection_version(client, collection.name)
269273

@@ -278,7 +282,7 @@ def test_statistics_wrapper_incremental_updates(basic_http_client: System) -> No
278282
next_version = get_collection_version(client, collection.name)
279283

280284
# Check initial statistics
281-
stats = get_statistics(collection)
285+
stats = get_statistics(collection, "incremental_test_statistics")
282286
assert stats["statistics"]["category"]["A"]["count"] == 2
283287
assert stats["summary"]["total_count"] == 2
284288

@@ -295,7 +299,7 @@ def test_statistics_wrapper_incremental_updates(basic_http_client: System) -> No
295299
time.sleep(70)
296300

297301
# Check updated statistics
298-
stats = get_statistics(collection)
302+
stats = get_statistics(collection, "incremental_test_statistics")
299303
assert stats["statistics"]["category"]["A"]["count"] == 3
300304
assert stats["statistics"]["category"]["B"]["count"] == 1
301305
assert stats["summary"]["total_count"] == 4
@@ -333,14 +337,14 @@ def test_sparse_vector_statistics(basic_http_client: System) -> None:
333337
{"category": "A", "vec": sparse_vec3},
334338
],
335339
)
336-
attach_statistics_function(collection)
340+
attach_statistics_function(collection, "sparse_vector_test1_statistics")
337341

338342
initial_version = get_collection_version(client, collection.name)
339343

340344
wait_for_version_increase(client, collection.name, initial_version)
341345

342346
# Get statistics
343-
stats = get_statistics(collection)
347+
stats = get_statistics(collection, "sparse_vector_test1_statistics")
344348
print("\nSparse vector statistics output:")
345349
print(json.dumps(stats, indent=2))
346350

@@ -385,54 +389,54 @@ def test_statistics_high_cardinality(basic_http_client: System) -> None:
385389
num_fields = 10
386390
ids = [f"id{i}" for i in range(num_docs)]
387391
documents = [f"doc{i}" for i in range(num_docs)]
388-
389-
metadatas = []
392+
393+
metadatas: list[dict[str, Any]] = []
390394
for i in range(num_docs):
391-
meta = {}
395+
meta: dict[str, Any] = {}
392396
for j in range(num_fields):
393397
meta[f"field_{j}"] = f"value_{j}_{i}"
394398
metadatas.append(meta)
395399

396400
# Add in batches to avoid hitting request size limits
397401
batch_size = 100
398402
initial_version = get_collection_version(client, collection.name)
399-
403+
400404
for i in range(0, num_docs, batch_size):
401405
collection.add(
402406
ids=ids[i : i + batch_size],
403407
documents=documents[i : i + batch_size],
404-
metadatas=metadatas[i : i + batch_size],
408+
metadatas=metadatas[i : i + batch_size], # type: ignore[arg-type]
405409
)
406410

407411
# Let all data be compacted
408412
wait_for_version_increase(client, collection.name, initial_version)
409413
initial_version = get_collection_version(client, collection.name)
410414

411415
# Enable statistics
412-
attach_statistics_function(collection)
416+
attach_statistics_function(collection, "high_cardinality_test_statistics")
413417

414418
# Wait for statistics to be computed
415419
wait_for_version_increase(client, collection.name, initial_version)
416420

417421
# Get statistics
418-
stats = get_statistics(collection)
419-
422+
stats = get_statistics(collection, "high_cardinality_test_statistics")
423+
420424
assert "statistics" in stats
421-
425+
422426
# Verify we have stats for all fields
423427
for j in range(num_fields):
424428
field_key = f"field_{j}"
425429
assert field_key in stats["statistics"]
426-
430+
427431
field_stats = stats["statistics"][field_key]
428432
assert len(field_stats) == num_docs
429-
433+
430434
# Verify each value has count 1
431435
for i in range(num_docs):
432436
value = f"value_{j}_{i}"
433437
assert value in field_stats
434438
assert field_stats[value]["count"] == 1
435-
439+
436440
# Verify total count
437441
assert stats["summary"]["total_count"] == num_docs
438442

chromadb/utils/statistics.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
>>> client = chromadb.Client()
1212
>>> collection = client.get_or_create_collection("my_collection")
1313
>>>
14-
>>> # Attach statistics function
15-
>>> attach_statistics_function(collection)
14+
>>> # Attach statistics function with output collection name
15+
>>> attach_statistics_function(collection, "my_collection_statistics")
1616
>>>
1717
>>> # Add some data
1818
>>> collection.add(
@@ -21,8 +21,8 @@
2121
... metadatas=[{"category": "A"}, {"category": "B"}]
2222
... )
2323
>>>
24-
>>> # Get statistics
25-
>>> stats = get_statistics(collection)
24+
>>> # Get statistics from the named output collection
25+
>>> stats = get_statistics(collection, "my_collection_statistics")
2626
>>> print(stats)
2727
"""
2828

@@ -49,7 +49,7 @@ def get_statistics_fn_name(collection: "Collection") -> str:
4949

5050

5151
def attach_statistics_function(
52-
collection: "Collection", stats_collection_name: Optional[str] = None
52+
collection: "Collection", stats_collection_name: str
5353
) -> "AttachedFunction":
5454
"""Attach statistics collection function to a collection.
5555
@@ -60,20 +60,16 @@ def attach_statistics_function(
6060
Args:
6161
collection: The collection to enable statistics for
6262
stats_collection_name: Name of the collection where statistics will be stored.
63-
If None, defaults to "{collection_name}_statistics".
6463
6564
Returns:
6665
AttachedFunction: The attached statistics function
6766
6867
Example:
69-
>>> attach_statistics_function(collection)
68+
>>> attach_statistics_function(collection, "my_collection_statistics")
7069
>>> collection.add(ids=["id1"], documents=["doc1"], metadatas=[{"key": "value"}])
7170
>>> # Statistics are automatically computed
72-
>>> stats = get_statistics(collection)
71+
>>> stats = get_statistics(collection, "my_collection_statistics")
7372
"""
74-
if stats_collection_name is None:
75-
stats_collection_name = f"{collection.name}_statistics"
76-
7773
return collection.attach_function(
7874
name=get_statistics_fn_name(collection),
7975
function_id="statistics",
@@ -125,7 +121,7 @@ def detach_statistics_function(
125121

126122

127123
def get_statistics(
128-
collection: "Collection", key: Optional[str] = None
124+
collection: "Collection", stats_collection_name: str, key: Optional[str] = None
129125
) -> Dict[str, Any]:
130126
"""Get the current statistics for a collection.
131127
@@ -134,6 +130,7 @@ def get_statistics(
134130
135131
Args:
136132
collection: The collection to get statistics for
133+
stats_collection_name: Name of the statistics collection to read from.
137134
key: Optional metadata key to filter statistics for. If provided,
138135
only returns statistics for that specific key.
139136
@@ -154,14 +151,14 @@ def get_statistics(
154151
}
155152
156153
Example:
157-
>>> attach_statistics_function(collection)
154+
>>> attach_statistics_function(collection, "my_collection_statistics")
158155
>>> collection.add(
159156
... ids=["id1", "id2"],
160157
... documents=["doc1", "doc2"],
161158
... metadatas=[{"category": "A", "score": 10}, {"category": "B", "score": 10}]
162159
... )
163160
>>> # Wait for statistics to be computed
164-
>>> stats = get_statistics(collection)
161+
>>> stats = get_statistics(collection, "my_collection_statistics")
165162
>>> print(stats)
166163
{
167164
"statistics": {
@@ -181,11 +178,9 @@ def get_statistics(
181178
# Import here to avoid circular dependency
182179
from chromadb.api.models.Collection import Collection
183180

184-
af = get_statistics_fn(collection)
185-
186181
# Get the statistics output collection model from the server
187182
stats_collection_model = collection._client.get_collection(
188-
name=af.output_collection,
183+
name=stats_collection_name,
189184
tenant=collection.tenant,
190185
database=collection.database,
191186
)

0 commit comments

Comments
 (0)