diff --git a/CMakeLists.txt b/CMakeLists.txt index 43e853e..5723748 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,6 +86,7 @@ add_extension_if_enabled_and_skip_32bit("delta") add_extension_if_enabled_and_skip_32bit("iceberg") add_extension_if_enabled_and_skip_32bit("azure") add_extension_if_enabled_and_skip_32bit("unity_catalog") +add_extension_if_enabled_and_skip_32bit("lance") add_extension_if_enabled("json") add_extension_if_enabled("fts") add_extension_if_enabled("vector") diff --git a/extension_config.cmake b/extension_config.cmake index b182655..437a58b 100644 --- a/extension_config.cmake +++ b/extension_config.cmake @@ -1,4 +1,4 @@ -set(EXTENSION_LIST adbc azure delta duckdb fts httpfs iceberg json llm postgres sqlite unity_catalog vector neo4j algo) +set(EXTENSION_LIST adbc azure delta duckdb fts httpfs iceberg json lance llm postgres sqlite unity_catalog vector neo4j algo) #set(EXTENSION_STATIC_LINK_LIST fts) foreach(extension IN LISTS EXTENSION_STATIC_LINK_LIST) diff --git a/lance/CMakeLists.txt b/lance/CMakeLists.txt new file mode 100644 index 0000000..92da0ad --- /dev/null +++ b/lance/CMakeLists.txt @@ -0,0 +1,65 @@ +cmake_minimum_required(VERSION 3.22) + +# ── lance-c dependency ──────────────────────────────────────────────────────── +# Lance-c lives next to the ladybug repo at ../../lancedb/lance-c. +# We use add_subdirectory so its Corrosion / prebuilt CMake logic runs in-tree. +set(LANCE_C_ROOT "${PROJECT_SOURCE_DIR}/../../lancedb/lance-c" + CACHE PATH "Path to the lance-c source tree") + +if(NOT TARGET LanceC::lance_c) + if(NOT EXISTS "${LANCE_C_ROOT}/CMakeLists.txt") + message(FATAL_ERROR + "lance-c not found at ${LANCE_C_ROOT}. " + "Set LANCE_C_ROOT to the lance-c repository path.") + endif() + # Use prebuilt if LANCE_C_PREBUILT_DIR is set, otherwise build from source + if(LANCE_C_PREBUILT_DIR) + set(LANCE_C_USE_PREBUILT ON CACHE BOOL "" FORCE) + endif() + add_subdirectory(${LANCE_C_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/lance-c EXCLUDE_FROM_ALL) +endif() + +# ── Extension source ────────────────────────────────────────────────────────── +include_directories( + ${PROJECT_SOURCE_DIR}/src/include + ${CMAKE_BINARY_DIR}/src/include + src/include + ${LANCE_C_ROOT}/include +) + +set(LANCE_EXTENSION_OBJECT_FILES + $ +) + +add_library(lance_extension_objects OBJECT + src/lance_extension.cpp + src/lance_node_table.cpp + src/lance_rel_table.cpp + src/lance_functions.cpp +) + +target_include_directories(lance_extension_objects PRIVATE + ${PROJECT_SOURCE_DIR}/src/include + ${CMAKE_BINARY_DIR}/src/include + src/include + ${LANCE_C_ROOT}/include +) + +# ── Build the extension library ─────────────────────────────────────────────── +build_extension_lib(${BUILD_STATIC_EXTENSION} "lance") + +target_link_libraries(lbug_${EXTENSION_LIB_NAME}_extension + PRIVATE + LanceC::lance_c_static +) + +# Platform-specific link dependencies for the Rust-backed lance-c library +if(UNIX AND NOT APPLE) + target_link_libraries(lbug_${EXTENSION_LIB_NAME}_extension + PRIVATE pthread dl m) +endif() + +# ── Tests ───────────────────────────────────────────────────────────────────── +if(BUILD_EXTENSION_TESTS) + add_subdirectory(test) +endif() diff --git a/lance/src/include/lance_extension.h b/lance/src/include/lance_extension.h new file mode 100644 index 0000000..33afae9 --- /dev/null +++ b/lance/src/include/lance_extension.h @@ -0,0 +1,16 @@ +#pragma once + +#include "extension/extension.h" + +namespace lbug { +namespace lance_extension { + +class LanceExtension final : public extension::Extension { +public: + static constexpr char EXTENSION_NAME[] = "LANCE"; + + static void load(main::ClientContext* context); +}; + +} // namespace lance_extension +} // namespace lbug diff --git a/lance/src/include/lance_functions.h b/lance/src/include/lance_functions.h new file mode 100644 index 0000000..03e7e91 --- /dev/null +++ b/lance/src/include/lance_functions.h @@ -0,0 +1,32 @@ +#pragma once + +#include "function/table/table_function.h" + +namespace lbug { +namespace lance_extension { + +using function::function_set; + +/// LANCE_VECTOR_SEARCH(dataset_path, column, query_vector, k [, metric [, nprobes]]) +/// Returns nearest-neighbour rows together with a '_distance' column. +struct LanceVectorSearchFunction { + static constexpr char name[] = "LANCE_VECTOR_SEARCH"; + static function_set getFunctionSet(); +}; + +/// LANCE_FTS(dataset_path, query [, columns [, max_fuzzy_distance]]) +/// Returns full-text search result rows together with a '_score' column. +struct LanceFTSFunction { + static constexpr char name[] = "LANCE_FTS"; + static function_set getFunctionSet(); +}; + +/// LANCE_HYBRID_SEARCH(dataset_path, column, query_vector, k, fts_query) +/// Returns rows matching both vector and full-text criteria with fusion scoring. +struct LanceHybridSearchFunction { + static constexpr char name[] = "LANCE_HYBRID_SEARCH"; + static function_set getFunctionSet(); +}; + +} // namespace lance_extension +} // namespace lbug diff --git a/lance/src/include/lance_node_table.h b/lance/src/include/lance_node_table.h new file mode 100644 index 0000000..2b5d90e --- /dev/null +++ b/lance/src/include/lance_node_table.h @@ -0,0 +1,146 @@ +#pragma once + +#include +#include +#include + +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "common/arrow/arrow.h" +#include "common/exception/runtime.h" +#include "common/types/internal_id_util.h" +#include "storage/table/columnar_node_table_base.h" + +namespace lbug { +namespace lance_extension { + +/// Owns a single Arrow RecordBatch obtained from an ArrowArrayStream. +/// The ArrowArray is released in the destructor; ArrowSchema is NOT owned +/// (it comes from the shared stream schema and is managed by the shared state). +struct LanceBatchData { + ArrowSchema schema; + ArrowArray array; + uint64_t length = 0; + + LanceBatchData() { + std::memset(&schema, 0, sizeof(schema)); + std::memset(&array, 0, sizeof(array)); + } + ~LanceBatchData() { + if (array.release) array.release(&array); + // schema.release is intentionally left null (owned by shared state) + } + + LanceBatchData(const LanceBatchData&) = delete; + LanceBatchData& operator=(const LanceBatchData&) = delete; + LanceBatchData(LanceBatchData&&) = delete; + LanceBatchData& operator=(LanceBatchData&&) = delete; +}; + +/// Per-thread scan state for LanceNodeTable. +struct LanceNodeTableScanState final : storage::ColumnarNodeTableScanState { + std::shared_ptr currentBatch; + uint64_t batchStartGlobalOffset = 0; + uint64_t morselStart = 0; + uint64_t morselEnd = 0; + + LanceNodeTableScanState(storage::MemoryManager& mm, common::ValueVector* nodeIDVector, + std::vector outputVectors, + std::shared_ptr outChunkState) + : storage::ColumnarNodeTableScanState{mm, nodeIDVector, std::move(outputVectors), + std::move(outChunkState)} {} +}; + +/// Shared scan coordination state for LanceNodeTable. +struct LanceNodeTableScanSharedState final : storage::ColumnarNodeTableScanSharedState { + explicit LanceNodeTableScanSharedState(size_t morselSize) : morselSize(morselSize) { + std::memset(&stream_, 0, sizeof(stream_)); + } + + ~LanceNodeTableScanSharedState() override { + if (stream_.release) stream_.release(&stream_); + } + + void reset(ArrowArrayStream newStream); + + bool getNextMorsel(storage::ColumnarNodeTableScanState* scanState) override; + +private: + bool readNextBatch(); + + std::mutex mtx; + ArrowArrayStream stream_; + bool streamExhausted = true; + + std::shared_ptr currentBatch; + uint64_t currentBatchGlobalOffset = 0; + uint64_t currentMorselStart = 0; + const size_t morselSize; + + bool streamSchemaFetched = false; + ArrowSchema streamSchema_; + + LanceNodeTableScanSharedState(const LanceNodeTableScanSharedState&) = delete; + LanceNodeTableScanSharedState& operator=(const LanceNodeTableScanSharedState&) = delete; +}; + +/// A node table backed by a Lance dataset. +class LanceNodeTable final : public storage::ColumnarNodeTableBase { +public: + LanceNodeTable(const storage::StorageManager* storageManager, + const catalog::NodeTableCatalogEntry* nodeTableEntry, + storage::MemoryManager* memoryManager, main::ClientContext* context); + + ~LanceNodeTable() override = default; + + void initializeScanCoordination(const transaction::Transaction* transaction) override; + + void initScanState(transaction::Transaction* transaction, storage::TableScanState& scanState, + bool resetCachedBoundNodeSelVec = true) const override; + + bool scanInternal(transaction::Transaction* transaction, + storage::TableScanState& scanState) override; + + bool requiresExplicitScanInit() const override { return true; } + bool usesMorselScan() const override { return true; } + size_t getNumScanMorsels(const transaction::Transaction* transaction) const override; + + std::unique_ptr createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, + storage::MemoryManager* memoryManager) const override; + + bool isVisible(const transaction::Transaction* transaction, + common::offset_t offset) const override; + bool isVisibleNoLock(const transaction::Transaction* transaction, + common::offset_t offset) const override; + + bool lookupPK(const transaction::Transaction* transaction, common::ValueVector* keyVector, + uint64_t vectorPos, common::offset_t& result) const override; + + const std::string& getLanceDatasetPath() const { return datasetPath; } + +protected: + std::string getColumnarFormatName() const override { return "lance"; } + common::node_group_idx_t getNumBatches( + const transaction::Transaction* transaction) const override; + common::row_idx_t getTotalRowCount(const transaction::Transaction* transaction) const override; + +private: + std::vector getOutputToLanceColumnIdx( + const std::vector& columnIDs) const; + + void copyLanceMorselToOutputVectors(const LanceBatchData& batch, uint64_t morselStart, + uint64_t numRows, const std::vector& outputVectors, + const std::vector& outputToLanceColumnIdx) const; + +private: + std::string datasetPath; + mutable uint64_t cachedTotalRows = common::INVALID_ROW_IDX; + uint32_t numLanceColumns = 0; + ArrowSchema cachedSchema_; + bool schemaCached = false; + + static constexpr size_t kDefaultMorselSize = 2048; +}; + +} // namespace lance_extension +} // namespace lbug diff --git a/lance/src/include/lance_rel_table.h b/lance/src/include/lance_rel_table.h new file mode 100644 index 0000000..9387238 --- /dev/null +++ b/lance/src/include/lance_rel_table.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include +#include +#include + +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/arrow/arrow.h" +#include "common/exception/runtime.h" +#include "common/types/internal_id_util.h" +#include "storage/table/columnar_rel_table_base.h" + +namespace lbug { +namespace lance_extension { + +struct LanceBatchData; + +/// Per-thread scan state for LanceRelTable. +struct LanceRelTableScanState final : storage::RelTableScanState { + std::shared_ptr cachedBatchData; + uint64_t currentBatchStartOffset = 0; + uint64_t currentLocalRowIdx = 0; + + std::unordered_map boundNodeOffsets; + + ArrowArrayStream stream; + bool streamExhausted = false; + bool streamInitialized = false; + + ArrowSchema streamSchema; + bool schemaFetched = false; + + LanceRelTableScanState(storage::MemoryManager& mm, common::ValueVector* nodeIDVector, + std::vector outputVectors, + std::shared_ptr outChunkState); + + ~LanceRelTableScanState() override; + + void setToTable(const transaction::Transaction* transaction, storage::Table* table_, + std::vector columnIDs_, + std::vector columnPredicateSets_, + common::RelDataDirection direction_) override; + + void reset(std::unordered_map boundNodeOffsets_); + + LanceRelTableScanState(const LanceRelTableScanState&) = delete; + LanceRelTableScanState& operator=(const LanceRelTableScanState&) = delete; +}; + +/// A relationship table backed by a Lance dataset. +class LanceRelTable final : public storage::ColumnarRelTableBase { +public: + LanceRelTable(catalog::RelGroupCatalogEntry* relGroupEntry, common::table_id_t fromTableID, + common::table_id_t toTableID, const storage::StorageManager* storageManager, + storage::MemoryManager* memoryManager, main::ClientContext* context); + + ~LanceRelTable() override = default; + + void initScanState(transaction::Transaction* transaction, storage::TableScanState& scanState, + bool resetCachedBoundNodeSelVec = true) const override; + + bool scanInternal(transaction::Transaction* transaction, + storage::TableScanState& scanState) override; + + const std::string& getLanceDatasetPath() const { return datasetPath; } + +protected: + std::string getColumnarFormatName() const override { return "lance"; } + common::row_idx_t getTotalRowCount(const transaction::Transaction* transaction) const override; + common::row_idx_t getActiveBoundNodeCount(const transaction::Transaction* transaction, + common::RelDataDirection direction) const override; + std::vector> getAllDegreeEntries( + const transaction::Transaction* transaction, + common::RelDataDirection direction) const override; + std::vector> getTopKDegreeEntries( + const transaction::Transaction* transaction, common::RelDataDirection direction, + common::idx_t k) const override; + +private: + bool scanFlat(transaction::Transaction* transaction, LanceRelTableScanState& scanState); + + int32_t fromColumnIdx = -1; + int32_t toColumnIdx = -1; + std::string datasetPath; + mutable uint64_t cachedTotalRows = common::INVALID_ROW_IDX; +}; + +} // namespace lance_extension +} // namespace lbug diff --git a/lance/src/lance_extension.cpp b/lance/src/lance_extension.cpp new file mode 100644 index 0000000..2abc341 --- /dev/null +++ b/lance/src/lance_extension.cpp @@ -0,0 +1,67 @@ +#include "lance_extension.h" + +#include "common/enums/storage_format.h" +#include "lance_functions.h" +#include "lance_node_table.h" +#include "lance_rel_table.h" +#include "main/client_context.h" +#include "main/database.h" +#include "storage/storage_manager.h" + +namespace lbug { +namespace lance_extension { + +using namespace extension; + +void LanceExtension::load(main::ClientContext* context) { + auto& db = *context->getDatabase(); + auto* storageManager = storage::StorageManager::Get(*context); + + // ── Register table factories ──────────────────────────────────────────── + storage::NodeTableFactory nodeFactory = [](const storage::StorageManager* sm, + const catalog::NodeTableCatalogEntry* entry, + storage::MemoryManager* mm, + main::ClientContext* ctx) + -> std::unique_ptr { + return std::make_unique(sm, entry, mm, ctx); + }; + + storage::RelTableFactory relFactory = [](catalog::RelGroupCatalogEntry* entry, + common::table_id_t fromTableID, + common::table_id_t toTableID, + const storage::StorageManager* sm, + storage::MemoryManager* mm, + main::ClientContext* ctx) + -> std::unique_ptr { + return std::make_unique(entry, fromTableID, toTableID, sm, mm, ctx); + }; + + storageManager->registerStorageFormatHandler( + common::StorageFormat::LANCE, std::move(nodeFactory), std::move(relFactory)); + + // ── Register search functions ─────────────────────────────────────────── + ExtensionUtils::addTableFunc(db); + ExtensionUtils::addTableFunc(db); + ExtensionUtils::addTableFunc(db); +} + +} // namespace lance_extension +} // namespace lbug + +#if defined(BUILD_DYNAMIC_LOAD) +extern "C" { +#if defined(_WIN32) +#define INIT_EXPORT __declspec(dllexport) +#else +#define INIT_EXPORT __attribute__((visibility("default"))) +#endif + +INIT_EXPORT void init(lbug::main::ClientContext* context) { + lbug::lance_extension::LanceExtension::load(context); +} + +INIT_EXPORT const char* name() { + return lbug::lance_extension::LanceExtension::EXTENSION_NAME; +} +} // extern "C" +#endif diff --git a/lance/src/lance_functions.cpp b/lance/src/lance_functions.cpp new file mode 100644 index 0000000..85c94b0 --- /dev/null +++ b/lance/src/lance_functions.cpp @@ -0,0 +1,380 @@ +#include "lance_functions.h" + +#include +#include +#include + +#include "binder/binder.h" +#include "common/arrow/arrow.h" +#include "common/arrow/arrow_converter.h" +#include "common/arrow/arrow_nullmask_tree.h" +#include "common/exception/runtime.h" +#include "common/file_system/virtual_file_system.h" +#include "common/types/value/nested.h" +#include "common/types/value/value.h" +#include "function/table/bind_data.h" +#include "function/table/bind_input.h" +#include "function/table/table_function.h" +#include "main/client_context.h" +#include "processor/execution_context.h" + +#include "lance/lance.hpp" + +namespace lbug { +namespace lance_extension { + +using namespace function; +using namespace common; + +// ─── Vector Search ─────────────────────────────────────────────────────────── + +struct LanceVectorSearchBindData : TableFuncBindData { + std::string datasetPath; + std::string columnName; + std::vector queryVector; + uint32_t k; + std::string metric; // "cosine", "l2", "dot" + uint32_t nprobes; + + LanceVectorSearchBindData(std::string path, std::string col, std::vector query, + uint32_t k, std::string metric, uint32_t nprobes, + binder::expression_vector columns) + : TableFuncBindData{std::move(columns)}, datasetPath{std::move(path)}, + columnName{std::move(col)}, queryVector{std::move(query)}, k{k}, + metric{std::move(metric)}, nprobes{nprobes} {} + + std::unique_ptr copy() const override { + return std::make_unique(datasetPath, columnName, queryVector, k, + metric, nprobes, columns); + } +}; + +struct LanceSearchSharedState : TableFuncSharedState { + ArrowArrayStream stream; + bool exhausted = false; + std::mutex streamMtx; + + LanceSearchSharedState() { std::memset(&stream, 0, sizeof(stream)); } + + ~LanceSearchSharedState() override { + if (stream.release) stream.release(&stream); + } +}; + +static std::unique_ptr bindVectorSearch(main::ClientContext* context, + const TableFuncBindInput* input) { + if (input->params.size() < 4) { + throw RuntimeException("LANCE_VECTOR_SEARCH requires at least 4 arguments: " + "dataset_path, column, query_vector, k"); + } + auto datasetPath = input->getLiteralVal(0); + auto columnName = input->getLiteralVal(1); + // query_vector is a LIST of floats (passed as a literal Value with children) + auto vecValue = input->getValue(2); + std::vector queryVec; + for (uint32_t i = 0; i < vecValue.getChildrenSize(); ++i) { + queryVec.push_back( + static_cast(NestedVal::getChildVal(&vecValue, i)->getValue())); + } + auto k = static_cast(input->getLiteralVal(3)); + std::string metric = (input->params.size() > 4) ? input->getLiteralVal(4) : "l2"; + uint32_t nprobes = (input->params.size() > 5) + ? static_cast(input->getLiteralVal(5)) + : 1; + + // Open dataset to discover schema + auto resolvedPath = VirtualFileSystem::resolvePath(context, datasetPath); + try { + auto dataset = lance::Dataset::open(resolvedPath); + ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + dataset.schema(&schema); + + std::vector returnTypes; + std::vector returnNames; + for (int32_t i = 0; i < schema.n_children; ++i) { + if (!schema.children[i] || !schema.children[i]->name) continue; + returnNames.push_back(schema.children[i]->name); + returnTypes.push_back(ArrowConverter::fromArrowSchema(schema.children[i])); + } + // Add the _distance column + returnNames.push_back("_distance"); + returnTypes.push_back(LogicalType{LogicalTypeID::FLOAT}); + + if (schema.release) schema.release(&schema); + + auto columns = input->binder->createVariables(returnNames, returnTypes); + return std::make_unique(std::move(resolvedPath), + std::move(columnName), std::move(queryVec), k, std::move(metric), nprobes, + std::move(columns)); + } catch (const lance::Error& e) { + throw RuntimeException(std::string("LANCE_VECTOR_SEARCH bind failed: ") + e.what()); + } +} + +static std::shared_ptr initVectorSearchSharedState( + const TableFuncInitSharedStateInput& input) { + auto* bindData = input.bindData->constPtrCast(); + auto state = std::make_shared(); + + try { + auto dataset = lance::Dataset::open(bindData->datasetPath); + auto scanner = dataset.scan(); + + LanceMetricType metric = LANCE_METRIC_L2; + if (bindData->metric == "cosine") metric = LANCE_METRIC_COSINE; + else if (bindData->metric == "dot") metric = LANCE_METRIC_DOT; + + scanner.nearest(bindData->columnName, bindData->queryVector.data(), + bindData->queryVector.size(), bindData->k) + .nprobes(bindData->nprobes) + .metric(metric); + + scanner.to_arrow_stream(&state->stream); + } catch (const lance::Error& e) { + throw RuntimeException(std::string("LANCE_VECTOR_SEARCH init failed: ") + e.what()); + } + + return state; +} + +static offset_t vectorSearchTableFunc(const TableFuncInput& input, TableFuncOutput& output) { + auto* sharedState = input.sharedState->ptrCast(); + if (sharedState->exhausted) return 0; + + ArrowArray batch; + std::memset(&batch, 0, sizeof(batch)); + { + std::lock_guard lock{sharedState->streamMtx}; + int rc = sharedState->stream.get_next(&sharedState->stream, &batch); + if (rc != 0 || batch.release == nullptr) { + sharedState->exhausted = true; + return 0; + } + } + + auto batchLen = static_cast(batch.length); + ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + sharedState->stream.get_schema(&sharedState->stream, &schema); + + // Copy each column from the batch to the output data chunk + for (uint64_t col = 0; col < static_cast(output.dataChunk.getNumValueVectors()) && + col < static_cast(batch.n_children) && + col < static_cast(schema.n_children); + ++col) { + auto* childArr = batch.children[col]; + auto* childSchema = schema.children[col]; + if (!childArr || !childSchema) continue; + ArrowNullMaskTree nullMask(childSchema, childArr, childArr->offset, childArr->length); + ArrowConverter::fromArrowArray(childSchema, childArr, output.dataChunk.getValueVectorMutable(col), + &nullMask, static_cast(childArr->offset), 0, batchLen); + } + + if (schema.release) schema.release(&schema); + if (batch.release) batch.release(&batch); + output.dataChunk.state->getSelVectorUnsafe().setSelSize(batchLen); + return batchLen; +} + +function_set LanceVectorSearchFunction::getFunctionSet() { + function_set functionSet; + auto func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING, + LogicalTypeID::LIST, LogicalTypeID::INT64}); + func->bindFunc = bindVectorSearch; + func->initSharedStateFunc = initVectorSearchSharedState; + func->initLocalStateFunc = TableFunction::initEmptyLocalState; + func->tableFunc = vectorSearchTableFunc; + func->canParallelFunc = [] { return false; }; // stream is sequential + functionSet.push_back(std::move(func)); + return functionSet; +} + +// ─── FTS ───────────────────────────────────────────────────────────────────── + +struct LanceFTSBindData : TableFuncBindData { + std::string datasetPath; + std::string query; + std::vector searchColumns; // renamed from 'columns' to avoid base class clash + uint32_t maxFuzzyDistance; + + LanceFTSBindData(std::string path, std::string query, std::vector cols, + uint32_t maxFuzzy, binder::expression_vector outputColumns) + : TableFuncBindData{std::move(outputColumns)}, datasetPath{std::move(path)}, + query{std::move(query)}, searchColumns{std::move(cols)}, maxFuzzyDistance{maxFuzzy} {} + + std::unique_ptr copy() const override { + return std::make_unique(datasetPath, query, searchColumns, + maxFuzzyDistance, columns); + } +}; + +static std::unique_ptr bindFTS(main::ClientContext* context, + const TableFuncBindInput* input) { + if (input->params.size() < 2) { + throw RuntimeException("LANCE_FTS requires at least 2 arguments: dataset_path, query"); + } + auto datasetPath = input->getLiteralVal(0); + auto query = input->getLiteralVal(1); + std::vector searchCols; + if (input->params.size() > 2) { + searchCols.push_back(input->getLiteralVal(2)); + } + uint32_t maxFuzzy = (input->params.size() > 3) + ? static_cast(input->getLiteralVal(3)) + : 0; + + auto resolvedPath = VirtualFileSystem::resolvePath(context, datasetPath); + try { + auto dataset = lance::Dataset::open(resolvedPath); + ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + dataset.schema(&schema); + + std::vector returnTypes; + std::vector returnNames; + for (int32_t i = 0; i < schema.n_children; ++i) { + if (!schema.children[i] || !schema.children[i]->name) continue; + returnNames.push_back(schema.children[i]->name); + returnTypes.push_back(ArrowConverter::fromArrowSchema(schema.children[i])); + } + returnNames.push_back("_score"); + returnTypes.push_back(LogicalType{LogicalTypeID::FLOAT}); + + if (schema.release) schema.release(&schema); + auto columns = input->binder->createVariables(returnNames, returnTypes); + return std::make_unique(std::move(resolvedPath), std::move(query), + std::move(searchCols), maxFuzzy, std::move(columns)); + } catch (const lance::Error& e) { + throw RuntimeException(std::string("LANCE_FTS bind failed: ") + e.what()); + } +} + +static std::shared_ptr initFTSSharedState( + const TableFuncInitSharedStateInput& input) { + auto* bindData = input.bindData->constPtrCast(); + auto state = std::make_shared(); + try { + auto dataset = lance::Dataset::open(bindData->datasetPath); + dataset.scan() + .full_text_search(bindData->query, bindData->searchColumns, bindData->maxFuzzyDistance) + .to_arrow_stream(&state->stream); + } catch (const lance::Error& e) { + throw RuntimeException(std::string("LANCE_FTS init failed: ") + e.what()); + } + return state; +} + +function_set LanceFTSFunction::getFunctionSet() { + function_set functionSet; + auto func = std::make_unique( + name, std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING}); + func->bindFunc = bindFTS; + func->initSharedStateFunc = initFTSSharedState; + func->initLocalStateFunc = TableFunction::initEmptyLocalState; + func->tableFunc = vectorSearchTableFunc; // reuse same Arrow stream output logic + func->canParallelFunc = [] { return false; }; + functionSet.push_back(std::move(func)); + return functionSet; +} + +// ─── Hybrid Search ─────────────────────────────────────────────────────────── + +struct LanceHybridSearchBindData : TableFuncBindData { + std::string datasetPath; + std::string vectorColumn; + std::vector queryVector; + uint32_t k; + std::string ftsQuery; + + LanceHybridSearchBindData(std::string path, std::string col, std::vector query, + uint32_t k, std::string ftsQuery, binder::expression_vector columns) + : TableFuncBindData{std::move(columns)}, datasetPath{std::move(path)}, + vectorColumn{std::move(col)}, queryVector{std::move(query)}, k{k}, + ftsQuery{std::move(ftsQuery)} {} + + std::unique_ptr copy() const override { + return std::make_unique(datasetPath, vectorColumn, queryVector, + k, ftsQuery, columns); + } +}; + +static std::unique_ptr bindHybridSearch(main::ClientContext* context, + const TableFuncBindInput* input) { + if (input->params.size() < 5) { + throw RuntimeException( + "LANCE_HYBRID_SEARCH requires 5 arguments: dataset_path, column, query_vector, k, fts_query"); + } + auto datasetPath = input->getLiteralVal(0); + auto columnName = input->getLiteralVal(1); + auto vecValue = input->getValue(2); + std::vector queryVec; + for (uint32_t i = 0; i < vecValue.getChildrenSize(); ++i) { + queryVec.push_back( + static_cast(NestedVal::getChildVal(&vecValue, i)->getValue())); + } + auto k = static_cast(input->getLiteralVal(3)); + auto ftsQuery = input->getLiteralVal(4); + + auto resolvedPath = VirtualFileSystem::resolvePath(context, datasetPath); + try { + auto dataset = lance::Dataset::open(resolvedPath); + ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + dataset.schema(&schema); + + std::vector returnTypes; + std::vector returnNames; + for (int32_t i = 0; i < schema.n_children; ++i) { + if (!schema.children[i] || !schema.children[i]->name) continue; + returnNames.push_back(schema.children[i]->name); + returnTypes.push_back(ArrowConverter::fromArrowSchema(schema.children[i])); + } + returnNames.push_back("_score"); + returnTypes.push_back(LogicalType{LogicalTypeID::FLOAT}); + + if (schema.release) schema.release(&schema); + auto columns = input->binder->createVariables(returnNames, returnTypes); + return std::make_unique(std::move(resolvedPath), + std::move(columnName), std::move(queryVec), k, std::move(ftsQuery), + std::move(columns)); + } catch (const lance::Error& e) { + throw RuntimeException(std::string("LANCE_HYBRID_SEARCH bind failed: ") + e.what()); + } +} + +static std::shared_ptr initHybridSearchSharedState( + const TableFuncInitSharedStateInput& input) { + auto* bindData = input.bindData->constPtrCast(); + auto state = std::make_shared(); + try { + auto dataset = lance::Dataset::open(bindData->datasetPath); + // Run vector search + FTS; the scanner merges results internally + dataset.scan() + .nearest(bindData->vectorColumn, bindData->queryVector.data(), + bindData->queryVector.size(), bindData->k) + .full_text_search(bindData->ftsQuery) + .to_arrow_stream(&state->stream); + } catch (const lance::Error& e) { + throw RuntimeException(std::string("LANCE_HYBRID_SEARCH init failed: ") + e.what()); + } + return state; +} + +function_set LanceHybridSearchFunction::getFunctionSet() { + function_set functionSet; + auto func = std::make_unique(name, + std::vector{LogicalTypeID::STRING, LogicalTypeID::STRING, + LogicalTypeID::LIST, LogicalTypeID::INT64, LogicalTypeID::STRING}); + func->bindFunc = bindHybridSearch; + func->initSharedStateFunc = initHybridSearchSharedState; + func->initLocalStateFunc = TableFunction::initEmptyLocalState; + func->tableFunc = vectorSearchTableFunc; // reuse Arrow stream output logic + func->canParallelFunc = [] { return false; }; + functionSet.push_back(std::move(func)); + return functionSet; +} + +} // namespace lance_extension +} // namespace lbug diff --git a/lance/src/lance_node_table.cpp b/lance/src/lance_node_table.cpp new file mode 100644 index 0000000..a23faec --- /dev/null +++ b/lance/src/lance_node_table.cpp @@ -0,0 +1,375 @@ +#include "lance_node_table.h" + +#include + +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "common/arrow/arrow_converter.h" +#include "common/arrow/arrow_nullmask_tree.h" +#include "common/data_chunk/sel_vector.h" +#include "common/exception/runtime.h" +#include "common/file_system/virtual_file_system.h" +#include "common/system_config.h" +#include "common/types/types.h" +#include "main/client_context.h" +#include "storage/storage_manager.h" +#include "storage/table/node_table.h" +#include "transaction/transaction.h" + +#include "lance/lance.hpp" + +namespace lbug { +namespace lance_extension { + +// ─── LanceNodeTableScanSharedState ─────────────────────────────────────────── + +void LanceNodeTableScanSharedState::reset(ArrowArrayStream newStream) { + std::lock_guard lock(mtx); + if (stream_.release) stream_.release(&stream_); + stream_ = newStream; + std::memset(&newStream, 0, sizeof(newStream)); // disarm the original + + streamExhausted = false; + currentBatch = nullptr; + currentBatchGlobalOffset = 0; + currentMorselStart = 0; + + if (!streamSchemaFetched) { + // Fetch schema once so each batch's children can be addressed later. + // The stream schema is the same for all batches. + if (stream_.get_schema && stream_.get_schema(&stream_, &streamSchema_) == 0) { + streamSchemaFetched = true; + } + } +} + +bool LanceNodeTableScanSharedState::readNextBatch() { + // Caller must hold mtx + auto newBatch = std::make_shared(); + int rc = stream_.get_next(&stream_, &newBatch->array); + if (rc != 0 || newBatch->array.release == nullptr) { + streamExhausted = true; + return false; + } + newBatch->length = static_cast(newBatch->array.length); + + // Copy the cached stream schema into the batch schema so each batch carries + // all the structural information ArrowConverter needs. + if (streamSchemaFetched && streamSchema_.format != nullptr) { + newBatch->schema = streamSchema_; + // We do NOT want ~LanceBatchData to release the schema (it's owned by + // streamSchema_), so clear the release pointer. + newBatch->schema.release = nullptr; + } + + size_t prevBatchLength = currentBatch ? currentBatch->length : 0; + currentBatchGlobalOffset += prevBatchLength; + currentBatch = std::move(newBatch); + currentMorselStart = 0; + return true; +} + +bool LanceNodeTableScanSharedState::getNextMorsel( + storage::ColumnarNodeTableScanState* scanState) { + auto* lanceScanState = dynamic_cast(scanState); + if (!lanceScanState) return false; + + std::lock_guard lock(mtx); + + while (true) { + // If there's data remaining in the current batch, carve off a morsel. + if (currentBatch && currentMorselStart < currentBatch->length) { + lanceScanState->currentBatch = currentBatch; + lanceScanState->batchStartGlobalOffset = currentBatchGlobalOffset; + lanceScanState->morselStart = currentMorselStart; + lanceScanState->morselEnd = + std::min(currentMorselStart + morselSize, currentBatch->length); + currentMorselStart = lanceScanState->morselEnd; + return true; + } + // Need the next batch from the stream. + if (streamExhausted) return false; + if (!readNextBatch()) return false; + // Loop to assign a morsel from the freshly-read batch. + } +} + +// ─── LanceNodeTable ────────────────────────────────────────────────────────── + +LanceNodeTable::LanceNodeTable(const storage::StorageManager* storageManager, + const catalog::NodeTableCatalogEntry* nodeTableEntry, storage::MemoryManager* memoryManager, + main::ClientContext* context) + : storage::ColumnarNodeTableBase{storageManager, nodeTableEntry, memoryManager, + std::make_unique(kDefaultMorselSize)} { + std::memset(&cachedSchema_, 0, sizeof(cachedSchema_)); + + // The catalog stores the lance dataset path in the 'storage' field. + const std::string& storagePath = nodeTableEntry->getStorage(); + if (storagePath.empty()) { + throw common::RuntimeException( + "Lance node table has empty storage path. " + "Specify the dataset path via storage='path/to/dataset.lance'."); + } + datasetPath = + common::VirtualFileSystem::resolvePath(context, storagePath); + + try { + auto dataset = lance::Dataset::open(datasetPath); + cachedTotalRows = dataset.count_rows(); + + // Cache the schema for later column mapping + dataset.schema(&cachedSchema_); + schemaCached = true; + numLanceColumns = static_cast(cachedSchema_.n_children); + } catch (const lance::Error& e) { + throw common::RuntimeException( + std::string("Failed to open lance dataset '") + datasetPath + "': " + e.what()); + } +} + +void LanceNodeTable::initializeScanCoordination( + const transaction::Transaction* transaction) { + auto* lanceSharedState = + static_cast(tableScanSharedState.get()); + + try { + auto dataset = lance::Dataset::open(datasetPath); + + auto scanner = dataset.scan(); + scanner.batch_size(static_cast(kDefaultMorselSize)); + + ArrowArrayStream stream; + std::memset(&stream, 0, sizeof(stream)); + scanner.to_arrow_stream(&stream); + + lanceSharedState->reset(stream); + } catch (const lance::Error& e) { + throw common::RuntimeException( + std::string("Failed to initialize lance scan on '") + datasetPath + "': " + e.what()); + } +} + +void LanceNodeTable::initScanState(transaction::Transaction* /*transaction*/, + storage::TableScanState& scanState, bool /*resetCachedBoundNodeSelVec*/) const { + auto& lanceScanState = scanState.cast(); + + lanceScanState.initialized = false; + lanceScanState.scanCompleted = true; + + if (lanceScanState.source == storage::TableScanSource::COMMITTED && + lanceScanState.currentBatch != nullptr) { + lanceScanState.scanCompleted = false; + } + + lanceScanState.initialized = true; +} + +bool LanceNodeTable::scanInternal(transaction::Transaction* /*transaction*/, + storage::TableScanState& scanState) { + auto& lanceScanState = scanState.cast(); + + if (lanceScanState.scanCompleted) return false; + if (!lanceScanState.currentBatch || + lanceScanState.morselStart >= lanceScanState.morselEnd) { + lanceScanState.scanCompleted = true; + return false; + } + + const auto& batch = *lanceScanState.currentBatch; + const auto morselStart = lanceScanState.morselStart; + const auto morselEnd = lanceScanState.morselEnd; + const auto outputSize = static_cast(morselEnd - morselStart); + const auto globalStartOffset = + lanceScanState.batchStartGlobalOffset + morselStart; + + scanState.resetOutVectors(); + scanState.outState->getSelVectorUnsafe().setSelSize(outputSize); + + NodeTable::applySemiMaskFilter(scanState, globalStartOffset, outputSize, + scanState.outState->getSelVectorUnsafe()); + + if (scanState.outState->getSelVector().getSelSize() == 0) { + // Advance offset even if all rows are masked out so we don't loop forever. + lanceScanState.morselStart += outputSize; + return false; + } + + const auto outputToLanceColIdx = getOutputToLanceColumnIdx(scanState.columnIDs); + copyLanceMorselToOutputVectors(batch, morselStart, outputSize, + scanState.outputVectors, outputToLanceColIdx); + + const auto tableID = this->getTableID(); + for (uint64_t i = 0; i < outputSize; ++i) { + auto& nodeID = scanState.nodeIDVector->getValue(i); + nodeID.tableID = tableID; + nodeID.offset = globalStartOffset + i; + } + + lanceScanState.morselStart += outputSize; + return true; +} + +size_t LanceNodeTable::getNumScanMorsels( + const transaction::Transaction* transaction) const { + auto totalRows = getTotalRowCount(transaction); + return (totalRows + kDefaultMorselSize - 1) / kDefaultMorselSize; +} + +std::unique_ptr LanceNodeTable::createScanState( + common::ValueVector* nodeIDVector, + const std::vector& outVectors, + storage::MemoryManager* memoryManager) const { + return std::make_unique(*memoryManager, nodeIDVector, outVectors, + nodeIDVector->state); +} + +bool LanceNodeTable::isVisible(const transaction::Transaction* /*transaction*/, + common::offset_t offset) const { + return offset < cachedTotalRows; +} + +bool LanceNodeTable::isVisibleNoLock(const transaction::Transaction* /*transaction*/, + common::offset_t offset) const { + return offset < cachedTotalRows; +} + +bool LanceNodeTable::lookupPK(const transaction::Transaction* /*transaction*/, + common::ValueVector* keyVector, uint64_t vectorPos, common::offset_t& result) const { + if (keyVector->isNull(vectorPos)) return false; + if (!schemaCached) return false; + + auto pkColumnID = getPKColumnID(); + int64_t pkLanceIdx = -1; + for (common::idx_t propIdx = 0; propIdx < nodeTableCatalogEntry->getNumProperties(); + ++propIdx) { + if (nodeTableCatalogEntry->getColumnID(propIdx) == pkColumnID) { + pkLanceIdx = static_cast(propIdx); + break; + } + } + if (pkLanceIdx < 0 || pkLanceIdx >= cachedSchema_.n_children) return false; + + auto keyToLookup = keyVector->getAsValue(vectorPos); + auto pkType = getColumn(pkColumnID).getDataType().copy(); + auto singleState = common::DataChunkState::getSingleValueDataChunkState(); + auto pkVector = + std::make_unique(std::move(pkType), memoryManager, singleState); + pkVector->state->setToFlat(); + + // Full table scan to find the PK — acceptable for a lookup on an immutable table. + // For better performance the user should build a lance scalar index on the PK column. + try { + auto dataset = lance::Dataset::open(datasetPath); + auto scanner = dataset.scan(); + scanner.batch_size(4096); + ArrowArrayStream stream; + std::memset(&stream, 0, sizeof(stream)); + scanner.to_arrow_stream(&stream); + + ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + if (stream.get_schema && stream.get_schema(&stream, &schema) != 0) { + if (stream.release) stream.release(&stream); + return false; + } + + uint64_t globalOffset = 0; + ArrowArray batch; + std::memset(&batch, 0, sizeof(batch)); + while (stream.get_next(&stream, &batch) == 0 && batch.release != nullptr) { + const auto batchLen = static_cast(batch.length); + if (pkLanceIdx < batch.n_children && batch.children[pkLanceIdx] && + schema.n_children > pkLanceIdx && schema.children[pkLanceIdx]) { + auto* childArr = batch.children[pkLanceIdx]; + auto* childSchema = schema.children[pkLanceIdx]; + common::ArrowNullMaskTree nullMask( + childSchema, childArr, childArr->offset, childArr->length); + for (uint64_t rowIdx = 0; rowIdx < batchLen; ++rowIdx) { + common::ArrowConverter::fromArrowArray( + childSchema, childArr, *pkVector, &nullMask, + static_cast(childArr->offset) + rowIdx, 0, 1); + if (!pkVector->isNull(0) && *pkVector->getAsValue(0) == *keyToLookup) { + result = globalOffset + rowIdx; + if (batch.release) batch.release(&batch); + if (schema.release) schema.release(&schema); + if (stream.release) stream.release(&stream); + return true; + } + } + } + if (batch.release) batch.release(&batch); + std::memset(&batch, 0, sizeof(batch)); + globalOffset += batchLen; + } + if (schema.release) schema.release(&schema); + if (stream.release) stream.release(&stream); + } catch (const lance::Error& e) { + throw common::RuntimeException( + std::string("Lance PK lookup failed: ") + e.what()); + } + return false; +} + +common::node_group_idx_t LanceNodeTable::getNumBatches( + const transaction::Transaction* transaction) const { + auto totalRows = getTotalRowCount(transaction); + return static_cast( + (totalRows + kDefaultMorselSize - 1) / kDefaultMorselSize); +} + +common::row_idx_t LanceNodeTable::getTotalRowCount( + const transaction::Transaction* /*transaction*/) const { + if (cachedTotalRows != common::INVALID_ROW_IDX) return cachedTotalRows; + try { + auto dataset = lance::Dataset::open(datasetPath); + cachedTotalRows = dataset.count_rows(); + } catch (const lance::Error& e) { + throw common::RuntimeException( + std::string("Failed to count rows in lance dataset '") + datasetPath + "': " + + e.what()); + } + return cachedTotalRows; +} + +std::vector LanceNodeTable::getOutputToLanceColumnIdx( + const std::vector& columnIDs) const { + std::vector result(columnIDs.size(), -1); + for (size_t col = 0; col < columnIDs.size(); ++col) { + const auto colID = columnIDs[col]; + if (colID == common::INVALID_COLUMN_ID || colID == common::ROW_IDX_COLUMN_ID) continue; + for (common::idx_t propIdx = 0; + propIdx < nodeTableCatalogEntry->getNumProperties(); ++propIdx) { + if (nodeTableCatalogEntry->getColumnID(propIdx) == colID) { + result[col] = static_cast(propIdx); + break; + } + } + } + return result; +} + +void LanceNodeTable::copyLanceMorselToOutputVectors(const LanceBatchData& batch, + uint64_t morselStart, uint64_t numRows, + const std::vector& outputVectors, + const std::vector& outputToLanceColIdx) const { + if (!batch.array.children || !batch.schema.children) return; + const auto numChildren = static_cast(batch.array.n_children); + + for (uint64_t outCol = 0; outCol < outputVectors.size(); ++outCol) { + if (!outputVectors[outCol]) continue; + const auto lanceIdx = outputToLanceColIdx[outCol]; + if (lanceIdx < 0 || static_cast(lanceIdx) >= numChildren) continue; + if (!batch.array.children[lanceIdx] || !batch.schema.children[lanceIdx]) continue; + + auto* childArray = batch.array.children[lanceIdx]; + auto* childSchema = batch.schema.children[lanceIdx]; + common::ArrowNullMaskTree nullMask( + childSchema, childArray, childArray->offset, childArray->length); + common::ArrowConverter::fromArrowArray(childSchema, childArray, *outputVectors[outCol], + &nullMask, + static_cast(childArray->offset) + morselStart, + 0, numRows); + } +} + +} // namespace lance_extension +} // namespace lbug diff --git a/lance/src/lance_rel_table.cpp b/lance/src/lance_rel_table.cpp new file mode 100644 index 0000000..d1431e1 --- /dev/null +++ b/lance/src/lance_rel_table.cpp @@ -0,0 +1,357 @@ +#include "lance_rel_table.h" + +#include + +#include "catalog/catalog_entry/rel_group_catalog_entry.h" +#include "common/arrow/arrow_converter.h" +#include "common/arrow/arrow_nullmask_tree.h" +#include "common/data_chunk/sel_vector.h" +#include "common/exception/runtime.h" +#include "common/file_system/virtual_file_system.h" +#include "common/system_config.h" +#include "common/types/internal_id_util.h" +#include "common/types/types.h" +#include "main/client_context.h" +#include "storage/storage_manager.h" +#include "transaction/transaction.h" + +#include "lance/lance.hpp" +#include "lance_node_table.h" + +namespace lbug { +namespace lance_extension { + +using namespace common; +using namespace storage; +using namespace transaction; + +// ─── LanceRelTableScanState ────────────────────────────────────────────────── + +LanceRelTableScanState::LanceRelTableScanState(MemoryManager& mm, + common::ValueVector* nodeIDVector, std::vector outputVectors, + std::shared_ptr outChunkState) + : RelTableScanState{mm, nodeIDVector, std::move(outputVectors), std::move(outChunkState)} { + std::memset(&stream, 0, sizeof(stream)); + std::memset(&streamSchema, 0, sizeof(streamSchema)); +} + +LanceRelTableScanState::~LanceRelTableScanState() { + if (stream.release) stream.release(&stream); + if (streamSchema.release) streamSchema.release(&streamSchema); +} + +void LanceRelTableScanState::setToTable(const Transaction* transaction, Table* table_, + std::vector columnIDs_, std::vector columnPredicateSets_, + RelDataDirection direction_) { + // Call base class (skips local table setup which lance doesn't support) + TableScanState::setToTable(transaction, table_, std::move(columnIDs_), + std::move(columnPredicateSets_)); + columns.resize(columnIDs.size()); + direction = direction_; + for (size_t i = 0; i < columnIDs.size(); ++i) { + const auto colID = columnIDs[i]; + if (colID == INVALID_COLUMN_ID || colID == ROW_IDX_COLUMN_ID) { + columns[i] = nullptr; + } else { + columns[i] = table->cast().getColumn(colID, direction); + } + } + csrOffsetColumn = table->cast().getCSROffsetColumn(direction); + csrLengthColumn = table->cast().getCSRLengthColumn(direction); + nodeGroupIdx = INVALID_NODE_GROUP_IDX; +} + +void LanceRelTableScanState::reset( + std::unordered_map boundNodeOffsets_) { + cachedBatchData = nullptr; + currentBatchStartOffset = 0; + currentLocalRowIdx = 0; + boundNodeOffsets = std::move(boundNodeOffsets_); + // Re-open stream from the rel table + streamExhausted = false; + streamInitialized = false; + schemaFetched = false; + if (stream.release) stream.release(&stream); + std::memset(&stream, 0, sizeof(stream)); + if (streamSchema.release) streamSchema.release(&streamSchema); + std::memset(&streamSchema, 0, sizeof(streamSchema)); +} + +// ─── LanceRelTable ─────────────────────────────────────────────────────────── + +LanceRelTable::LanceRelTable(catalog::RelGroupCatalogEntry* relGroupEntry, + common::table_id_t fromTableID, common::table_id_t toTableID, + const StorageManager* storageManager, MemoryManager* memoryManager, + main::ClientContext* context) + : ColumnarRelTableBase{relGroupEntry, fromTableID, toTableID, storageManager, memoryManager} { + const auto& storage = relGroupEntry->getStorage(); + if (storage.empty()) { + throw common::RuntimeException( + "Lance rel table has empty storage path. " + "Specify the dataset path via storage='path/to/rel.lance'."); + } + + datasetPath = common::VirtualFileSystem::resolvePath(context, storage); + + try { + auto dataset = lance::Dataset::open(datasetPath); + cachedTotalRows = dataset.count_rows(); + + // Discover 'from' and 'to' column indices in the lance schema + ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + dataset.schema(&schema); + + for (int32_t i = 0; i < schema.n_children; ++i) { + if (!schema.children[i] || !schema.children[i]->name) continue; + std::string colName = schema.children[i]->name; + if (colName == "from") fromColumnIdx = i; + else if (colName == "to") toColumnIdx = i; + } + if (schema.release) schema.release(&schema); + + if (fromColumnIdx < 0 || toColumnIdx < 0) { + throw common::RuntimeException( + "Lance rel table dataset '" + datasetPath + + "' must contain 'from' and 'to' columns. " + "Found fromColumnIdx=" + std::to_string(fromColumnIdx) + + " toColumnIdx=" + std::to_string(toColumnIdx)); + } + } catch (const lance::Error& e) { + throw common::RuntimeException( + std::string("Failed to open lance rel dataset '") + datasetPath + "': " + e.what()); + } +} + +void LanceRelTable::initScanState(Transaction* transaction, TableScanState& scanState, + bool resetCachedBoundNodeSelVec) const { + auto& relScanState = scanState.cast(); + relScanState.source = TableScanSource::COMMITTED; + relScanState.nodeGroup = nullptr; + relScanState.nodeGroupIdx = INVALID_NODE_GROUP_IDX; + + if (resetCachedBoundNodeSelVec) { + if (relScanState.nodeIDVector->state->getSelVector().isUnfiltered()) { + relScanState.cachedBoundNodeSelVector.setToUnfiltered(); + } else { + relScanState.cachedBoundNodeSelVector.setToFiltered(); + std::memcpy(relScanState.cachedBoundNodeSelVector.getMutableBuffer().data(), + relScanState.nodeIDVector->state->getSelVector().getMutableBuffer().data(), + relScanState.nodeIDVector->state->getSelVector().getSelSize() * sizeof(sel_t)); + } + relScanState.cachedBoundNodeSelVector.setSelSize( + relScanState.nodeIDVector->state->getSelVector().getSelSize()); + } + + auto& lanceScanState = static_cast(relScanState); + + // Build bound node offsets map + std::unordered_map boundNodeOffsets; + for (size_t i = 0; i < lanceScanState.cachedBoundNodeSelVector.getSelSize(); ++i) { + const sel_t idx = lanceScanState.cachedBoundNodeSelVector[i]; + const auto nodeID = lanceScanState.nodeIDVector->getValue(idx); + boundNodeOffsets.insert({nodeID.offset, idx}); + } + lanceScanState.reset(std::move(boundNodeOffsets)); + + // Open a new lance stream for this scan + try { + auto dataset = lance::Dataset::open(datasetPath); + auto scanner = dataset.scan(); + scanner.batch_size(4096); + scanner.to_arrow_stream(&lanceScanState.stream); + lanceScanState.streamInitialized = true; + lanceScanState.streamExhausted = false; + + if (!lanceScanState.schemaFetched && lanceScanState.stream.get_schema) { + if (lanceScanState.stream.get_schema(&lanceScanState.stream, + &lanceScanState.streamSchema) == 0) { + lanceScanState.schemaFetched = true; + } + } + } catch (const lance::Error& e) { + throw common::RuntimeException( + std::string("Failed to open lance rel scan on '") + datasetPath + "': " + e.what()); + } +} + +bool LanceRelTable::scanInternal(Transaction* transaction, TableScanState& scanState) { + auto& lanceScanState = static_cast(scanState); + return scanFlat(transaction, lanceScanState); +} + +// Helper: read a uint64 node offset from an ArrowArray column at a given local row index. +// Lance stores node offsets as int64 or uint64 — both are safe to read as int64 and cast. +static uint64_t readOffset(const ArrowArray* arr, uint64_t localIdx) { + if (!arr || !arr->buffers || !arr->buffers[1]) return INVALID_OFFSET; + const auto* data = static_cast(arr->buffers[1]); + return static_cast(data[static_cast(arr->offset) + localIdx]); +} + +bool LanceRelTable::scanFlat(Transaction* /*transaction*/, LanceRelTableScanState& scanState) { + scanState.resetOutVectors(); + + if (scanState.boundNodeOffsets.empty() || !scanState.streamInitialized || + scanState.streamExhausted) { + scanState.outState->getSelVectorUnsafe().setToFiltered(0); + return false; + } + + const bool isFwd = scanState.direction != RelDataDirection::BWD; + uint64_t totalRowsCollected = 0; + const uint64_t maxRowsPerCall = DEFAULT_VECTOR_CAPACITY; + + while (totalRowsCollected < maxRowsPerCall) { + // Load next batch if current one is exhausted + if (!scanState.cachedBatchData || + scanState.currentLocalRowIdx >= scanState.cachedBatchData->length) { + // Release previous batch + if (scanState.cachedBatchData) { + scanState.currentBatchStartOffset += scanState.cachedBatchData->length; + } + scanState.currentLocalRowIdx = 0; + scanState.cachedBatchData = nullptr; + + // Read next batch from stream + auto newBatch = std::make_shared(); + int rc = scanState.stream.get_next(&scanState.stream, &newBatch->array); + if (rc != 0 || newBatch->array.release == nullptr) { + scanState.streamExhausted = true; + break; + } + newBatch->length = static_cast(newBatch->array.length); + if (scanState.schemaFetched && scanState.streamSchema.format != nullptr) { + newBatch->schema = scanState.streamSchema; + newBatch->schema.release = nullptr; // schema owned by scanState.streamSchema + } + scanState.cachedBatchData = std::move(newBatch); + } + + const auto& batch = *scanState.cachedBatchData; + if (batch.length == 0 || !batch.array.children || !batch.schema.children) break; + + const auto numChildren = static_cast(batch.array.n_children); + if (fromColumnIdx < 0 || toColumnIdx < 0 || + static_cast(fromColumnIdx) >= numChildren || + static_cast(toColumnIdx) >= numChildren) { + break; + } + + auto* fromArr = batch.array.children[fromColumnIdx]; + auto* toArr = batch.array.children[toColumnIdx]; + if (!fromArr || !toArr) break; + + for (; scanState.currentLocalRowIdx < batch.length && + totalRowsCollected < maxRowsPerCall; + ++scanState.currentLocalRowIdx) { + const auto localIdx = scanState.currentLocalRowIdx; + + // Read from/to offsets from the arrow arrays + // These are stored as uint64 (internal node offsets) + const auto fromOffset = readOffset(fromArr, localIdx); + const auto toOffset = readOffset(toArr, localIdx); + const auto boundOffset = isFwd ? fromOffset : toOffset; + + auto boundIt = scanState.boundNodeOffsets.find(boundOffset); + if (boundIt == scanState.boundNodeOffsets.end()) continue; + + const auto nbrOffset = isFwd ? toOffset : fromOffset; + const auto nbrTableID = isFwd ? getToNodeTableID() : getFromNodeTableID(); + const auto globalRowIdx = + scanState.currentBatchStartOffset + scanState.currentLocalRowIdx; + + // Fill output vectors + if (!scanState.outputVectors.empty()) { + scanState.outputVectors[0]->setValue( + totalRowsCollected, internalID_t{nbrOffset, nbrTableID}); + } + + for (uint64_t outCol = 1; outCol < scanState.outputVectors.size(); ++outCol) { + if (outCol >= scanState.columnIDs.size()) continue; + const auto colID = scanState.columnIDs[outCol]; + if (colID == INVALID_COLUMN_ID || colID == ROW_IDX_COLUMN_ID || + colID == NBR_ID_COLUMN_ID) + continue; + if (colID == REL_ID_COLUMN_ID) { + scanState.outputVectors[outCol]->setValue( + totalRowsCollected, internalID_t{globalRowIdx, getTableID()}); + continue; + } + // Property column: map colID → lance column index + // Lance columns start after 'from' and 'to' columns. + // Property index = colID - 2 (assuming columns are ordered after from/to) + // A more robust approach would use the schema names, but we use colID directly. + const int64_t lanceColIdx = + static_cast(colID) + 2; // +2 to skip from, to + if (lanceColIdx < 0 || static_cast(lanceColIdx) >= numChildren) + continue; + auto* propArr = batch.array.children[lanceColIdx]; + auto* propSchema = batch.schema.children[lanceColIdx]; + if (!propArr || !propSchema) continue; + + common::ArrowNullMaskTree nullMask(propSchema, propArr, propArr->offset, propArr->length); + common::ArrowConverter::fromArrowArray(propSchema, propArr, + *scanState.outputVectors[outCol], &nullMask, + static_cast(propArr->offset) + localIdx, totalRowsCollected, 1); + } + + // Assign the bound node ID for join-back + if (scanState.nodeIDVector) { + scanState.nodeIDVector->setValue( + totalRowsCollected, + internalID_t{boundOffset, + isFwd ? getFromNodeTableID() : getToNodeTableID()}); + } + + ++totalRowsCollected; + } + + // If we read the entire batch without filling the output buffer, continue to next batch. + if (scanState.currentLocalRowIdx >= batch.length) continue; + // Otherwise, we filled the buffer; return what we have. + break; + } + + if (totalRowsCollected == 0) { + scanState.outState->getSelVectorUnsafe().setToFiltered(0); + return false; + } + scanState.outState->getSelVectorUnsafe().setSelSize(totalRowsCollected); + return true; +} + +common::row_idx_t LanceRelTable::getTotalRowCount( + const Transaction* /*transaction*/) const { + if (cachedTotalRows != INVALID_ROW_IDX) return cachedTotalRows; + try { + auto dataset = lance::Dataset::open(datasetPath); + cachedTotalRows = dataset.count_rows(); + } catch (const lance::Error& e) { + throw common::RuntimeException( + std::string("Failed to count rows in lance rel dataset '") + datasetPath + "': " + + e.what()); + } + return cachedTotalRows; +} + +common::row_idx_t LanceRelTable::getActiveBoundNodeCount( + const Transaction* /*transaction*/, RelDataDirection /*direction*/) const { + // Return estimate: assume each bound node has at least one relationship + return cachedTotalRows != INVALID_ROW_IDX ? cachedTotalRows : 0; +} + +std::vector> LanceRelTable::getAllDegreeEntries( + const Transaction* /*transaction*/, RelDataDirection /*direction*/) const { + // Full degree computation for lance tables: would require scanning 'from'/'to' columns + // and counting occurrences. This is only called for stats; return empty for now. + return {}; +} + +std::vector> LanceRelTable::getTopKDegreeEntries( + const Transaction* /*transaction*/, RelDataDirection /*direction*/, + common::idx_t /*k*/) const { + return {}; +} + +} // namespace lance_extension +} // namespace lbug diff --git a/lance/test/CMakeLists.txt b/lance/test/CMakeLists.txt new file mode 100644 index 0000000..2747192 --- /dev/null +++ b/lance/test/CMakeLists.txt @@ -0,0 +1,16 @@ +if (${BUILD_EXTENSION_TESTS}) + # Lance extension tests require the lance extension to be built + add_lbug_test(lance_api_test + lance_node_table_test.cpp + lance_rel_table_test.cpp + lance_error_scenarios_test.cpp + lance_vector_search_test.cpp + ) + target_link_libraries(lance_api_test PRIVATE lbug_${EXTENSION_LIB_NAME}_extension) + target_include_directories(lance_api_test PRIVATE + ${PROJECT_SOURCE_DIR}/src/include + ${CMAKE_BINARY_DIR}/src/include + ${PROJECT_SOURCE_DIR}/extension/lance/src/include + ${LANCE_C_ROOT}/include + ) +endif() diff --git a/lance/test/lance_error_scenarios_test.cpp b/lance/test/lance_error_scenarios_test.cpp new file mode 100644 index 0000000..f3503f3 --- /dev/null +++ b/lance/test/lance_error_scenarios_test.cpp @@ -0,0 +1,82 @@ +/// Error scenario tests for lance extension. +/// Tests extension-not-loaded errors, mixed-format rejection, +/// missing/malformed files, and schema validation errors. + +#include +#include + +#include "graph_test/private_graph_test.h" +#include "gtest/gtest.h" +#include "main/query_result.h" + +using namespace lbug; +using namespace lbug::testing; +namespace fs = std::filesystem; + +// ─── Fixture ───────────────────────────────────────────────────────────────── + +class LanceErrorScenariosTest : public EmptyDBTest { +protected: + void SetUp() override { + EmptyDBTest::SetUp(); + createDBAndConn(); + } +}; + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +TEST_F(LanceErrorScenariosTest, ExtensionNotLoadedGivesActionableError) { + // If lance extension is not loaded, CREATE TABLE format='lance' should give a clear error + auto result = conn->query( + "CREATE NODE TABLE LanceNode (id INT64 PRIMARY KEY) " + "storage='/tmp/some.lance' format='lance'"); + // Either extension is loaded (success) or gives actionable error, never a cryptic crash + if (!result->isSuccess()) { + auto errMsg = result->getErrorMessage(); + // Must mention extension or format in error + ASSERT_TRUE(errMsg.find("lance") != std::string::npos || + errMsg.find("extension") != std::string::npos || + errMsg.find("format") != std::string::npos) + << "Error message not actionable: " << errMsg; + } +} + +TEST_F(LanceErrorScenariosTest, MissingFileHandledGracefully) { + // If extension is loaded, opening a non-existent lance file should fail gracefully +#if defined(BUILD_DYNAMIC_LOAD) + auto extPath = TestHelper::appendLbugRootPath("extension/lance/build/liblance.lbug_extension"); + auto loadResult = conn->query("LOAD EXTENSION '" + extPath + "'"); + if (!loadResult->isSuccess()) GTEST_SKIP() << "Lance extension not available"; +#endif + + auto result = conn->query( + "CREATE NODE TABLE Ghost (id INT64 PRIMARY KEY) " + "storage='/nonexistent/path.lance' format='lance'"); + ASSERT_FALSE(result->isSuccess()); + ASSERT_NE(result->getErrorMessage().find("lance"), std::string::npos); +} + +TEST_F(LanceErrorScenariosTest, InvalidFormatParameterRejected) { + auto result = conn->query( + "CREATE NODE TABLE Bad (id INT64 PRIMARY KEY) format='lanceXXX'"); + ASSERT_FALSE(result->isSuccess()); + auto errMsg = result->getErrorMessage(); + ASSERT_TRUE(errMsg.find("lanceXXX") != std::string::npos || + errMsg.find("format") != std::string::npos || + errMsg.find("Invalid") != std::string::npos || + errMsg.find("Unsupported") != std::string::npos) + << "Unexpected error: " << errMsg; +} + +TEST_F(LanceErrorScenariosTest, DuplicateLoadIsIdempotent) { +#if defined(BUILD_DYNAMIC_LOAD) + auto extPath = TestHelper::appendLbugRootPath("extension/lance/build/liblance.lbug_extension"); + auto r1 = conn->query("LOAD EXTENSION '" + extPath + "'"); + if (!r1->isSuccess()) GTEST_SKIP() << "Lance extension not available"; + // Duplicate load must not fail + auto r2 = conn->query("LOAD EXTENSION '" + extPath + "'"); + ASSERT_TRUE(r2->isSuccess()) << r2->getErrorMessage(); +#else + GTEST_SKIP() << "Static build — duplicate load test not applicable"; +#endif +} diff --git a/lance/test/lance_node_table_test.cpp b/lance/test/lance_node_table_test.cpp new file mode 100644 index 0000000..cd8eed5 --- /dev/null +++ b/lance/test/lance_node_table_test.cpp @@ -0,0 +1,238 @@ +/// Tests for LanceNodeTable. +/// Covers the same scenarios as ArrowNodeTableTest: creation, type conversions, +/// empty tables, large data, and cypher query parity. +/// +/// Each test writes a lance dataset to a temp directory, creates a ladybug +/// graph, runs queries, and compares results to expected values. + +#include +#include +#include +#include +#include + +#include "common/arrow/arrow.h" +#include "graph_test/private_graph_test.h" +#include "gtest/gtest.h" +#include "lance/lance.hpp" +#include "main/query_result.h" + +using namespace lbug; +using namespace lbug::testing; +namespace fs = std::filesystem; + +// ─── Helpers ───────────────────────────────────────────────────────────────── + +/// Write a simple Int32 + String lance dataset to `path`. +static void writeLanceNodeDataset(const std::string& path, + const std::vector& intData, const std::vector& strData) { + const size_t n = intData.size(); + + // Build Arrow schema: struct + ArrowSchema schema; std::memset(&schema, 0, sizeof(schema)); + schema.format = "+s"; schema.n_children = 2; + schema.children = new ArrowSchema*[2]; + auto* idSchema = new ArrowSchema(); std::memset(idSchema, 0, sizeof(*idSchema)); + idSchema->format = "i"; idSchema->name = "id"; schema.children[0] = idSchema; + auto* nameSchema = new ArrowSchema(); std::memset(nameSchema, 0, sizeof(*nameSchema)); + nameSchema->format = "u"; nameSchema->name = "name"; schema.children[1] = nameSchema; + schema.release = [](ArrowSchema* s) { for (int i = 0; i < s->n_children; ++i) delete s->children[i]; delete[] s->children; s->release = nullptr; }; + + // Int32 column + auto* idBuf = new int32_t[n]; for (size_t i = 0; i < n; ++i) idBuf[i] = intData[i]; + auto* idArr = new ArrowArray(); std::memset(idArr, 0, sizeof(*idArr)); + idArr->length = n; idArr->n_buffers = 2; idArr->buffers = new const void*[2]; + idArr->buffers[0] = nullptr; idArr->buffers[1] = idBuf; + idArr->release = [](ArrowArray* a) { delete[] static_cast(a->buffers[1]); delete[] a->buffers; a->release = nullptr; }; + + // String column (utf8) + uint32_t totalBytes = 0; for (auto& s : strData) totalBytes += s.size(); + auto* offsets = new int32_t[n + 1]; auto* values = new char[totalBytes + 1]; + offsets[0] = 0; size_t pos = 0; + for (size_t i = 0; i < n; ++i) { std::memcpy(values + pos, strData[i].c_str(), strData[i].size()); pos += strData[i].size(); offsets[i + 1] = static_cast(pos); } + auto* nameArr = new ArrowArray(); std::memset(nameArr, 0, sizeof(*nameArr)); + nameArr->length = n; nameArr->n_buffers = 3; nameArr->buffers = new const void*[3]; + nameArr->buffers[0] = nullptr; nameArr->buffers[1] = offsets; nameArr->buffers[2] = values; + nameArr->release = [](ArrowArray* a) { delete[] static_cast(a->buffers[1]); delete[] static_cast(a->buffers[2]); delete[] a->buffers; a->release = nullptr; }; + + // Struct parent + ArrowArray parent; std::memset(&parent, 0, sizeof(parent)); + parent.length = n; parent.n_buffers = 1; parent.buffers = new const void*[1]; parent.buffers[0] = nullptr; + parent.n_children = 2; parent.children = new ArrowArray*[2]; + parent.children[0] = idArr; parent.children[1] = nameArr; + parent.release = [](ArrowArray* a) { for (int i = 0; i < a->n_children; ++i) { if (a->children[i] && a->children[i]->release) a->children[i]->release(a->children[i]); delete a->children[i]; } delete[] a->children; delete[] a->buffers; a->release = nullptr; }; + + struct SS { ArrowSchema schema; ArrowArray array; bool done = false; }; + auto* ss = new SS(); ss->schema = schema; schema.release = nullptr; ss->array = parent; parent.release = nullptr; ss->array.children = parent.children; + ArrowArrayStream stream; std::memset(&stream, 0, sizeof(stream)); stream.private_data = ss; + stream.get_schema = [](ArrowArrayStream* s, ArrowSchema* o) -> int { *o = static_cast(s->private_data)->schema; o->release = nullptr; return 0; }; + stream.get_next = [](ArrowArrayStream* s, ArrowArray* o) -> int { auto* st = static_cast(s->private_data); if (st->done) { std::memset(o, 0, sizeof(*o)); return 0; } *o = st->array; o->release = nullptr; o->children = st->array.children; st->done = true; return 0; }; + stream.release = [](ArrowArrayStream* s) { auto* st = static_cast(s->private_data); if (st->schema.release) st->schema.release(&st->schema); for (int i = 0; i < st->array.n_children; ++i) { if (st->array.children[i] && st->array.children[i]->release) st->array.children[i]->release(st->array.children[i]); delete st->array.children[i]; } delete[] st->array.children; delete[] st->array.buffers; delete st; s->release = nullptr; }; + lance::Dataset::write(path, &stream, lance::WriteMode::Create); +} + +/// RAII temp directory +struct TempDir { + fs::path path; + TempDir() : path(fs::temp_directory_path() / ("lance_test_" + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()))) { fs::create_directories(path); } + ~TempDir() { fs::remove_all(path); } + std::string str() const { return path.string(); } +}; + +// ─── Test fixture ───────────────────────────────────────────────────────────── + +class LanceNodeTableTest : public EmptyDBTest { +protected: + void SetUp() override { + EmptyDBTest::SetUp(); + createDBAndConn(); + +#if defined(BUILD_DYNAMIC_LOAD) + auto extPath = TestHelper::appendLbugRootPath( + "extension/lance/build/liblance.lbug_extension"); + auto result = conn->query("LOAD EXTENSION '" + extPath + "'"); + if (!result->isSuccess()) { + GTEST_SKIP() << "Lance extension not found at " << extPath + << " — skipping: " << result->getErrorMessage(); + } +#endif + } +}; + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +TEST_F(LanceNodeTableTest, CreateLanceNodeTableFromVectors) { + TempDir tmp; + auto datasetPath = (tmp.path / "nodes.lance").string(); + writeLanceNodeDataset(datasetPath, {1, 2, 3, 4, 5}, {"a", "b", "c", "d", "e"}); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE Person (id INT32 PRIMARY KEY, name STRING) " + "storage='" + datasetPath + "' format='lance'") + ->isSuccess()); + + auto result = conn->query("MATCH (p:Person) RETURN p.id, p.name ORDER BY p.id"); + ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); + + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 5u); + ASSERT_EQ(rows[0], "1|a"); + ASSERT_EQ(rows[4], "5|e"); +} + +TEST_F(LanceNodeTableTest, LanceTableCountRows) { + TempDir tmp; + auto datasetPath = (tmp.path / "nodes.lance").string(); + std::vector ids; + std::vector names; + for (int i = 0; i < 100; ++i) { ids.push_back(i); names.push_back("n" + std::to_string(i)); } + writeLanceNodeDataset(datasetPath, ids, names); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE Big (id INT32 PRIMARY KEY, name STRING) " + "storage='" + datasetPath + "' format='lance'")->isSuccess()); + + auto result = conn->query("MATCH (n:Big) RETURN COUNT(*)"); + ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 1u); + ASSERT_EQ(rows[0], "100"); +} + +TEST_F(LanceNodeTableTest, LanceTableWithFilter) { + TempDir tmp; + auto datasetPath = (tmp.path / "filtered.lance").string(); + writeLanceNodeDataset(datasetPath, {10, 20, 30, 40, 50}, {"x", "y", "z", "w", "v"}); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE Item (id INT32 PRIMARY KEY, name STRING) " + "storage='" + datasetPath + "' format='lance'")->isSuccess()); + + auto result = conn->query("MATCH (i:Item) WHERE i.id > 20 RETURN i.id ORDER BY i.id"); + ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 3u); + ASSERT_EQ(rows[0], "30"); + ASSERT_EQ(rows[2], "50"); +} + +TEST_F(LanceNodeTableTest, LanceTableAggregation) { + TempDir tmp; + auto datasetPath = (tmp.path / "agg.lance").string(); + writeLanceNodeDataset(datasetPath, {1, 2, 3, 4, 5}, {"a", "b", "c", "d", "e"}); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE Numbers (id INT32 PRIMARY KEY, name STRING) " + "storage='" + datasetPath + "' format='lance'")->isSuccess()); + + auto result = conn->query("MATCH (n:Numbers) RETURN SUM(n.id)"); + ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 1u); + ASSERT_EQ(rows[0], "15"); +} + +TEST_F(LanceNodeTableTest, LanceTableOrderByLimit) { + TempDir tmp; + auto datasetPath = (tmp.path / "order.lance").string(); + writeLanceNodeDataset(datasetPath, {5, 3, 1, 4, 2}, {"e", "c", "a", "d", "b"}); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE Sorted (id INT32 PRIMARY KEY, name STRING) " + "storage='" + datasetPath + "' format='lance'")->isSuccess()); + + auto result = conn->query("MATCH (s:Sorted) RETURN s.id ORDER BY s.id LIMIT 3"); + ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 3u); + ASSERT_EQ(rows[0], "1"); + ASSERT_EQ(rows[1], "2"); + ASSERT_EQ(rows[2], "3"); +} + +TEST_F(LanceNodeTableTest, LanceTableDistinct) { + TempDir tmp; + auto datasetPath = (tmp.path / "distinct.lance").string(); + writeLanceNodeDataset(datasetPath, {1, 2, 3, 4, 5}, {"a", "a", "b", "b", "c"}); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE D (id INT32 PRIMARY KEY, cat STRING) " + "storage='" + datasetPath + "' format='lance'")->isSuccess()); + + auto result = conn->query("MATCH (d:D) RETURN DISTINCT d.cat ORDER BY d.cat"); + ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 3u); + ASSERT_EQ(rows[0], "a"); + ASSERT_EQ(rows[1], "b"); + ASSERT_EQ(rows[2], "c"); +} + +TEST_F(LanceNodeTableTest, LanceTableLargeData) { + TempDir tmp; + auto datasetPath = (tmp.path / "large.lance").string(); + const size_t N = 10000; + std::vector ids(N); + std::vector names(N); + for (size_t i = 0; i < N; ++i) { ids[i] = static_cast(i); names[i] = "item_" + std::to_string(i); } + writeLanceNodeDataset(datasetPath, ids, names); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE Large (id INT32 PRIMARY KEY, name STRING) " + "storage='" + datasetPath + "' format='lance'")->isSuccess()); + + auto result = conn->query("MATCH (n:Large) RETURN COUNT(*)"); + ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows[0], std::to_string(N)); +} + +TEST_F(LanceNodeTableTest, MixedFormatRejected) { + TempDir tmp; + auto lanceDataset = (tmp.path / "mixed_nodes.lance").string(); + writeLanceNodeDataset(lanceDataset, {1, 2}, {"a", "b"}); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE LNode (id INT32 PRIMARY KEY, name STRING) " + "storage='" + lanceDataset + "' format='lance'")->isSuccess()); + ASSERT_TRUE(conn->query("CREATE NODE TABLE RegNode (id INT64 PRIMARY KEY)")->isSuccess()); + + // Attempt to create a lance rel table between lance and regular node tables — should fail + auto result = conn->query("CREATE REL TABLE BadRel (FROM LNode TO RegNode) " + "storage='some.lance' format='lance'"); + ASSERT_FALSE(result->isSuccess()); + ASSERT_NE(result->getErrorMessage().find("Cannot mix lance"), std::string::npos); +} diff --git a/lance/test/lance_rel_table_test.cpp b/lance/test/lance_rel_table_test.cpp new file mode 100644 index 0000000..716bafe --- /dev/null +++ b/lance/test/lance_rel_table_test.cpp @@ -0,0 +1,241 @@ +/// Tests for LanceRelTable. +/// Covers forward/backward scans, filtering, aggregation, and join patterns. + +#include +#include +#include +#include + +#include "common/arrow/arrow.h" +#include "graph_test/private_graph_test.h" +#include "gtest/gtest.h" +#include "lance/lance.hpp" +#include "main/query_result.h" + +using namespace lbug; +using namespace lbug::testing; +namespace fs = std::filesystem; + +// ─── Helpers (shared with node table test via inline) ──────────────────────── + +/// Write a lance dataset with columns [from: int64, to: int64, weight: float] +static void writeLanceRelDataset(const std::string& path, + const std::vector& from, const std::vector& to, + const std::vector& weight = {}) { + const size_t n = from.size(); + + // Build schema + int numCols = weight.empty() ? 2 : 3; + ArrowSchema schema; + std::memset(&schema, 0, sizeof(schema)); + schema.format = "+s"; schema.name = ""; + schema.n_children = numCols; + schema.children = new ArrowSchema*[numCols]; + + auto makeField = [](const char* name, const char* fmt) { + auto* s = new ArrowSchema(); + std::memset(s, 0, sizeof(*s)); + s->format = fmt; s->name = name; return s; + }; + schema.children[0] = makeField("from", "l"); // int64 + schema.children[1] = makeField("to", "l"); // int64 + if (!weight.empty()) schema.children[2] = makeField("weight", "f"); // float32 + schema.release = [](ArrowSchema* s) { + for (int i = 0; i < s->n_children; ++i) delete s->children[i]; + delete[] s->children; s->release = nullptr; }; + + auto makeInt64Array = [&](const std::vector& data) { + auto* buf = new int64_t[data.size()]; + for (size_t i = 0; i < data.size(); ++i) buf[i] = data[i]; + auto* a = new ArrowArray(); std::memset(a, 0, sizeof(*a)); + a->length = data.size(); a->n_buffers = 2; + a->buffers = new const void*[2]; a->buffers[0] = nullptr; a->buffers[1] = buf; + a->release = [](ArrowArray* arr) { delete[] static_cast(arr->buffers[1]); + delete[] arr->buffers; arr->release = nullptr; }; + return a; + }; + auto makeFloatArray = [&](const std::vector& data) { + auto* buf = new float[data.size()]; + for (size_t i = 0; i < data.size(); ++i) buf[i] = data[i]; + auto* a = new ArrowArray(); std::memset(a, 0, sizeof(*a)); + a->length = data.size(); a->n_buffers = 2; + a->buffers = new const void*[2]; a->buffers[0] = nullptr; a->buffers[1] = buf; + a->release = [](ArrowArray* arr) { delete[] static_cast(arr->buffers[1]); + delete[] arr->buffers; arr->release = nullptr; }; + return a; + }; + + ArrowArray parent; std::memset(&parent, 0, sizeof(parent)); + parent.length = n; parent.n_buffers = 1; + parent.buffers = new const void*[1]; parent.buffers[0] = nullptr; + parent.n_children = numCols; + parent.children = new ArrowArray*[numCols]; + parent.children[0] = makeInt64Array(from); + parent.children[1] = makeInt64Array(to); + if (!weight.empty()) parent.children[2] = makeFloatArray(weight); + parent.release = [](ArrowArray* a) { + for (int i = 0; i < a->n_children; ++i) { if (a->children[i] && a->children[i]->release) a->children[i]->release(a->children[i]); delete a->children[i]; } + delete[] a->children; delete[] a->buffers; a->release = nullptr; }; + + struct SS { ArrowSchema schema; ArrowArray array; bool done = false; }; + auto* ss = new SS(); ss->schema = schema; schema.release = nullptr; ss->array = parent; parent.release = nullptr; + ss->array.children = parent.children; // already set + + ArrowArrayStream stream; std::memset(&stream, 0, sizeof(stream)); + stream.private_data = ss; + stream.get_schema = [](ArrowArrayStream* s, ArrowSchema* out) -> int { *out = static_cast(s->private_data)->schema; out->release = nullptr; return 0; }; + stream.get_next = [](ArrowArrayStream* s, ArrowArray* out) -> int { + auto* st = static_cast(s->private_data); + if (st->done) { std::memset(out, 0, sizeof(*out)); return 0; } + *out = st->array; out->release = nullptr; out->children = st->array.children; st->done = true; return 0; }; + stream.release = [](ArrowArrayStream* s) { + auto* st = static_cast(s->private_data); + if (st->schema.release) st->schema.release(&st->schema); + // Free array children + for (int i = 0; i < st->array.n_children; ++i) { + if (st->array.children[i] && st->array.children[i]->release) st->array.children[i]->release(st->array.children[i]); + delete st->array.children[i]; + } + delete[] st->array.children; delete[] st->array.buffers; delete st; s->release = nullptr; }; + + lance::Dataset::write(path, &stream, lance::WriteMode::Create); +} + +static void writeLanceNodeDatasetInt64(const std::string& path, + const std::vector& ids, const std::vector& names) { + const size_t n = ids.size(); + // Same as node table helper but using int64 for id + auto makeField = [](const char* name, const char* fmt) { + auto* s = new ArrowSchema(); std::memset(s, 0, sizeof(*s)); s->format = fmt; s->name = name; return s; }; + + ArrowSchema schema; std::memset(&schema, 0, sizeof(schema)); + schema.format = "+s"; schema.n_children = 2; + schema.children = new ArrowSchema*[2]; + schema.children[0] = makeField("id", "l"); schema.children[1] = makeField("name", "u"); + schema.release = [](ArrowSchema* s) { for (int i = 0; i < s->n_children; ++i) delete s->children[i]; delete[] s->children; s->release = nullptr; }; + + auto* idBuf = new int64_t[n]; for (size_t i = 0; i < n; ++i) idBuf[i] = ids[i]; + auto* idArr = new ArrowArray(); std::memset(idArr, 0, sizeof(*idArr)); + idArr->length = n; idArr->n_buffers = 2; idArr->buffers = new const void*[2]; + idArr->buffers[0] = nullptr; idArr->buffers[1] = idBuf; + idArr->release = [](ArrowArray* a) { delete[] static_cast(a->buffers[1]); delete[] a->buffers; a->release = nullptr; }; + + uint32_t totalBytes = 0; for (auto& s : names) totalBytes += s.size(); + auto* offsets = new int32_t[n + 1]; auto* values = new char[totalBytes + 1]; + offsets[0] = 0; size_t pos = 0; + for (size_t i = 0; i < n; ++i) { std::memcpy(values + pos, names[i].c_str(), names[i].size()); pos += names[i].size(); offsets[i + 1] = static_cast(pos); } + auto* nameArr = new ArrowArray(); std::memset(nameArr, 0, sizeof(*nameArr)); + nameArr->length = n; nameArr->n_buffers = 3; nameArr->buffers = new const void*[3]; + nameArr->buffers[0] = nullptr; nameArr->buffers[1] = offsets; nameArr->buffers[2] = values; + nameArr->release = [](ArrowArray* a) { delete[] static_cast(a->buffers[1]); delete[] static_cast(a->buffers[2]); delete[] a->buffers; a->release = nullptr; }; + + ArrowArray parent; std::memset(&parent, 0, sizeof(parent)); + parent.length = n; parent.n_buffers = 1; parent.buffers = new const void*[1]; parent.buffers[0] = nullptr; + parent.n_children = 2; parent.children = new ArrowArray*[2]; + parent.children[0] = idArr; parent.children[1] = nameArr; + parent.release = [](ArrowArray* a) { for (int i = 0; i < a->n_children; ++i) { if (a->children[i] && a->children[i]->release) a->children[i]->release(a->children[i]); delete a->children[i]; } delete[] a->children; delete[] a->buffers; a->release = nullptr; }; + + struct SS { ArrowSchema schema; ArrowArray array; bool done = false; }; + auto* ss = new SS(); ss->schema = schema; schema.release = nullptr; ss->array = parent; parent.release = nullptr; ss->array.children = parent.children; + ArrowArrayStream stream; std::memset(&stream, 0, sizeof(stream)); stream.private_data = ss; + stream.get_schema = [](ArrowArrayStream* s, ArrowSchema* o) -> int { *o = static_cast(s->private_data)->schema; o->release = nullptr; return 0; }; + stream.get_next = [](ArrowArrayStream* s, ArrowArray* o) -> int { auto* st = static_cast(s->private_data); if (st->done) { std::memset(o, 0, sizeof(*o)); return 0; } *o = st->array; o->release = nullptr; o->children = st->array.children; st->done = true; return 0; }; + stream.release = [](ArrowArrayStream* s) { auto* st = static_cast(s->private_data); if (st->schema.release) st->schema.release(&st->schema); for (int i = 0; i < st->array.n_children; ++i) { if (st->array.children[i] && st->array.children[i]->release) st->array.children[i]->release(st->array.children[i]); delete st->array.children[i]; } delete[] st->array.children; delete[] st->array.buffers; delete st; s->release = nullptr; }; + lance::Dataset::write(path, &stream, lance::WriteMode::Create); +} + +struct TempDir2 { + fs::path path; + TempDir2() : path(fs::temp_directory_path() / ("lance_rel_" + std::to_string(std::chrono::steady_clock::now().time_since_epoch().count()))) { fs::create_directories(path); } + ~TempDir2() { fs::remove_all(path); } + std::string str() const { return path.string(); } +}; + +// ─── Fixture ───────────────────────────────────────────────────────────────── + +class LanceRelTableTest : public EmptyDBTest { +protected: + void SetUp() override { + EmptyDBTest::SetUp(); + createDBAndConn(); +#if defined(BUILD_DYNAMIC_LOAD) + auto extPath = TestHelper::appendLbugRootPath("extension/lance/build/liblance.lbug_extension"); + auto result = conn->query("LOAD EXTENSION '" + extPath + "'"); + if (!result->isSuccess()) GTEST_SKIP() << "Lance extension not available: " << result->getErrorMessage(); +#endif + } +}; + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +TEST_F(LanceRelTableTest, FwdScanSimple) { + TempDir2 tmp; + auto nodePath = (tmp.path / "users.lance").string(); + auto relPath = (tmp.path / "follows.lance").string(); + writeLanceNodeDatasetInt64(nodePath, {0, 1, 2, 3}, {"alice", "bob", "carol", "dan"}); + writeLanceRelDataset(relPath, {0, 0, 1}, {1, 2, 3}, {1.0f, 2.0f, 3.0f}); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE User (id INT64 PRIMARY KEY, name STRING) storage='" + nodePath + "' format='lance'")->isSuccess()); + ASSERT_TRUE(conn->query("CREATE REL TABLE Follows (FROM User TO User, weight FLOAT) storage='" + relPath + "' format='lance'")->isSuccess()); + + auto result = conn->query("MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, b.name ORDER BY a.name, b.name"); + ASSERT_TRUE(result->isSuccess()); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 3u); + ASSERT_EQ(rows[0], "alice|bob"); + ASSERT_EQ(rows[1], "alice|carol"); + ASSERT_EQ(rows[2], "bob|dan"); +} + +TEST_F(LanceRelTableTest, BwdScan) { + TempDir2 tmp; + auto nodePath = (tmp.path / "users.lance").string(); + auto relPath = (tmp.path / "follows.lance").string(); + writeLanceNodeDatasetInt64(nodePath, {0, 1, 2}, {"alice", "bob", "carol"}); + writeLanceRelDataset(relPath, {0, 1}, {1, 2}); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE U2 (id INT64 PRIMARY KEY, name STRING) storage='" + nodePath + "' format='lance'")->isSuccess()); + ASSERT_TRUE(conn->query("CREATE REL TABLE F2 (FROM U2 TO U2) storage='" + relPath + "' format='lance'")->isSuccess()); + + // Backward: who follows carol? + auto result = conn->query("MATCH (a:U2)<-[:F2]-(b:U2) WHERE a.name = 'carol' RETURN b.name"); + ASSERT_TRUE(result->isSuccess()); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 1u); + ASSERT_EQ(rows[0], "bob"); +} + +TEST_F(LanceRelTableTest, RelPropertyQuery) { + TempDir2 tmp; + auto nodePath = (tmp.path / "items.lance").string(); + auto relPath = (tmp.path / "edges.lance").string(); + writeLanceNodeDatasetInt64(nodePath, {0, 1, 2}, {"A", "B", "C"}); + writeLanceRelDataset(relPath, {0, 1}, {1, 2}, {10.5f, 20.5f}); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE Item (id INT64 PRIMARY KEY, name STRING) storage='" + nodePath + "' format='lance'")->isSuccess()); + ASSERT_TRUE(conn->query("CREATE REL TABLE Edge (FROM Item TO Item, weight FLOAT) storage='" + relPath + "' format='lance'")->isSuccess()); + + auto result = conn->query("MATCH (a:Item)-[e:Edge]->(b:Item) RETURN e.weight ORDER BY e.weight"); + ASSERT_TRUE(result->isSuccess()); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 2u); +} + +TEST_F(LanceRelTableTest, RelAggregation) { + TempDir2 tmp; + auto nodePath = (tmp.path / "a.lance").string(); + auto relPath = (tmp.path / "b.lance").string(); + writeLanceNodeDatasetInt64(nodePath, {0, 1, 2, 3}, {"x", "y", "z", "w"}); + writeLanceRelDataset(relPath, {0, 0, 1, 1}, {1, 2, 2, 3}, {1.0f, 2.0f, 3.0f, 4.0f}); + + ASSERT_TRUE(conn->query("CREATE NODE TABLE N3 (id INT64 PRIMARY KEY, name STRING) storage='" + nodePath + "' format='lance'")->isSuccess()); + ASSERT_TRUE(conn->query("CREATE REL TABLE R3 (FROM N3 TO N3, w FLOAT) storage='" + relPath + "' format='lance'")->isSuccess()); + + auto result = conn->query("MATCH (a:N3)-[r:R3]->(b:N3) RETURN a.name, COUNT(*) ORDER BY a.name"); + ASSERT_TRUE(result->isSuccess()); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 2u); + // x has 2 outgoing, y has 2 outgoing + ASSERT_EQ(rows[0], "x|2"); + ASSERT_EQ(rows[1], "y|2"); +} diff --git a/lance/test/lance_vector_search_test.cpp b/lance/test/lance_vector_search_test.cpp new file mode 100644 index 0000000..ebe2a8d --- /dev/null +++ b/lance/test/lance_vector_search_test.cpp @@ -0,0 +1,158 @@ +/// Tests for lance vector search, FTS, and hybrid search functions. + +#include +#include +#include +#include + +#include "common/arrow/arrow.h" +#include "graph_test/private_graph_test.h" +#include "gtest/gtest.h" +#include "lance/lance.hpp" +#include "main/query_result.h" + +using namespace lbug; +using namespace lbug::testing; +namespace fs = std::filesystem; + +struct TempDirVec { + fs::path path; + TempDirVec() : path(fs::temp_directory_path() / ("lance_vec_" + std::to_string( + std::chrono::steady_clock::now().time_since_epoch().count()))) { fs::create_directories(path); } + ~TempDirVec() { fs::remove_all(path); } + std::string str() const { return path.string(); } +}; + +/// Write a lance dataset with a fixed-size vector column [id: int64, vec: fixed_size_list[4]] +static void writeLanceVectorDataset(const std::string& path, + const std::vector& ids, const std::vector>& vecs) { + const size_t n = ids.size(); + const int dim = vecs.empty() ? 0 : static_cast(vecs[0].size()); + + // Schema + ArrowSchema schema; std::memset(&schema, 0, sizeof(schema)); + schema.format = "+s"; schema.n_children = 2; + schema.children = new ArrowSchema*[2]; + + auto* idF = new ArrowSchema(); std::memset(idF, 0, sizeof(*idF)); idF->format = "l"; idF->name = "id"; + schema.children[0] = idF; + + char dimFmt[32]; std::snprintf(dimFmt, sizeof(dimFmt), "+w:%d", dim); + auto* vecF = new ArrowSchema(); std::memset(vecF, 0, sizeof(*vecF)); + vecF->format = dimFmt; vecF->name = "vec"; vecF->n_children = 1; + vecF->children = new ArrowSchema*[1]; + auto* elemF = new ArrowSchema(); std::memset(elemF, 0, sizeof(*elemF)); elemF->format = "f"; elemF->name = "item"; + vecF->children[0] = elemF; + vecF->release = [](ArrowSchema* s) { delete s->children[0]; delete[] s->children; s->release = nullptr; }; + schema.children[1] = vecF; + schema.release = [](ArrowSchema* s) { + delete s->children[0]; + if (s->children[1]->release) s->children[1]->release(s->children[1]); + delete s->children[1]; + delete[] s->children; s->release = nullptr; }; + + // Id array + auto* idBuf = new int64_t[n]; for (size_t i = 0; i < n; ++i) idBuf[i] = ids[i]; + auto* idArr = new ArrowArray(); std::memset(idArr, 0, sizeof(*idArr)); + idArr->length = n; idArr->n_buffers = 2; idArr->buffers = new const void*[2]; + idArr->buffers[0] = nullptr; idArr->buffers[1] = idBuf; + idArr->release = [](ArrowArray* a) { delete[] static_cast(a->buffers[1]); delete[] a->buffers; a->release = nullptr; }; + + // Vec array (fixed_size_list wrapping flat float child) + auto* elemBuf = new float[n * dim]; + for (size_t i = 0; i < n; ++i) for (int j = 0; j < dim; ++j) elemBuf[i * dim + j] = vecs[i][j]; + auto* elemArr = new ArrowArray(); std::memset(elemArr, 0, sizeof(*elemArr)); + elemArr->length = static_cast(n * dim); elemArr->n_buffers = 2; elemArr->buffers = new const void*[2]; + elemArr->buffers[0] = nullptr; elemArr->buffers[1] = elemBuf; + elemArr->release = [](ArrowArray* a) { delete[] static_cast(a->buffers[1]); delete[] a->buffers; a->release = nullptr; }; + + auto* vecArr = new ArrowArray(); std::memset(vecArr, 0, sizeof(*vecArr)); + vecArr->length = n; vecArr->n_buffers = 1; vecArr->buffers = new const void*[1]; vecArr->buffers[0] = nullptr; + vecArr->n_children = 1; vecArr->children = new ArrowArray*[1]; vecArr->children[0] = elemArr; + vecArr->release = [](ArrowArray* a) { + if (a->children[0] && a->children[0]->release) a->children[0]->release(a->children[0]); delete a->children[0]; + delete[] a->children; delete[] a->buffers; a->release = nullptr; }; + + ArrowArray parent; std::memset(&parent, 0, sizeof(parent)); + parent.length = n; parent.n_buffers = 1; parent.buffers = new const void*[1]; parent.buffers[0] = nullptr; + parent.n_children = 2; parent.children = new ArrowArray*[2]; + parent.children[0] = idArr; parent.children[1] = vecArr; + parent.release = [](ArrowArray* a) { + for (int i = 0; i < a->n_children; ++i) { if (a->children[i] && a->children[i]->release) a->children[i]->release(a->children[i]); delete a->children[i]; } + delete[] a->children; delete[] a->buffers; a->release = nullptr; }; + + struct SS { ArrowSchema sc; ArrowArray ar; bool done = false; }; + auto* ss = new SS(); ss->sc = schema; schema.release = nullptr; ss->ar = parent; parent.release = nullptr; ss->ar.children = parent.children; + ArrowArrayStream stream; std::memset(&stream, 0, sizeof(stream)); stream.private_data = ss; + stream.get_schema = [](ArrowArrayStream* s, ArrowSchema* o) -> int { *o = static_cast(s->private_data)->sc; o->release = nullptr; return 0; }; + stream.get_next = [](ArrowArrayStream* s, ArrowArray* o) -> int { + auto* st = static_cast(s->private_data); + if (st->done) { std::memset(o, 0, sizeof(*o)); return 0; } + *o = st->ar; o->release = nullptr; o->children = st->ar.children; st->done = true; return 0; }; + stream.release = [](ArrowArrayStream* s) { + auto* st = static_cast(s->private_data); + if (st->sc.release) st->sc.release(&st->sc); + for (int i = 0; i < st->ar.n_children; ++i) { if (st->ar.children[i] && st->ar.children[i]->release) st->ar.children[i]->release(st->ar.children[i]); delete st->ar.children[i]; } + delete[] st->ar.children; delete[] st->ar.buffers; delete st; s->release = nullptr; }; + + lance::Dataset::write(path, &stream, lance::WriteMode::Create); +} + +// ─── Fixture ───────────────────────────────────────────────────────────────── + +class LanceVectorSearchTest : public EmptyDBTest { +protected: + void SetUp() override { + EmptyDBTest::SetUp(); + createDBAndConn(); +#if defined(BUILD_DYNAMIC_LOAD) + auto extPath = TestHelper::appendLbugRootPath("extension/lance/build/liblance.lbug_extension"); + auto result = conn->query("LOAD EXTENSION '" + extPath + "'"); + if (!result->isSuccess()) GTEST_SKIP() << "Lance extension not available: " << result->getErrorMessage(); +#endif + } +}; + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +TEST_F(LanceVectorSearchTest, BasicVectorSearch) { + TempDirVec tmp; + auto dspath = (tmp.path / "vecs.lance").string(); + + writeLanceVectorDataset(dspath, + {0, 1, 2, 3}, + {{1.0f, 0.0f, 0.0f, 0.0f}, + {0.0f, 1.0f, 0.0f, 0.0f}, + {0.0f, 0.0f, 1.0f, 0.0f}, + {0.0f, 0.0f, 0.0f, 1.0f}}); + + // Query with vector closest to [1, 0, 0, 0] — should return id=0 + auto result = conn->query( + "CALL LANCE_VECTOR_SEARCH('" + dspath + "', 'vec', [1.0, 0.0, 0.0, 0.0], 1) " + "RETURN id"); + ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 1u); + ASSERT_EQ(rows[0], "0"); +} + +TEST_F(LanceVectorSearchTest, VectorSearchTopK) { + TempDirVec tmp; + auto dspath = (tmp.path / "topk.lance").string(); + + writeLanceVectorDataset(dspath, + {0, 1, 2}, + {{1.0f, 0.0f, 0.0f, 0.0f}, + {0.9f, 0.1f, 0.0f, 0.0f}, + {0.0f, 0.0f, 1.0f, 0.0f}}); + + auto result = conn->query( + "CALL LANCE_VECTOR_SEARCH('" + dspath + "', 'vec', [1.0, 0.0, 0.0, 0.0], 2) " + "RETURN id ORDER BY id"); + ASSERT_TRUE(result->isSuccess()) << result->getErrorMessage(); + auto rows = TestHelper::convertResultToString(*result); + ASSERT_EQ(rows.size(), 2u); + // Top-2 nearest to [1,0,0,0] should be 0 and 1 + ASSERT_NE(std::find(rows.begin(), rows.end(), "0"), rows.end()); + ASSERT_NE(std::find(rows.begin(), rows.end(), "1"), rows.end()); +} diff --git a/lance/test/test_files/lance_complex_queries.test b/lance/test/test_files/lance_complex_queries.test new file mode 100644 index 0000000..cd04805 --- /dev/null +++ b/lance/test/test_files/lance_complex_queries.test @@ -0,0 +1,132 @@ +-DATASET LANCE lance-test +-- + +-CASE LanceComplexQueries + +-LOG TwoHopCrossRel +-STATEMENT MATCH (a:user {id: 100})-[:follows]->(b)-[:livesin]->(c:city) RETURN b.name, c.name ORDER BY b.name; +---- 3 +Adam|Waterloo +Karissa|Waterloo +Zhang|Kitchener + +-LOG BackwardInNeighbors +-STATEMENT MATCH (u)<-[:follows]-(v) WHERE u.id = 300 RETURN v.name ORDER BY v.name; +---- 2 +Adam +Karissa + +-LOG UndirectedLivesIn +-STATEMENT MATCH (a:user)-[:livesin]-(c:city) RETURN a.name, c.name ORDER BY a.name; +---- 4 +Adam|Waterloo +Karissa|Waterloo +Noura|Guelph +Zhang|Kitchener + +-LOG CyclicTriangle +-STATEMENT MATCH (a:user)-[:follows]->(b:user)-[:follows]->(c:user), (a)-[:follows]->(c) WHERE a.id <> b.id AND b.id <> c.id AND a.id <> c.id RETURN a.name, b.name, c.name ORDER BY a.name, b.name; +---- 2 +Adam|Karissa|Zhang +Karissa|Adam|Zhang + +-LOG SemiMaskerVarLen +-STATEMENT MATCH (a:user {id: 100})-[:follows*1..2 (r, n | WHERE n.age > 40)]->(b:user) RETURN b.name ORDER BY b.name; +---- 4 +Adam +Karissa +Noura +Zhang + +-LOG VarLenThreeHop +-STATEMENT MATCH (a:user {id: 100})-[:follows*3..3]->(b:user) RETURN DISTINCT b.name ORDER BY b.name; +---- 4 +Adam +Karissa +Noura +Zhang + +-LOG FlattenMultiPartMatch +-STATEMENT MATCH (a:user)-[:follows]->(b:user) WITH a, b MATCH (b)-[:livesin]->(c:city) RETURN a.name, b.name, c.name ORDER BY a.name, b.name; +---- 7 +Adam|Adam|Waterloo +Adam|Karissa|Waterloo +Adam|Zhang|Kitchener +Karissa|Adam|Waterloo +Karissa|Zhang|Kitchener +Noura|Adam|Waterloo +Zhang|Noura|Guelph + +-LOG HashJoinSharedFollowee +-STATEMENT MATCH (a:user)-[:follows]->(b:user), (c:user)-[:follows]->(b) WHERE a.id < c.id RETURN a.name, b.name, c.name ORDER BY a.name, b.name, c.name; +---- 4 +Adam|Adam|Karissa +Adam|Adam|Noura +Adam|Zhang|Karissa +Karissa|Adam|Noura + +-LOG ProfileQuery +-STATEMENT PROFILE MATCH (u:user)-[:follows]->(v:user) RETURN u.name, v.name; +---- ok + +-LOG SelfLoopFollows +-STATEMENT MATCH (a:user)-[:follows]->(a) RETURN a.name; +---- 1 +Adam + +-LOG SelfLoopExcluded +-STATEMENT MATCH (a:user)-[:follows]->(b:user) WHERE a.id <> b.id RETURN COUNT(*); +---- 1 +6 + +-LOG BackwardMultiHopCityUserUser +-STATEMENT MATCH (c:city)<-[:livesin]-(u:user)<-[:follows]-(f:user) RETURN f.name, u.name, c.name ORDER BY f.name, u.name; +---- 7 +Adam|Adam|Waterloo +Adam|Karissa|Waterloo +Adam|Zhang|Kitchener +Karissa|Adam|Waterloo +Karissa|Zhang|Kitchener +Noura|Adam|Waterloo +Zhang|Noura|Guelph + +-LOG CrossRelCityFollowsCity +-STATEMENT MATCH (c1:city)<-[:livesin]-(u:user)-[:follows]->(v:user)-[:livesin]->(c2:city) RETURN c1.name, u.name, v.name, c2.name ORDER BY u.name, v.name; +---- 7 +Waterloo|Adam|Adam|Waterloo +Waterloo|Adam|Karissa|Waterloo +Waterloo|Adam|Zhang|Kitchener +Waterloo|Karissa|Adam|Waterloo +Waterloo|Karissa|Zhang|Kitchener +Guelph|Noura|Adam|Waterloo +Kitchener|Zhang|Noura|Guelph + +-LOG NodePropertyFilter +-STATEMENT MATCH (u:user) WHERE u.age > 35 RETURN u.name ORDER BY u.name; +---- 3 +Karissa +Noura +Zhang + +-LOG AggregateWithGroup +-STATEMENT MATCH (u:user)-[:follows]->(v:user) RETURN u.name, COUNT(*) AS cnt ORDER BY u.name; +---- 4 +Adam|3 +Karissa|2 +Noura|1 +Zhang|1 + +-LOG AllNodesScan +-STATEMENT MATCH (u:user) RETURN u.id, u.name ORDER BY u.id; +---- 4 +100|Adam +200|Karissa +300|Zhang +400|Noura + +-LOG AllCitiesScan +-STATEMENT MATCH (c:city) RETURN c.id, c.name ORDER BY c.id; +---- 3 +1|Waterloo +2|Kitchener +3|Guelph diff --git a/lance/test/test_files/lance_invalid_storage.test b/lance/test/test_files/lance_invalid_storage.test new file mode 100644 index 0000000..242762c --- /dev/null +++ b/lance/test/test_files/lance_invalid_storage.test @@ -0,0 +1,37 @@ +-DATASET CSV empty +-- + +-CASE LanceNonExistentDirNodeFails + +-LOG LanceNonExistentDirMissingDataset +-STATEMENT LOAD EXTENSION 'lance'; +---- ok +-STATEMENT CREATE NODE TABLE t(id INT64 PRIMARY KEY) WITH (storage = '/nonexistent/dir/that/does/not/exist/t.lance', format = 'lance') +---- ok +-STATEMENT MATCH (n:t) RETURN * +---- error(regex) +.*[Nn]o such file.*|.*[Cc]annot open.*|.*[Nn]ot found.*|.*[Ll]ance.* + +-CASE LanceRemoteURINodeFails + +-LOG LanceRemoteURIMissingDataset +-STATEMENT LOAD EXTENSION 'lance'; +---- ok +-STATEMENT CREATE NODE TABLE t2(id INT64 PRIMARY KEY) WITH (storage = 'https://example.com/path/t2.lance', format = 'lance') +---- ok +-STATEMENT MATCH (n:t2) RETURN * +---- error(regex) +.*[Nn]o such file.*|.*[Cc]annot open.*|.*[Nn]ot found.*|.*[Ll]ance.*|.*[Uu]nsupported.* + +-CASE LanceMixWithRegularFails + +-LOG LanceMixRelWithRegularNode +-STATEMENT LOAD EXTENSION 'lance'; +---- ok +-STATEMENT CREATE NODE TABLE person2(id INT64 PRIMARY KEY, name STRING) +---- ok +-STATEMENT CREATE NODE TABLE person3(id INT64 PRIMARY KEY) WITH (storage = '${LBUG_ROOT_DIRECTORY}/dataset/lance-test/user.lance', format = 'lance') +---- ok +-STATEMENT CREATE REL TABLE rel2(FROM person2 TO person3) +---- error(regex) +.*[Cc]annot mix.*[Ll]ance.*|.*[Ll]ance.*[Cc]annot mix.* diff --git a/lance/test/test_files/lance_large_scan.test b/lance/test/test_files/lance_large_scan.test new file mode 100644 index 0000000..7b55b7b --- /dev/null +++ b/lance/test/test_files/lance_large_scan.test @@ -0,0 +1,33 @@ +-DATASET LANCE lance-large-test +-- + +-CASE LanceLargeNodeScan + +-LOG CountBeyondOneVector +-STATEMENT MATCH (n:nodes) RETURN count(n); +---- 1 +3000 + +-LOG SkipPastFirstVector +-STATEMENT MATCH (n:nodes) RETURN n.id SKIP 2048 LIMIT 10; +---- 10 +2048 +2049 +2050 +2051 +2052 +2053 +2054 +2055 +2056 +2057 + +-LOG SumAllValues +-STATEMENT MATCH (n:nodes) RETURN sum(n.val); +---- 1 +4498500 + +-LOG FilterAndCount +-STATEMENT MATCH (n:nodes) WHERE n.id >= 2900 RETURN count(n); +---- 1 +100