diff --git a/src/include/storage/table/arrow_node_table.h b/src/include/storage/table/arrow_node_table.h index febe70920..8f5a5c7e4 100644 --- a/src/include/storage/table/arrow_node_table.h +++ b/src/include/storage/table/arrow_node_table.h @@ -96,13 +96,15 @@ class ArrowNodeTable final : public ColumnarNodeTableBase { const ArrowSchemaWrapper& getArrowSchema() const { return schema; } const std::vector& getArrowArrays() const { return arrays; } - common::node_group_idx_t getNumBatches( + common::node_group_idx_t getNumScanMorsels( const transaction::Transaction* transaction) const override; - size_t getNumScanMorsels(const transaction::Transaction* transaction) const; - const catalog::NodeTableCatalogEntry* getCatalogEntry() const { return nodeTableCatalogEntry; } + std::unique_ptr createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, + MemoryManager* memoryManager) const override; + protected: std::string getColumnarFormatName() const override { return "Arrow"; } common::row_idx_t getTotalRowCount(const transaction::Transaction* transaction) const override; diff --git a/src/include/storage/table/arrow_rel_table.h b/src/include/storage/table/arrow_rel_table.h index 7f6d8c928..e6bdb68f0 100644 --- a/src/include/storage/table/arrow_rel_table.h +++ b/src/include/storage/table/arrow_rel_table.h @@ -41,6 +41,9 @@ class ArrowRelTable final : public ColumnarRelTableBase { bool resetCachedBoundNodeSelVec = true) const override; bool scanInternal(transaction::Transaction* transaction, TableScanState& scanState) override; + std::unique_ptr createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, MemoryManager* memoryManager, + std::shared_ptr outChunkState) const override; protected: std::string getColumnarFormatName() const override { return "Arrow"; } diff --git a/src/include/storage/table/columnar_node_table_base.h b/src/include/storage/table/columnar_node_table_base.h index a719aa840..9c91bc097 100644 --- a/src/include/storage/table/columnar_node_table_base.h +++ b/src/include/storage/table/columnar_node_table_base.h @@ -72,8 +72,6 @@ class ColumnarNodeTableBase : public NodeTable { // Template method pattern: subclasses implement format-specific operations virtual std::string getColumnarFormatName() const = 0; - virtual common::node_group_idx_t getNumBatches( - const transaction::Transaction* transaction) const = 0; virtual common::row_idx_t getTotalRowCount( const transaction::Transaction* transaction) const = 0; @@ -81,6 +79,11 @@ class ColumnarNodeTableBase : public NodeTable { ColumnarNodeTableScanSharedState* getTableScanSharedState() const { return tableScanSharedState.get(); } + virtual std::unique_ptr createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, + MemoryManager* memoryManager) const = 0; + virtual common::node_group_idx_t getNumScanMorsels( + const transaction::Transaction* transaction) const = 0; }; } // namespace storage diff --git a/src/include/storage/table/columnar_rel_table_base.h b/src/include/storage/table/columnar_rel_table_base.h index 0cbc6eb2f..686f92c6b 100644 --- a/src/include/storage/table/columnar_rel_table_base.h +++ b/src/include/storage/table/columnar_rel_table_base.h @@ -51,6 +51,9 @@ class ColumnarRelTableBase : public RelTable { std::vector> getTopKDegrees( const transaction::Transaction* transaction, common::RelDataDirection direction, common::idx_t k) override; + virtual std::unique_ptr createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, MemoryManager* memoryManager, + std::shared_ptr outChunkState) const = 0; protected: catalog::RelGroupCatalogEntry* relGroupEntry; diff --git a/src/include/storage/table/ice_disk_node_table.h b/src/include/storage/table/ice_disk_node_table.h index 2a6dc251f..e266f0a5a 100644 --- a/src/include/storage/table/ice_disk_node_table.h +++ b/src/include/storage/table/ice_disk_node_table.h @@ -76,12 +76,15 @@ class IceDiskNodeTable final : public ColumnarNodeTableBase { common::offset_t offset) const override; const std::string& getParquetFilePath() const { return parquetFilePath; } + std::unique_ptr createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, + MemoryManager* memoryManager) const override; + common::node_group_idx_t getNumScanMorsels( + const transaction::Transaction* transaction) const override; protected: // Implement ColumnarNodeTableBase interface std::string getColumnarFormatName() const override { return "icebug-disk"; } - common::node_group_idx_t getNumBatches( - const transaction::Transaction* transaction) const override; common::row_idx_t getTotalRowCount(const transaction::Transaction* transaction) const override; private: diff --git a/src/include/storage/table/ice_disk_rel_table.h b/src/include/storage/table/ice_disk_rel_table.h index e04bc75dc..22a1c719e 100644 --- a/src/include/storage/table/ice_disk_rel_table.h +++ b/src/include/storage/table/ice_disk_rel_table.h @@ -65,6 +65,9 @@ class IceDiskRelTable final : public ColumnarRelTableBase { bool resetCachedBoundNodeSelVec = true) const override; bool scanInternal(transaction::Transaction* transaction, TableScanState& scanState) override; + std::unique_ptr createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, MemoryManager* memoryManager, + std::shared_ptr outChunkState) const override; protected: // Implement ColumnarRelTableBase interface diff --git a/src/processor/operator/scan/scan_multi_rel_tables.cpp b/src/processor/operator/scan/scan_multi_rel_tables.cpp index f24eae254..05c31e847 100644 --- a/src/processor/operator/scan/scan_multi_rel_tables.cpp +++ b/src/processor/operator/scan/scan_multi_rel_tables.cpp @@ -67,34 +67,27 @@ void ScanMultiRelTable::initLocalStateInternal(ResultSet* resultSet, ExecutionCo auto nbrNodeIDVector = outVectors[0]; // Check if any table in any scanner is an external rel table with a custom scan state. - bool hasArrowTable = false; - bool hasIceDiskTable = false; + storage::ColumnarRelTableBase* columnarTable = nullptr; for (auto& [_, scanner] : scanners) { for (auto& relInfo : scanner.relInfos) { - if (dynamic_cast(relInfo.table) != nullptr) { - hasArrowTable = true; + if (dynamic_cast(relInfo.table) != nullptr) { + columnarTable = dynamic_cast(relInfo.table); break; } - if (dynamic_cast(relInfo.table) != nullptr) { - hasIceDiskTable = true; + + if (columnarTable != nullptr) { break; } } - if (hasArrowTable || hasIceDiskTable) { - break; - } } // IceDisk scan state extends the common rel scan state and Arrow stores its per-table state // there, so one scan state can now cover IceDisk, Arrow, and native rel tables. - if (hasIceDiskTable) { - scanState = - std::make_unique(*MemoryManager::Get(*clientContext), - boundNodeIDVector, outVectors, nbrNodeIDVector->state); - } else if (hasArrowTable) { - scanState = - std::make_unique(*MemoryManager::Get(*clientContext), - boundNodeIDVector, outVectors, nbrNodeIDVector->state); + if (columnarTable != nullptr) { + auto tableScanState = columnarTable->createScanState(boundNodeIDVector, outVectors, + MemoryManager::Get(*clientContext), nbrNodeIDVector->state); + scanState = std::unique_ptr( + dynamic_cast(tableScanState.release())); } else { scanState = std::make_unique(*MemoryManager::Get(*clientContext), boundNodeIDVector, outVectors, nbrNodeIDVector->state); diff --git a/src/processor/operator/scan/scan_node_table.cpp b/src/processor/operator/scan/scan_node_table.cpp index e9c952485..741671cdf 100644 --- a/src/processor/operator/scan/scan_node_table.cpp +++ b/src/processor/operator/scan/scan_node_table.cpp @@ -17,14 +17,11 @@ namespace processor { static std::unique_ptr createNodeTableScanState(NodeTable* table, ValueVector* nodeIDVector, const std::vector& outVectors, MemoryManager* memoryManager) { - if (dynamic_cast(table) != nullptr) { - return std::make_unique(*memoryManager, nodeIDVector, outVectors, - nodeIDVector->state); - } - if (dynamic_cast(table) != nullptr) { - return std::make_unique(*memoryManager, nodeIDVector, outVectors, - nodeIDVector->state); + if (dynamic_cast(table) != nullptr) { + return table->cast().createScanState(nodeIDVector, outVectors, + memoryManager); } + return std::make_unique(nodeIDVector, outVectors, nodeIDVector->state); } @@ -56,23 +53,8 @@ void ScanNodeTableSharedState::initialize(const transaction::Transaction* transa // Initialize table-specific scan coordination (e.g., for IceDiskNodeTable) table->initializeScanCoordination(transaction); - if (const auto iceDiskTable = dynamic_cast(table)) { - // For ice-disk tables, set numCommittedNodeGroups to number of row groups - std::vector columnSkips; - try { - auto context = transaction->getClientContext(); - auto resolvedPath = - common::VirtualFileSystem::resolvePath(context, iceDiskTable->getParquetFilePath()); - auto tempReader = - std::make_unique(resolvedPath, columnSkips, context); - this->numCommittedNodeGroups = tempReader->getNumRowGroups(); - } catch (const std::exception& e) { - this->numCommittedNodeGroups = 1; - } - } else if (const auto arrowTable = dynamic_cast(table)) { - // For Arrow tables, set numCommittedNodeGroups to number of morsels - this->numCommittedNodeGroups = - static_cast(arrowTable->getNumScanMorsels(transaction)); + if (const auto columnarNodeTable = dynamic_cast(table)) { + this->numCommittedNodeGroups = columnarNodeTable->getNumScanMorsels(transaction); } else { this->numCommittedNodeGroups = table->getNumCommittedNodeGroups(); } @@ -90,10 +72,8 @@ void ScanNodeTableSharedState::nextMorsel(TableScanState& scanState, ScanNodeTableProgressSharedState& progressSharedState) { std::unique_lock lck{mtx}; - // ColumnarNodeTables handle morsel assignment internally - // TODO: icebug-disk tables https://github.com/LadybugDB/ladybug/issues/245 - if (const auto arrowTable = dynamic_cast(this->table)) { - const auto tableSharedState = arrowTable->getTableScanSharedState(); + if (const auto columnarTable = dynamic_cast(this->table)) { + const auto tableSharedState = columnarTable->getTableScanSharedState(); if (tableSharedState->getNextMorsel(static_cast(&scanState))) { scanState.source = TableScanSource::COMMITTED; progressSharedState.numMorselsScanned++; @@ -149,9 +129,8 @@ void ScanNodeTable::initCurrentTable(ExecutionContext* context) { outVectors, MemoryManager::Get(*context->clientContext)); currentInfo.initScanState(*scanState, outVectors, context->clientContext); scanState->semiMask = sharedStates[currentTableIdx]->getSemiMask(); - // Call table->initScanState for IceDiskNodeTable or ArrowNodeTable - if (dynamic_cast(tableInfos[currentTableIdx].table) || - dynamic_cast(tableInfos[currentTableIdx].table)) { + // Call table->initScanState for ColumnarNodeTables + if (dynamic_cast(tableInfos[currentTableIdx].table)) { auto transaction = transaction::Transaction::Get(*context->clientContext); tableInfos[currentTableIdx].table->initScanState(transaction, *scanState); } diff --git a/src/processor/operator/scan/scan_rel_table.cpp b/src/processor/operator/scan/scan_rel_table.cpp index 8e7b5a7d1..aff9fe020 100644 --- a/src/processor/operator/scan/scan_rel_table.cpp +++ b/src/processor/operator/scan/scan_rel_table.cpp @@ -23,13 +23,9 @@ namespace processor { static std::unique_ptr createSourceNodeTableScanState(NodeTable* table, ValueVector* nodeIDVector, const std::vector& outVectors, MemoryManager* memoryManager) { - if (dynamic_cast(table) != nullptr) { - return std::make_unique(*memoryManager, nodeIDVector, outVectors, - nodeIDVector->state); - } - if (dynamic_cast(table) != nullptr) { - return std::make_unique(*memoryManager, nodeIDVector, outVectors, - nodeIDVector->state); + if (dynamic_cast(table) != nullptr) { + return table->cast().createScanState(nodeIDVector, outVectors, + memoryManager); } return std::make_unique(nodeIDVector, outVectors, nodeIDVector->state); } @@ -91,17 +87,13 @@ void ScanRelTable::initLocalStateInternal(ResultSet* resultSet, ExecutionContext auto boundNodeIDVector = resultSet->getValueVector(opInfo.nodeIDPos).get(); auto nbrNodeIDVector = outVectors[0]; // Check if this is an external rel table and create the corresponding scan state. - auto* arrowTable = dynamic_cast(tableInfo.table); - auto* iceDiskTable = dynamic_cast(tableInfo.table); + auto* columnarTable = dynamic_cast(tableInfo.table); auto* foreignTable = dynamic_cast(tableInfo.table); - if (arrowTable) { - scanState = - std::make_unique(*MemoryManager::Get(*clientContext), - boundNodeIDVector, outVectors, nbrNodeIDVector->state); - } else if (iceDiskTable) { - scanState = - std::make_unique(*MemoryManager::Get(*clientContext), - boundNodeIDVector, outVectors, nbrNodeIDVector->state); + if (columnarTable) { + auto tableScanState = columnarTable->createScanState(boundNodeIDVector, outVectors, + MemoryManager::Get(*clientContext), nbrNodeIDVector->state); + scanState = std::unique_ptr( + dynamic_cast(tableScanState.release())); } else if (foreignTable) { scanState = std::make_unique(*MemoryManager::Get(*clientContext), @@ -142,8 +134,7 @@ static void initSourceNodeScanState(ScanNodeTableInfo& sourceInfo, sourceScanState = createSourceNodeTableScanState(sourceInfo.table->ptrCast(), boundNodeIDVector, sourceNodeOutVectors, MemoryManager::Get(*context)); sourceInfo.initScanState(*sourceScanState, sourceNodeOutVectors, context); - if (dynamic_cast(sourceInfo.table) || - dynamic_cast(sourceInfo.table)) { + if (dynamic_cast(sourceInfo.table)) { sourceInfo.table->initScanState(transaction::Transaction::Get(*context), *sourceScanState); } } diff --git a/src/storage/table/arrow_node_table.cpp b/src/storage/table/arrow_node_table.cpp index a139e1464..ba84148c2 100644 --- a/src/storage/table/arrow_node_table.cpp +++ b/src/storage/table/arrow_node_table.cpp @@ -50,6 +50,12 @@ ArrowNodeTable::~ArrowNodeTable() { } } +std::unique_ptr ArrowNodeTable::createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, MemoryManager* memoryManager) const { + return std::make_unique(*memoryManager, nodeIDVector, outVectors, + nodeIDVector->state); +} + void ArrowNodeTable::initializeScanCoordination(const transaction::Transaction* transaction) { auto arrowScanSharedState = static_cast(tableScanSharedState.get()); @@ -135,11 +141,6 @@ bool ArrowNodeTable::scanInternal([[maybe_unused]] transaction::Transaction* tra return true; } -common::node_group_idx_t ArrowNodeTable::getNumBatches( - [[maybe_unused]] const transaction::Transaction* transaction) const { - return arrays.size(); -} - common::row_idx_t ArrowNodeTable::getTotalRowCount( [[maybe_unused]] const transaction::Transaction* transaction) const { return totalRows; @@ -156,9 +157,9 @@ std::vector ArrowNodeTable::getBatchSizes( return batchSizes; } -size_t ArrowNodeTable::getNumScanMorsels( +common::node_group_idx_t ArrowNodeTable::getNumScanMorsels( [[maybe_unused]] const transaction::Transaction* transaction) const { - size_t numMorsels = 0; + common::node_group_idx_t numMorsels = 0; for (const auto& array : arrays) { auto batchLength = getArrowBatchLength(array); numMorsels += (batchLength + scanMorselSize - 1) / scanMorselSize; diff --git a/src/storage/table/arrow_rel_table.cpp b/src/storage/table/arrow_rel_table.cpp index 327443eaf..868b06c76 100644 --- a/src/storage/table/arrow_rel_table.cpp +++ b/src/storage/table/arrow_rel_table.cpp @@ -38,6 +38,13 @@ static int64_t findColumnIdx(const ArrowSchemaWrapper& schema, const std::string return -1; } +std::unique_ptr ArrowRelTable::createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, MemoryManager* memoryManager, + std::shared_ptr outChunkState) const { + return std::make_unique(*memoryManager, nodeIDVector, + outVectors, outChunkState); +} + void ArrowRelTableScanState::setToTable(const transaction::Transaction* transaction, Table* table_, std::vector columnIDs_, std::vector columnPredicateSets_, RelDataDirection direction_) { diff --git a/src/storage/table/ice_disk_node_table.cpp b/src/storage/table/ice_disk_node_table.cpp index d61b02947..c2eeb8c47 100644 --- a/src/storage/table/ice_disk_node_table.cpp +++ b/src/storage/table/ice_disk_node_table.cpp @@ -40,11 +40,17 @@ IceDiskNodeTable::IceDiskNodeTable(const StorageManager* storageManager, parquetFilePath = resolvedPath; } +std::unique_ptr IceDiskNodeTable::createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, MemoryManager* memoryManager) const { + return std::make_unique(*memoryManager, nodeIDVector, outVectors, + nodeIDVector->state); +} + void IceDiskNodeTable::initializeScanCoordination(const transaction::Transaction* transaction) { auto iceDiskScanSharedState = static_cast(tableScanSharedState.get()); - auto numBatches = getNumBatches(transaction); - iceDiskScanSharedState->reset(numBatches); + auto numMorsels = getNumScanMorsels(transaction); + iceDiskScanSharedState->reset(numMorsels); } void IceDiskNodeTable::initScanState(Transaction* transaction, TableScanState& scanState, @@ -85,7 +91,8 @@ void IceDiskNodeTable::initScanState(Transaction* transaction, TableScanState& s initParquetScanForRowGroup(transaction, iceDiskScanState); } -common::node_group_idx_t IceDiskNodeTable::getNumBatches(const Transaction* transaction) const { +common::node_group_idx_t IceDiskNodeTable::getNumScanMorsels( + const transaction::Transaction* transaction) const { auto context = transaction->getClientContext(); if (!context) { return 1; @@ -94,7 +101,7 @@ common::node_group_idx_t IceDiskNodeTable::getNumBatches(const Transaction* tran std::vector columnSkips; try { auto tempReader = std::make_unique(parquetFilePath, columnSkips, context); - return tempReader->getNumRowGroups(); + return static_cast(tempReader->getNumRowGroups()); } catch (const std::exception& e) { return 1; // Fallback } diff --git a/src/storage/table/ice_disk_rel_table.cpp b/src/storage/table/ice_disk_rel_table.cpp index e662b7162..13d5778dd 100644 --- a/src/storage/table/ice_disk_rel_table.cpp +++ b/src/storage/table/ice_disk_rel_table.cpp @@ -22,6 +22,13 @@ using namespace lbug::transaction; namespace lbug { namespace storage { +std::unique_ptr IceDiskRelTable::createScanState(common::ValueVector* nodeIDVector, + const std::vector& outVectors, MemoryManager* memoryManager, + std::shared_ptr outChunkState) const { + return std::make_unique(*memoryManager, nodeIDVector, + outVectors, outChunkState); +} + void IceDiskRelTableScanState::setToTable(const Transaction* transaction, Table* table_, std::vector columnIDs_, std::vector columnPredicateSets_, RelDataDirection direction_) {