Skip to content

Commit 4f7b192

Browse files
committed
get_items return numpy array
1 parent f6d170c commit 4f7b192

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib.
7979

8080
* `set_num_threads(num_threads)` set the default number of cpu threads used during data insertion/querying.
8181

82-
* `get_items(ids)` - returns a numpy array (shape:`N*dim`) of vectors that have integer identifiers specified in `ids` numpy vector (shape:`N`). Note that for cosine similarity it currently returns **normalized** vectors.
82+
* `get_items(ids, return_type = 'numpy')` - returns a numpy array (shape:`N*dim`) of vectors that have integer identifiers specified in `ids` numpy vector (shape:`N`) if `return_type` is `list` return list of lists. Note that for cosine similarity it currently returns **normalized** vectors.
8383

8484
* `get_ids_list()` - returns a list of all elements' ids.
8585

python_bindings/bindings.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,11 @@ class Index {
304304
}
305305

306306

307-
std::vector<std::vector<data_t>> getDataReturnList(py::object ids_ = py::none()) {
307+
py::object getData(py::object ids_ = py::none(), std::string return_type = "numpy") {
308+
std::vector<std::string> return_types{"numpy", "list"};
309+
if (std::find(std::begin(return_types), std::end(return_types), return_type) == std::end(return_types)) {
310+
throw std::invalid_argument("return_type should be \"numpy\" or \"list\"");
311+
}
308312
std::vector<size_t> ids;
309313
if (!ids_.is_none()) {
310314
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
@@ -325,7 +329,12 @@ class Index {
325329
for (auto id : ids) {
326330
data.push_back(appr_alg->template getDataByLabel<data_t>(id));
327331
}
328-
return data;
332+
if (return_type == "list") {
333+
return py::cast(data);
334+
}
335+
if (return_type == "numpy") {
336+
return py::array_t< data_t, py::array::c_style | py::array::forcecast >(py::cast(data));
337+
}
329338
}
330339

331340

@@ -925,7 +934,7 @@ PYBIND11_PLUGIN(hnswlib) {
925934
py::arg("ids") = py::none(),
926935
py::arg("num_threads") = -1,
927936
py::arg("replace_deleted") = false)
928-
.def("get_items", &Index<float, float>::getDataReturnList, py::arg("ids") = py::none())
937+
.def("get_items", &Index<float>::getData, py::arg("ids") = py::none(), py::arg("return_type") = "numpy")
929938
.def("get_ids_list", &Index<float>::getIdsList)
930939
.def("set_ef", &Index<float>::set_ef, py::arg("ef"))
931940
.def("set_num_threads", &Index<float>::set_num_threads, py::arg("num_threads"))

0 commit comments

Comments
 (0)