@@ -719,6 +719,7 @@ class BFIndex {
719719 int dim;
720720 bool index_inited;
721721 bool normalize;
722+ int num_threads_default;
722723
723724 hnswlib::labeltype cur_l;
724725 hnswlib::BruteforceSearch<dist_t >* alg;
@@ -739,6 +740,8 @@ class BFIndex {
739740 }
740741 alg = NULL ;
741742 index_inited = false ;
743+
744+ num_threads_default = std::thread::hardware_concurrency ();
742745 }
743746
744747
@@ -749,6 +752,21 @@ class BFIndex {
749752 }
750753
751754
755+ size_t getMaxElements () const {
756+ return alg->maxelements_ ;
757+ }
758+
759+
760+ size_t getCurrentCount () const {
761+ return alg->cur_element_count ;
762+ }
763+
764+
765+ void set_num_threads (int num_threads) {
766+ this ->num_threads_default = num_threads;
767+ }
768+
769+
752770 void init_new_index (const size_t maxElements) {
753771 if (alg) {
754772 throw std::runtime_error (" The index is already initiated." );
@@ -820,15 +838,19 @@ class BFIndex {
820838 py::object knnQuery_return_numpy (
821839 py::object input,
822840 size_t k = 1 ,
841+ int num_threads = -1 ,
823842 const std::function<bool (hnswlib::labeltype)>& filter = nullptr) {
824843 py::array_t < dist_t , py::array::c_style | py::array::forcecast > items (input);
825844 auto buffer = items.request ();
826845 hnswlib::labeltype *data_numpy_l;
827846 dist_t *data_numpy_d;
828847 size_t rows, features;
848+
849+ if (num_threads <= 0 )
850+ num_threads = num_threads_default;
851+
829852 {
830853 py::gil_scoped_release l;
831-
832854 get_input_array_shapes (buffer, &rows, &features);
833855
834856 data_numpy_l = new hnswlib::labeltype[rows * k];
@@ -837,16 +859,16 @@ class BFIndex {
837859 CustomFilterFunctor idFilter (filter);
838860 CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr ;
839861
840- for ( size_t row = 0 ; row < rows; row++ ) {
862+ ParallelFor ( 0 , rows, num_threads, [&]( size_t row, size_t threadId ) {
841863 std::priority_queue<std::pair<dist_t , hnswlib::labeltype >> result = alg->searchKnn (
842- (void *) items.data (row), k, p_idFilter);
864+ (void *) items.data (row), k, p_idFilter);
843865 for (int i = k - 1 ; i >= 0 ; i--) {
844- auto & result_tuple = result.top ();
866+ auto & result_tuple = result.top ();
845867 data_numpy_d[row * k + i] = result_tuple.first ;
846868 data_numpy_l[row * k + i] = result_tuple.second ;
847869 result.pop ();
848870 }
849- }
871+ });
850872 }
851873
852874 py::capsule free_when_done_l (data_numpy_l, [](void *f) {
@@ -957,13 +979,22 @@ PYBIND11_PLUGIN(hnswlib) {
957979 py::class_<BFIndex<float >>(m, " BFIndex" )
958980 .def (py::init<const std::string &, const int >(), py::arg (" space" ), py::arg (" dim" ))
959981 .def (" init_index" , &BFIndex<float >::init_new_index, py::arg (" max_elements" ))
960- .def (" knn_query" , &BFIndex<float >::knnQuery_return_numpy, py::arg (" data" ), py::arg (" k" ) = 1 , py::arg (" filter" ) = py::none ())
982+ .def (" knn_query" ,
983+ &BFIndex<float >::knnQuery_return_numpy,
984+ py::arg (" data" ),
985+ py::arg (" k" ) = 1 ,
986+ py::arg (" num_threads" ) = -1 ,
987+ py::arg (" filter" ) = py::none ())
961988 .def (" add_items" , &BFIndex<float >::addItems, py::arg (" data" ), py::arg (" ids" ) = py::none ())
962989 .def (" delete_vector" , &BFIndex<float >::deleteVector, py::arg (" label" ))
990+ .def (" set_num_threads" , &BFIndex<float >::set_num_threads, py::arg (" num_threads" ))
963991 .def (" save_index" , &BFIndex<float >::saveIndex, py::arg (" path_to_index" ))
964992 .def (" load_index" , &BFIndex<float >::loadIndex, py::arg (" path_to_index" ), py::arg (" max_elements" ) = 0 )
965993 .def (" __repr__" , [](const BFIndex<float > &a) {
966994 return " <hnswlib.BFIndex(space='" + a.space_name + " ', dim=" +std::to_string (a.dim )+" )>" ;
967- });
995+ })
996+ .def (" get_max_elements" , &BFIndex<float >::getMaxElements)
997+ .def (" get_current_count" , &BFIndex<float >::getCurrentCount)
998+ .def_readwrite (" num_threads" , &BFIndex<float >::num_threads_default);
968999 return m.ptr ();
9691000}
0 commit comments