|
16 | 16 | #include <thread> |
17 | 17 | #include <VecSim/algorithms/hnsw/hnsw_single.h> |
18 | 18 | #include <VecSim/algorithms/brute_force/brute_force_single.h> |
19 | | -#include "tiered_index_mock.h" |
| 19 | +#include "mock_thread_pool.h" |
20 | 20 |
|
21 | 21 | namespace py = pybind11; |
22 | | -using namespace tiered_index_mock; |
23 | 22 |
|
24 | 23 | // Helper function that iterates query results and wrap them in python numpy object - |
25 | 24 | // a tuple of two 2D arrays: (labels, distances) |
@@ -366,92 +365,64 @@ class PyHNSWLibIndex : public PyVecSimIndex { |
366 | 365 | }; |
367 | 366 |
|
368 | 367 | class PyTieredIndex : public PyVecSimIndex { |
369 | | -private: |
| 368 | +protected: |
| 369 | + tieredIndexMock mock_thread_pool; |
| 370 | + |
370 | 371 | VecSimIndexAbstract<float> *getFlatBuffer() { |
371 | 372 | return reinterpret_cast<VecSimTieredIndex<float, float> *>(this->index.get()) |
372 | | - ->getFlatbufferIndex(); |
| 373 | + ->getFlatBufferIndex(); |
373 | 374 | } |
374 | 375 |
|
375 | | -protected: |
376 | | - JobQueue jobQueue; // External queue that holds the jobs. |
377 | | - IndexExtCtx jobQueueCtx; // External context to be sent to the submit callback. |
378 | | - SubmitCB submitCb; // A callback that submits an array of jobs into a given jobQueue. |
379 | | - size_t flatBufferLimit; // Maximum size allowed for the flat buffer. If flat buffer is full, use |
380 | | - // in-place insertion. |
381 | | - bool run_thread; |
382 | | - std::bitset<MAX_POOL_SIZE> executions_status; |
383 | | - |
384 | | - TieredIndexParams TieredIndexParams_Init() { |
385 | | - TieredIndexParams ret = { |
386 | | - .jobQueue = &this->jobQueue, |
387 | | - .jobQueueCtx = &this->jobQueueCtx, |
388 | | - .submitCb = this->submitCb, |
389 | | - .flatBufferLimit = this->flatBufferLimit, |
| 376 | + TieredIndexParams getTieredIndexParams(size_t buffer_limit) { |
| 377 | + // Create TieredIndexParams using the mock thread pool. |
| 378 | + return TieredIndexParams{ |
| 379 | + .jobQueue = &(this->mock_thread_pool.jobQ), |
| 380 | + .jobQueueCtx = this->mock_thread_pool.ctx, |
| 381 | + .submitCb = tieredIndexMock::submit_callback, |
| 382 | + .flatBufferLimit = buffer_limit, |
390 | 383 | }; |
391 | | - |
392 | | - return ret; |
393 | 384 | } |
394 | 385 |
|
395 | 386 | public: |
396 | | - explicit PyTieredIndex(size_t BufferLimit = 3000000) |
397 | | - : submitCb(submit_callback), flatBufferLimit(BufferLimit), run_thread(true) { |
398 | | - |
399 | | - for (size_t i = 0; i < THREAD_POOL_SIZE; i++) { |
400 | | - ThreadParams params(run_thread, executions_status, i, jobQueue); |
401 | | - thread_pool.emplace_back(thread_main_loop, params); |
402 | | - } |
403 | | - } |
404 | | - |
405 | | - virtual ~PyTieredIndex() = 0; |
| 387 | + explicit PyTieredIndex() { mock_thread_pool.init_threads(); } |
406 | 388 |
|
407 | 389 | void WaitForIndex(size_t waiting_duration = 10) { |
408 | | - bool keep_wating = true; |
409 | | - while (keep_wating) { |
410 | | - std::this_thread::sleep_for(std::chrono::milliseconds(waiting_duration)); |
411 | | - std::unique_lock<std::mutex> lock(queue_guard); |
412 | | - if (jobQueue.empty()) { |
413 | | - while (true) { |
414 | | - if (executions_status.count() == 0) { |
415 | | - keep_wating = false; |
416 | | - break; |
417 | | - } |
418 | | - std::this_thread::sleep_for(std::chrono::milliseconds(waiting_duration)); |
419 | | - } |
420 | | - } |
421 | | - } |
| 390 | + mock_thread_pool.thread_pool_wait(waiting_duration); |
422 | 391 | } |
423 | 392 |
|
424 | 393 | size_t getFlatIndexSize() { return getFlatBuffer()->indexLabelCount(); } |
425 | 394 |
|
426 | | - static size_t GetThreadsNum() { return THREAD_POOL_SIZE; } |
| 395 | + size_t getThreadsNum() { return mock_thread_pool.thread_pool_size; } |
427 | 396 |
|
428 | | - size_t getBufferLimit() { return flatBufferLimit; } |
| 397 | + size_t getBufferLimit() { |
| 398 | + return reinterpret_cast<VecSimTieredIndex<float, float> *>(this->index.get()) |
| 399 | + ->getFlatBufferLimit(); |
| 400 | + } |
429 | 401 | }; |
430 | 402 |
|
431 | | -PyTieredIndex::~PyTieredIndex() { thread_pool_terminate(jobQueue, run_thread); } |
432 | 403 | class PyTiered_HNSWIndex : public PyTieredIndex { |
433 | 404 | public: |
434 | 405 | explicit PyTiered_HNSWIndex(const HNSWParams &hnsw_params, |
435 | | - const TieredHNSWParams &tiered_hnsw_params) { |
| 406 | + const TieredHNSWParams &tiered_hnsw_params, size_t buffer_limit) { |
436 | 407 |
|
437 | 408 | // Create primaryIndexParams and specific params for hnsw tiered index. |
438 | 409 | VecSimParams primary_index_params = {.algo = VecSimAlgo_HNSWLIB, |
439 | 410 | .algoParams = {.hnswParams = HNSWParams{hnsw_params}}}; |
440 | 411 |
|
441 | | - // create TieredIndexParams |
442 | | - TieredIndexParams tiered_params = TieredIndexParams_Init(); |
443 | | - |
| 412 | + auto tiered_params = this->getTieredIndexParams(buffer_limit); |
444 | 413 | tiered_params.primaryIndexParams = &primary_index_params; |
445 | 414 | tiered_params.specificParams.tieredHnswParams = tiered_hnsw_params; |
446 | 415 |
|
447 | | - // create VecSimParams for TieredIndexParams |
| 416 | + // Create VecSimParams for TieredIndexParams |
448 | 417 | VecSimParams params = {.algo = VecSimAlgo_TIERED, |
449 | 418 | .algoParams = {.tieredParams = TieredIndexParams{tiered_params}}}; |
450 | 419 |
|
451 | 420 | this->index = std::shared_ptr<VecSimIndex>(VecSimIndex_New(¶ms), VecSimIndex_Free); |
| 421 | + |
452 | 422 | // Set the created tiered index in the index external context. |
453 | | - this->jobQueueCtx.index_strong_ref = this->index; |
| 423 | + this->mock_thread_pool.ctx->index_strong_ref = this->index; |
454 | 424 | } |
| 425 | + |
455 | 426 | size_t HNSWLabelCount() { |
456 | 427 | return this->index->info().tieredInfo.backendCommonInfo.indexLabelCount; |
457 | 428 | } |
@@ -568,17 +539,17 @@ PYBIND11_MODULE(VecSim, m) { |
568 | 539 | py::arg("radius"), py::arg("query_param") = nullptr, py::arg("num_threads") = -1); |
569 | 540 |
|
570 | 541 | py::class_<PyTieredIndex, PyVecSimIndex>(m, "TieredIndex") |
571 | | - .def("wait_for_index", &PyTiered_HNSWIndex::WaitForIndex, py::arg("waiting_duration") = 10) |
572 | | - .def("get_curr_bf_size", &PyTiered_HNSWIndex::getFlatIndexSize) |
573 | | - .def("get_buffer_limit", &PyTiered_HNSWIndex::getBufferLimit) |
574 | | - .def_static("get_threads_num", &PyTieredIndex::GetThreadsNum); |
| 542 | + .def("wait_for_index", &PyTieredIndex::WaitForIndex, py::arg("waiting_duration") = 10) |
| 543 | + .def("get_curr_bf_size", &PyTieredIndex::getFlatIndexSize) |
| 544 | + .def("get_buffer_limit", &PyTieredIndex::getBufferLimit) |
| 545 | + .def("get_threads_num", &PyTieredIndex::getThreadsNum); |
575 | 546 |
|
576 | 547 | py::class_<PyTiered_HNSWIndex, PyTieredIndex>(m, "Tiered_HNSWIndex") |
577 | | - .def( |
578 | | - py::init([](const HNSWParams &hnsw_params, const TieredHNSWParams &tiered_hnsw_params) { |
579 | | - return new PyTiered_HNSWIndex(hnsw_params, tiered_hnsw_params); |
580 | | - }), |
581 | | - py::arg("hnsw_params"), py::arg("tiered_hnsw_params")) |
| 548 | + .def(py::init([](const HNSWParams &hnsw_params, const TieredHNSWParams &tiered_hnsw_params, |
| 549 | + size_t flat_buffer_size = DEFAULT_BLOCK_SIZE) { |
| 550 | + return new PyTiered_HNSWIndex(hnsw_params, tiered_hnsw_params, flat_buffer_size); |
| 551 | + }), |
| 552 | + py::arg("hnsw_params"), py::arg("tiered_hnsw_params"), py::arg("flat_buffer_size")) |
582 | 553 | .def("hnsw_label_count", &PyTiered_HNSWIndex::HNSWLabelCount); |
583 | 554 |
|
584 | 555 | py::class_<PyBFIndex, PyVecSimIndex>(m, "BFIndex") |
|
0 commit comments