diff --git a/README.md b/README.md index ad8305d..3ba722d 100644 --- a/README.md +++ b/README.md @@ -141,14 +141,14 @@ ctest --test-dir build --output-on-failure --verbose - [x] **Phase 2**: In-memory query execution (SeqScan, Filter, Projection) - [x] **Phase 3**: Disk-based storage with buffer pool - [x] **Phase 4**: BTree indexes with query planner integration -- [ ] **Phase 5**: JOIN operations +- [x] **Phase 5**: JOIN operations - [x] Parse `INNER JOIN ... ON ...` with qualified column references - [x] Execute joins via `NestedLoopJoin` - [x] Add rule-based join algorithm choice (`NestedLoopJoin` vs `HashJoin`) - [x] Support `JOIN + WHERE` (single-table pushdown + post-join filter) - [x] Add correctness checks (ambiguous columns, swapped `ON` sides, type-mismatch safety) - - [ ] Add `EXPLAIN` command in REPL to print physical plan (`SeqScan`/`IndexScan`/`Join` path) - - [ ] Add join-condition index matching (index-aware join access path, e.g. index nested-loop opportunities) + - [x] Add `EXPLAIN` command in REPL to print physical plan (`SeqScan`/`IndexScan`/`Join` path) + - [x] Add join-condition index matching (`IndexNestedLoopJoin` when index exists on join column) - [ ] **Phase 6**: Transactions ## Architecture diff --git a/src/execution/CMakeLists.txt b/src/execution/CMakeLists.txt index 5e595b6..421cfde 100644 --- a/src/execution/CMakeLists.txt +++ b/src/execution/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(execution index_scan.cpp nested_loop_join.cpp hash_join.cpp + index_nested_loop_join.cpp ) target_link_libraries(execution PUBLIC optimizer storage common) diff --git a/src/execution/executor.cpp b/src/execution/executor.cpp index b92fb99..f93ba00 100644 --- a/src/execution/executor.cpp +++ b/src/execution/executor.cpp @@ -258,6 +258,8 @@ namespace sql return ExecuteCreateIndex(static_cast(stmt)); case StatementType::DROP_TABLE: return ExecuteDropTable(static_cast(stmt)); + case StatementType::EXPLAIN_STMT: + return ExecuteExplain(static_cast(stmt)); default: return {false, "Unsupported statement type", {}, {}}; } @@ -388,6 +390,36 @@ namespace sql EnsureJoinContextTable(left, right); return std::make_unique(left, right, left_col, right_col, node->join_build_right); } + case PhysicalPlanType::INDEX_NESTED_LOOP_JOIN: + { + Table *left = catalog_->GetTable(node->table_name); + Table *right = catalog_->GetTable(node->right_table_name); + if (left == nullptr || right == nullptr) + throw std::runtime_error("JOIN table not found while building operator tree"); + + // Resolve which ON-clause column belongs to which table (handles swapped ON order). + const auto [left_col, right_col] = ResolveJoinColumns(node, left, right); + + // join_right_as_outer: true → right is outer, left is inner (has index) + // false → left is outer, right is inner (has index) + const bool right_is_outer = node->join_right_as_outer; + Table *outer_table = right_is_outer ? right : left; + Table *inner_table = right_is_outer ? left : right; + const std::string outer_col = right_is_outer ? right_col : left_col; + const std::string inner_col = right_is_outer ? left_col : right_col; + const std::string &inner_table_name = right_is_outer ? node->table_name : node->right_table_name; + + BTree *inner_index = catalog_->GetIndex(inner_table_name, inner_col); + if (inner_index == nullptr) + throw std::runtime_error("Expected index not found on " + inner_table_name + "." + inner_col); + + EnsureJoinContextTable(left, right); + // outer_is_left: left table is outer when right_is_outer=false + const bool outer_is_left = !right_is_outer; + return std::make_unique( + outer_table, inner_table, inner_index, + outer_col, inner_col, outer_is_left); + } case PhysicalPlanType::FILTER: { if (node->children.empty()) @@ -408,7 +440,8 @@ namespace sql auto child = BuildOperatorTree(node->children[0].get(), table, join_table); Table *projection_join_table = join_table; if (node->children[0]->type == PhysicalPlanType::NESTED_LOOP_JOIN || - node->children[0]->type == PhysicalPlanType::HASH_JOIN) + node->children[0]->type == PhysicalPlanType::HASH_JOIN || + node->children[0]->type == PhysicalPlanType::INDEX_NESTED_LOOP_JOIN) { projection_join_table = catalog_->GetTable(node->children[0]->right_table_name); } @@ -692,4 +725,24 @@ namespace sql return {true, "Table '" + drop->table + "' dropped.", {}, {}}; } + ExecutionResult Executor::ExecuteExplain(ExplainStatement *explain) + { + ExecutionResult result; + try + { + materialized_tables_.clear(); + join_context_table_.reset(); + Optimizer optimizer; + auto physical_plan = optimizer.BuildPhysicalPlan(explain->select.get(), catalog_); + result.success = true; + result.message = optimizer.ExplainPhysicalPlan(physical_plan.get()); + } + catch (const std::exception &e) + { + result.success = false; + result.message = e.what(); + } + return result; + } + } // namespace sql diff --git a/src/execution/executor.h b/src/execution/executor.h index eb70d14..039633e 100644 --- a/src/execution/executor.h +++ b/src/execution/executor.h @@ -7,6 +7,7 @@ #include "execution/index_scan.h" #include "execution/nested_loop_join.h" #include "execution/hash_join.h" +#include "execution/index_nested_loop_join.h" #include "parser/ast.h" #include "optimizer/optimizer.h" #include "catalog/catalog.h" @@ -44,6 +45,7 @@ namespace sql ExecutionResult ExecuteUpdate(UpdateStatement *update); ExecutionResult ExecuteCreateIndex(CreateIndexStatement *create); ExecutionResult ExecuteDropTable(DropTableStatement *drop); + ExecutionResult ExecuteExplain(ExplainStatement *explain); // Evaluate an expression (reused for INSERT values, UPDATE SET, etc.) Value EvaluateExpr(const Expression *expr, const Tuple *tuple = nullptr, Table *table = nullptr) const; diff --git a/src/execution/index_nested_loop_join.cpp b/src/execution/index_nested_loop_join.cpp new file mode 100644 index 0000000..e8b271a --- /dev/null +++ b/src/execution/index_nested_loop_join.cpp @@ -0,0 +1,76 @@ +#include "execution/index_nested_loop_join.h" +#include + +namespace sql +{ + + namespace + { + std::string StripQualifier(const std::string &name) + { + size_t dot = name.find('.'); + if (dot == std::string::npos) + return name; + return name.substr(dot + 1); + } + } // namespace + + void IndexNestedLoopJoin::Open() + { + if (outer_table_ == nullptr || inner_table_ == nullptr || inner_index_ == nullptr) + throw std::runtime_error("IndexNestedLoopJoin: null table or index"); + + outer_col_idx_ = outer_table_->GetColumnIndex(StripQualifier(outer_col_)); + if (outer_col_idx_ < 0) + throw std::runtime_error("IndexNestedLoopJoin: unknown outer column: " + outer_col_); + + outer_cursor_ = 0; + inner_matches_.clear(); + inner_cursor_ = 0; + } + + bool IndexNestedLoopJoin::Next(Tuple *tuple) + { + const auto &outer_rows = outer_table_->GetTuples(); + + while (true) + { + // Consume remaining inner matches for current outer row + while (inner_cursor_ < inner_matches_.size()) + { + size_t inner_row_idx = inner_matches_[inner_cursor_++]; + const Tuple &inner_row = inner_table_->GetTuple(inner_row_idx); + + std::vector joined; + const Tuple &left_row = outer_is_left_ ? current_outer_ : inner_row; + const Tuple &right_row = outer_is_left_ ? inner_row : current_outer_; + joined.reserve(left_row.GetValueCount() + right_row.GetValueCount()); + for (size_t i = 0; i < left_row.GetValueCount(); ++i) + joined.push_back(left_row.GetValue(i)); + for (size_t i = 0; i < right_row.GetValueCount(); ++i) + joined.push_back(right_row.GetValue(i)); + + *tuple = Tuple(std::move(joined)); + return true; + } + + // Advance to next outer row + if (outer_cursor_ >= outer_rows.size()) + return false; + + current_outer_ = outer_rows[outer_cursor_++]; + Value probe_key = current_outer_.GetValue(static_cast(outer_col_idx_)); + inner_matches_ = inner_index_->Search(probe_key); + inner_cursor_ = 0; + } + } + + void IndexNestedLoopJoin::Close() + { + outer_cursor_ = 0; + inner_matches_.clear(); + inner_cursor_ = 0; + outer_col_idx_ = -1; + } + +} // namespace sql diff --git a/src/execution/index_nested_loop_join.h b/src/execution/index_nested_loop_join.h new file mode 100644 index 0000000..bfcc57c --- /dev/null +++ b/src/execution/index_nested_loop_join.h @@ -0,0 +1,52 @@ +#pragma once + +#include "execution/operator.h" +#include "storage/table.h" +#include "storage/btree.h" +#include +#include + +namespace sql +{ + + // Index Nested-Loop Join: for each outer row, probes inner table via BTree index. + // Output column order is always (left_table_cols..., right_table_cols...) regardless + // of which side is outer; outer_is_left controls concatenation order. + class IndexNestedLoopJoin : public Operator + { + public: + // outer_table: table iterated row by row + // inner_table: table probed via index + // inner_index: BTree on inner_table.inner_col + // outer_col: column name in outer_table used as probe key + // inner_col: indexed column name in inner_table + // outer_is_left: if true, output is (outer || inner); else (inner || outer) + IndexNestedLoopJoin(Table *outer_table, Table *inner_table, + BTree *inner_index, + std::string outer_col, std::string inner_col, + bool outer_is_left) + : outer_table_(outer_table), inner_table_(inner_table), + inner_index_(inner_index), + outer_col_(std::move(outer_col)), inner_col_(std::move(inner_col)), + outer_is_left_(outer_is_left) {} + + void Open() override; + bool Next(Tuple *tuple) override; + void Close() override; + + private: + Table *outer_table_; + Table *inner_table_; + BTree *inner_index_; + std::string outer_col_; + std::string inner_col_; + bool outer_is_left_; + + int outer_col_idx_ = -1; + size_t outer_cursor_ = 0; + Tuple current_outer_; + std::vector inner_matches_; + size_t inner_cursor_ = 0; + }; + +} // namespace sql diff --git a/src/lexer/lexer.cpp b/src/lexer/lexer.cpp index 3227d2a..03b4878 100644 --- a/src/lexer/lexer.cpp +++ b/src/lexer/lexer.cpp @@ -268,6 +268,8 @@ namespace sql {"NOT", TokenType::NOT}, {"DROP", TokenType::DROP}, {"JOIN", TokenType::JOIN}, + {"INNER", TokenType::INNER}, + {"EXPLAIN", TokenType::EXPLAIN}, {"TRUE", TokenType::TRUE}, {"FALSE", TokenType::FALSE}}; diff --git a/src/lexer/token.h b/src/lexer/token.h index 43af3da..765b306 100644 --- a/src/lexer/token.h +++ b/src/lexer/token.h @@ -33,6 +33,8 @@ namespace sql NOT, DROP, JOIN, + EXPLAIN, + INNER, // Identifiers and Literals IDENTIFIER, diff --git a/src/optimizer/optimizer.cpp b/src/optimizer/optimizer.cpp index 11fc2c6..0075142 100644 --- a/src/optimizer/optimizer.cpp +++ b/src/optimizer/optimizer.cpp @@ -386,22 +386,69 @@ namespace sql const size_t left_count = table->GetTupleCount(); const size_t right_count = right_table->GetTupleCount(); - // Rule-based join choice: - // - HASH_JOIN for larger equi-joins - // - NESTED_LOOP_JOIN for small inputs - const size_t total_rows = left_count + right_count; - const bool use_hash_join = total_rows >= 16; - - auto join = std::make_unique( - use_hash_join ? PhysicalPlanType::HASH_JOIN : PhysicalPlanType::NESTED_LOOP_JOIN); + // Resolve which join column belongs to which table. + // The ON clause may have columns in either order (e.g., ON right.x = left.y), + // so we must check before looking up indexes. + const auto left_side = ResolveColumnSide(*select->join_left_column, + select->table, *select->join_table, + table, right_table); + const auto right_side = ResolveColumnSide(*select->join_right_column, + select->table, *select->join_table, + table, right_table); + + // Determine the actual column name for each table side. + std::string col_on_left_table; // column belonging to select->table + std::string col_on_right_table; // column belonging to select->join_table + if (left_side == PredicateTableSide::LEFT && right_side == PredicateTableSide::RIGHT) + { + col_on_left_table = StripQualifier(*select->join_left_column); + col_on_right_table = StripQualifier(*select->join_right_column); + } + else if (left_side == PredicateTableSide::RIGHT && right_side == PredicateTableSide::LEFT) + { + col_on_left_table = StripQualifier(*select->join_right_column); + col_on_right_table = StripQualifier(*select->join_left_column); + } + // else: ambiguous or unresolved — skip INLJ + + // Check if either join column has an index → prefer INDEX_NESTED_LOOP_JOIN. + std::unique_ptr join; + if (!col_on_left_table.empty() && !col_on_right_table.empty()) + { + BTree *left_index = catalog->GetIndex(select->table, col_on_left_table); + BTree *right_index = catalog->GetIndex(*select->join_table, col_on_right_table); + + if (right_index != nullptr) + { + // Right table has index → right is inner, left is outer. + join = std::make_unique(PhysicalPlanType::INDEX_NESTED_LOOP_JOIN); + join->join_right_as_outer = false; // left is outer + } + else if (left_index != nullptr) + { + // Left table has index → left is inner, right is outer. + join = std::make_unique(PhysicalPlanType::INDEX_NESTED_LOOP_JOIN); + join->join_right_as_outer = true; // right is outer + } + } + + if (join == nullptr) + { + // Rule-based choice between HASH_JOIN and NESTED_LOOP_JOIN. + const size_t total_rows = left_count + right_count; + const bool use_hash_join = total_rows >= 16; + join = std::make_unique( + use_hash_join ? PhysicalPlanType::HASH_JOIN : PhysicalPlanType::NESTED_LOOP_JOIN); + // Iterate smaller table in outer loop for nested loop. + join->join_right_as_outer = right_count < left_count; + // Build hash table on smaller side for hash join. + join->join_build_right = right_count <= left_count; + } + join->table_name = select->table; join->right_table_name = *select->join_table; join->join_left_column = *select->join_left_column; join->join_right_column = *select->join_right_column; - // Rule-based choice: iterate smaller table in outer loop for nested loop. - join->join_right_as_outer = right_count < left_count; - // Rule-based choice: build hash table on smaller side for hash join. - join->join_build_right = right_count <= left_count; join->children.push_back(std::move(left_input)); join->children.push_back(std::move(right_input)); current = std::move(join); @@ -484,6 +531,19 @@ namespace sql << ", build=" << (node->join_build_right ? "right" : "left") << ")"; break; + case PhysicalPlanType::INDEX_NESTED_LOOP_JOIN: + { + const std::string &outer = node->join_right_as_outer ? node->right_table_name : node->table_name; + const std::string &inner = node->join_right_as_outer ? node->table_name : node->right_table_name; + const std::string &inner_col = node->join_right_as_outer ? node->join_left_column : node->join_right_column; + out << pad << "IndexNestedLoopJoin(outer=" << outer + << ", inner=" << inner + << "(index on " << inner_col << ")" + << ", on=" << node->join_left_column + << " = " << node->join_right_column + << ")"; + break; + } case PhysicalPlanType::FILTER: out << pad << "Filter"; if (node->predicate != nullptr) diff --git a/src/optimizer/optimizer.h b/src/optimizer/optimizer.h index af87504..cc3f3e2 100644 --- a/src/optimizer/optimizer.h +++ b/src/optimizer/optimizer.h @@ -41,6 +41,7 @@ namespace sql INDEX_SCAN, NESTED_LOOP_JOIN, HASH_JOIN, + INDEX_NESTED_LOOP_JOIN, FILTER, PROJECTION }; @@ -65,10 +66,10 @@ namespace sql std::optional high_key; bool high_inclusive = true; - // NESTED_LOOP_JOIN + // NESTED_LOOP_JOIN / HASH_JOIN / INDEX_NESTED_LOOP_JOIN std::string join_left_column; std::string join_right_column; - bool join_right_as_outer = false; + bool join_right_as_outer = false; // NLJ: right is outer loop; INLJ: right is outer (left has index) // HASH_JOIN bool join_build_right = true; diff --git a/src/parser/ast.h b/src/parser/ast.h index 5f9fc32..0defab1 100644 --- a/src/parser/ast.h +++ b/src/parser/ast.h @@ -59,7 +59,8 @@ namespace sql INSERT, CREATE_INDEX, DELETE_STMT, - UPDATE_STMT + UPDATE_STMT, + EXPLAIN_STMT }; struct Statement @@ -130,6 +131,12 @@ namespace sql StatementType GetType() const override { return StatementType::UPDATE_STMT; } }; + struct ExplainStatement : public Statement + { + std::unique_ptr select; + StatementType GetType() const override { return StatementType::EXPLAIN_STMT; } + }; + std::string ExpressionTypeToString(ExpressionType type); std::string StatementTypeToString(StatementType type); std::string DumpExpression(const Expression *expr, int indent = 0); diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index 721d067..64f776a 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -53,6 +53,8 @@ namespace sql return ParseUpdate(); case TokenType::DROP: return ParseDropTable(); + case TokenType::EXPLAIN: + return ParseExplain(); default: throw std::runtime_error("Unexpected token at start of statement: " + current_token_.value); } @@ -79,7 +81,18 @@ namespace sql Token table = Expect(TokenType::IDENTIFIER); stmt->table = table.value; - if (Match(TokenType::JOIN)) + bool has_join = false; + if (Match(TokenType::INNER)) + { + Expect(TokenType::JOIN); + has_join = true; + } + else if (Match(TokenType::JOIN)) + { + has_join = true; + } + + if (has_join) { Token join_table = Expect(TokenType::IDENTIFIER); stmt->join_table = join_table.value; @@ -363,4 +376,12 @@ namespace sql return name; } + std::unique_ptr Parser::ParseExplain() + { + Expect(TokenType::EXPLAIN); + auto stmt = std::make_unique(); + stmt->select = ParseSelect(); + return stmt; + } + } // namespace sql diff --git a/src/parser/parser.h b/src/parser/parser.h index aa9e491..e43fe8b 100644 --- a/src/parser/parser.h +++ b/src/parser/parser.h @@ -22,6 +22,7 @@ namespace sql std::unique_ptr ParseCreate(); std::unique_ptr ParseSelect(); + std::unique_ptr ParseExplain(); std::unique_ptr ParseCreateTable(); std::unique_ptr ParseCreateIndex(); std::unique_ptr ParseDropTable(); diff --git a/test/integration/query_test.cpp b/test/integration/query_test.cpp index c8398f4..6e8b470 100644 --- a/test/integration/query_test.cpp +++ b/test/integration/query_test.cpp @@ -462,4 +462,101 @@ namespace sql EXPECT_EQ(result.tuples[0].GetValue(1).GetAsInt(), 99); } + // ── EXPLAIN ─────────────────────────────────────────────────────────────────── + + TEST(IntegrationTest, ExplainSeqScan) + { + Catalog catalog; + RunSQL(catalog, "CREATE TABLE t (id INTEGER, name VARCHAR(50));"); + auto result = RunSQL(catalog, "EXPLAIN SELECT * FROM t;"); + EXPECT_TRUE(result.success); + EXPECT_NE(result.message.find("SeqScan"), std::string::npos); + EXPECT_NE(result.message.find("Projection"), std::string::npos); + } + + TEST(IntegrationTest, ExplainIndexScan) + { + Catalog catalog; + RunSQL(catalog, "CREATE TABLE t (id INTEGER);"); + RunSQL(catalog, "INSERT INTO t VALUES (1), (2), (3);"); + RunSQL(catalog, "CREATE INDEX idx_id ON t (id);"); + auto result = RunSQL(catalog, "EXPLAIN SELECT * FROM t WHERE id = 2;"); + EXPECT_TRUE(result.success); + EXPECT_NE(result.message.find("IndexScan"), std::string::npos); + } + + TEST(IntegrationTest, ExplainJoin) + { + Catalog catalog; + RunSQL(catalog, "CREATE TABLE orders (id INTEGER, cid INTEGER);"); + RunSQL(catalog, "CREATE TABLE customers (id INTEGER, name VARCHAR(50));"); + auto result = RunSQL(catalog, "EXPLAIN SELECT * FROM orders JOIN customers ON orders.cid = customers.id;"); + EXPECT_TRUE(result.success); + // Without index: should be hash or nested loop join + EXPECT_TRUE(result.message.find("Join") != std::string::npos); + } + + // ── INDEX NESTED-LOOP JOIN ──────────────────────────────────────────────────── + + TEST(IntegrationTest, IndexNestedLoopJoinBasic) + { + Catalog catalog; + RunSQL(catalog, "CREATE TABLE orders (oid INTEGER, cid INTEGER);"); + RunSQL(catalog, "CREATE TABLE customers (id INTEGER, name VARCHAR(50));"); + RunSQL(catalog, "INSERT INTO orders VALUES (1, 10), (2, 20), (3, 10);"); + RunSQL(catalog, "INSERT INTO customers VALUES (10, 'Alice'), (20, 'Bob');"); + RunSQL(catalog, "CREATE INDEX idx_cid ON customers (id);"); + + auto result = RunSQL(catalog, "SELECT * FROM orders JOIN customers ON orders.cid = customers.id;"); + EXPECT_TRUE(result.success); + ASSERT_EQ(result.tuples.size(), 3); + } + + TEST(IntegrationTest, IndexNestedLoopJoinExplain) + { + Catalog catalog; + RunSQL(catalog, "CREATE TABLE orders (oid INTEGER, cid INTEGER);"); + RunSQL(catalog, "CREATE TABLE customers (id INTEGER, name VARCHAR(50));"); + RunSQL(catalog, "INSERT INTO orders VALUES (1, 10);"); + RunSQL(catalog, "INSERT INTO customers VALUES (10, 'Alice');"); + RunSQL(catalog, "CREATE INDEX idx_cid ON customers (id);"); + + auto explain = RunSQL(catalog, "EXPLAIN SELECT * FROM orders JOIN customers ON orders.cid = customers.id;"); + EXPECT_TRUE(explain.success); + EXPECT_NE(explain.message.find("IndexNestedLoopJoin"), std::string::npos); + } + + TEST(IntegrationTest, IndexNestedLoopJoinSwappedOnColumns) + { + Catalog catalog; + RunSQL(catalog, "CREATE TABLE orders (oid INTEGER, cid INTEGER);"); + RunSQL(catalog, "CREATE TABLE customers (id INTEGER, name VARCHAR(50));"); + RunSQL(catalog, "INSERT INTO orders VALUES (1, 10), (2, 20), (3, 10);"); + RunSQL(catalog, "INSERT INTO customers VALUES (10, 'Alice'), (20, 'Bob');"); + RunSQL(catalog, "CREATE INDEX idx_cid ON customers (id);"); + + // ON clause has columns in reverse order: right table column first. + auto result = RunSQL(catalog, "SELECT * FROM orders JOIN customers ON customers.id = orders.cid;"); + EXPECT_TRUE(result.success); + ASSERT_EQ(result.tuples.size(), 3); + + // EXPLAIN should still pick IndexNestedLoopJoin despite swapped ON order. + auto explain = RunSQL(catalog, "EXPLAIN SELECT * FROM orders JOIN customers ON customers.id = orders.cid;"); + EXPECT_TRUE(explain.success); + EXPECT_NE(explain.message.find("IndexNestedLoopJoin"), std::string::npos); + } + + TEST(IntegrationTest, InnerJoinSyntax) + { + Catalog catalog; + RunSQL(catalog, "CREATE TABLE a (id INTEGER, val INTEGER);"); + RunSQL(catalog, "CREATE TABLE b (aid INTEGER, name VARCHAR(50));"); + RunSQL(catalog, "INSERT INTO a VALUES (1, 100), (2, 200);"); + RunSQL(catalog, "INSERT INTO b VALUES (1, 'x'), (2, 'y');"); + + auto result = RunSQL(catalog, "SELECT * FROM a INNER JOIN b ON a.id = b.aid;"); + EXPECT_TRUE(result.success); + ASSERT_EQ(result.tuples.size(), 2); + } + } // namespace sql