Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ann_benchmarks/algorithms/elasticsearch/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ path.logs: /usr/share/elasticsearch/logs\n\
bootstrap.memory_lock: true\n\
thread_pool.write.size: 1\n\
thread_pool.search.size: 1\n\
thread_pool.search.queue_size: 1\n\
thread_pool.search.queue_size: 1000\n\
xpack.security.enabled: false\n\
' > config/elasticsearch.yml

Expand Down
27 changes: 25 additions & 2 deletions ann_benchmarks/algorithms/elasticsearch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,31 @@ def query(self, q, n):
)
return [int(h["fields"]["id"][0]) for h in res["hits"]["hits"]]

def batch_query(self, X, n):
self.batch_res = [self.query(q, n) for q in X]
def sub_batch_query(self, X, n):
msearch_requests = []
header = {"index": self.index_name}
for q in X:
body = {
"knn": {
"field": "vec",
"query_vector": q.tolist(),
"k": n,
"num_candidates": self.num_candidates,
}
}
msearch_requests.extend([header, body])
response = self.client.msearch(body=msearch_requests,request_timeout=10*len(X))
batch_res = [
[int(hit["_source"]["id"]) for hit in resp.get("hits", {}).get("hits", [])]
for resp in response["responses"]
]
return batch_res

def batch_query(self, X, n, sub_batch_size=1000):
sub_batches = [X[i:i + sub_batch_size] for i in range(0, len(X), sub_batch_size)]
self.batch_res = [self.sub_batch_query(sub_X, n) for sub_X in sub_batches]
self.batch_res = [item for sublist in self.batch_res for item in sublist]
return self.batch_res

def get_batch_results(self):
return self.batch_res
Expand Down