Skip to content

Commit 39b210d

Browse files
authored
Merge pull request #494 from dyashuni/get_items_numpy
Fix get_items
2 parents f6d170c + db19931 commit 39b210d

File tree

4 files changed

+23
-8
lines changed

4 files changed

+23
-8
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib.
7979

8080
* `set_num_threads(num_threads)` set the default number of cpu threads used during data insertion/querying.
8181

82-
* `get_items(ids)` - returns a numpy array (shape:`N*dim`) of vectors that have integer identifiers specified in `ids` numpy vector (shape:`N`). Note that for cosine similarity it currently returns **normalized** vectors.
82+
* `get_items(ids, return_type = 'numpy')` - returns a numpy array (shape:`N*dim`) of vectors that have integer identifiers specified in `ids` numpy vector (shape:`N`) if `return_type` is `list` return list of lists. Note that for cosine similarity it currently returns **normalized** vectors.
8383

8484
* `get_ids_list()` - returns a list of all elements' ids.
8585

python_bindings/bindings.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,11 @@ class Index {
304304
}
305305

306306

307-
std::vector<std::vector<data_t>> getDataReturnList(py::object ids_ = py::none()) {
307+
py::object getData(py::object ids_ = py::none(), std::string return_type = "numpy") {
308+
std::vector<std::string> return_types{"numpy", "list"};
309+
if (std::find(std::begin(return_types), std::end(return_types), return_type) == std::end(return_types)) {
310+
throw std::invalid_argument("return_type should be \"numpy\" or \"list\"");
311+
}
308312
std::vector<size_t> ids;
309313
if (!ids_.is_none()) {
310314
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
@@ -325,7 +329,12 @@ class Index {
325329
for (auto id : ids) {
326330
data.push_back(appr_alg->template getDataByLabel<data_t>(id));
327331
}
328-
return data;
332+
if (return_type == "list") {
333+
return py::cast(data);
334+
}
335+
if (return_type == "numpy") {
336+
return py::array_t< data_t, py::array::c_style | py::array::forcecast >(py::cast(data));
337+
}
329338
}
330339

331340

@@ -925,7 +934,7 @@ PYBIND11_PLUGIN(hnswlib) {
925934
py::arg("ids") = py::none(),
926935
py::arg("num_threads") = -1,
927936
py::arg("replace_deleted") = false)
928-
.def("get_items", &Index<float, float>::getDataReturnList, py::arg("ids") = py::none())
937+
.def("get_items", &Index<float>::getData, py::arg("ids") = py::none(), py::arg("return_type") = "numpy")
929938
.def("get_ids_list", &Index<float>::getIdsList)
930939
.def("set_ef", &Index<float>::set_ef, py::arg("ef"))
931940
.def("set_num_threads", &Index<float>::set_num_threads, py::arg("num_threads"))

tests/python/bindings_test_getdata.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,11 @@ def testGettingItems(self):
4545
self.assertRaises(ValueError, lambda: p.get_items(labels[0]))
4646

4747
# After adding them, all labels should be retrievable
48-
returned_items = p.get_items(labels)
49-
self.assertSequenceEqual(data.tolist(), returned_items)
48+
returned_items_np = p.get_items(labels)
49+
self.assertTrue((data == returned_items_np).all())
50+
51+
# check returned type of get_items
52+
self.assertTrue(isinstance(returned_items_np, np.ndarray))
53+
returned_items_list = p.get_items(labels, return_type="list")
54+
self.assertTrue(isinstance(returned_items_list, list))
55+
self.assertTrue(isinstance(returned_items_list[0], list))

tests/python/bindings_test_replace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ def testRandomSelf(self):
9494
remaining_data = comb_data[remaining_labels_list]
9595

9696
returned_items = hnsw_index.get_items(remaining_labels_list)
97-
self.assertSequenceEqual(remaining_data.tolist(), returned_items)
97+
self.assertTrue((remaining_data == returned_items).all())
9898

9999
returned_items = hnsw_index.get_items(labels3_tr)
100-
self.assertSequenceEqual(data3_tr.tolist(), returned_items)
100+
self.assertTrue((data3_tr == returned_items).all())
101101

102102
# Check index serialization
103103
# Delete batch 3

0 commit comments

Comments
 (0)