Skip to content

Commit d05aed4

Browse files
committed
[ENH]: Allow specifiying multiple filter keys in get_statistics
1 parent cb626b1 commit d05aed4

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

chromadb/test/distributed/test_statistics_wrapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def test_statistics_wrapper_key_filter(basic_http_client: System) -> None:
234234

235235
# Get statistics filtered by "category" key only
236236
category_stats = get_statistics(
237-
collection, "key_filter_test_statistics", key="category"
237+
collection, "key_filter_test_statistics", keys=["category"]
238238
)
239239
assert "category" in category_stats["statistics"]
240240
assert "score" not in category_stats["statistics"]
@@ -246,7 +246,9 @@ def test_statistics_wrapper_key_filter(basic_http_client: System) -> None:
246246
assert category_stats["summary"]["total_count"] == 3
247247

248248
# Get statistics filtered by "score" key only
249-
score_stats = get_statistics(collection, "key_filter_test_statistics", key="score")
249+
score_stats = get_statistics(
250+
collection, "key_filter_test_statistics", keys=["score"]
251+
)
250252
assert "score" in score_stats["statistics"]
251253
assert "category" not in score_stats["statistics"]
252254
assert "active" not in score_stats["statistics"]

chromadb/utils/statistics.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
>>> print(stats)
2727
"""
2828

29-
from typing import TYPE_CHECKING, Optional, Dict, Any, cast
29+
from typing import TYPE_CHECKING, Optional, Dict, Any, List, cast
3030
from collections import defaultdict
3131

3232
from chromadb.api.types import Where
@@ -121,7 +121,9 @@ def detach_statistics_function(
121121

122122

123123
def get_statistics(
124-
collection: "Collection", stats_collection_name: str, key: Optional[str] = None
124+
collection: "Collection",
125+
stats_collection_name: str,
126+
keys: Optional[List[str]] = None,
125127
) -> Dict[str, Any]:
126128
"""Get the current statistics for a collection.
127129
@@ -131,8 +133,8 @@ def get_statistics(
131133
Args:
132134
collection: The collection to get statistics for
133135
stats_collection_name: Name of the statistics collection to read from.
134-
key: Optional metadata key to filter statistics for. If provided,
135-
only returns statistics for that specific key.
136+
keys: Optional list of metadata keys to filter statistics for. If provided,
137+
only returns statistics for those specific keys.
136138
137139
Returns:
138140
Dict[str, Any]: A dictionary with the structure:
@@ -198,11 +200,9 @@ def get_statistics(
198200
summary: Dict[str, Any] = {}
199201

200202
offset = 0
201-
# When filtering by key, also include "summary" entries to get total_count
203+
# When filtering by keys, also include "summary" entries to get total_count
202204
where_filter: Optional[Where] = (
203-
cast(Where, {"$or": [{"key": key}, {"key": "summary"}]})
204-
if key is not None
205-
else None
205+
cast(Where, {"key": {"$in": keys + ["summary"]}}) if keys is not None else None
206206
)
207207

208208
while True:

0 commit comments

Comments
 (0)