From d4f555af0e83131701a840d23deacb69c18ae763 Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Thu, 18 Jun 2026 14:36:17 +0000 Subject: [PATCH 1/3] Update vendored DuckDB sources to 17ba7dd2b6 --- .../extension/parquet/column_reader.cpp | 10 +- .../extension/parquet/column_writer.cpp | 3 + .../catalog_entry/duck_table_entry.cpp | 12 + .../catalog_entry/table_catalog_entry.cpp | 20 +- .../src/catalog/default/default_types.cpp | 6 +- .../src/common/arrow/arrow_converter.cpp | 2 +- src/duckdb/src/common/enum_util.cpp | 32 +- .../common/enums/logical_operator_type.cpp | 2 + src/duckdb/src/common/types.cpp | 13 +- .../common/types/variant/variant_iterator.cpp | 104 ++-- .../common/vector_operations/vector_hash.cpp | 10 +- .../src/execution/expression_executor.cpp | 11 + .../src/execution/physical_plan_generator.cpp | 5 +- .../src/function/cast/variant/to_variant.cpp | 4 + src/duckdb/src/function/function_list.cpp | 1 + .../src/function/scalar/generic/invoke.cpp | 144 +++++ .../scalar/struct/struct_contains.cpp | 4 +- .../function/scalar/struct/struct_extract.cpp | 4 +- .../function/scalar/struct/struct_pack.cpp | 10 +- .../scalar/system/aggregate_export.cpp | 26 +- .../scalar/variant/variant_comparator.cpp | 116 +++- .../table/arrow/arrow_duck_schema.cpp | 4 - .../function/table/system/test_all_types.cpp | 7 + .../function/table/version/pragma_version.cpp | 6 +- .../catalog_entry/table_catalog_entry.hpp | 8 +- .../src/include/duckdb/common/enum_util.hpp | 8 + .../enums/debug_vector_verification.hpp | 3 +- .../common/enums/logical_operator_type.hpp | 1 + .../table_function_identifier_conversion.hpp | 21 + .../src/include/duckdb/common/types.hpp | 3 + .../duckdb/common/types/variant_iterator.hpp | 111 +++- .../duckdb/function/lambda_functions.hpp | 10 +- .../function/scalar/generic_functions.hpp | 10 + .../function/variant/variant_shredding.hpp | 6 +- .../include/duckdb/main/extension_entries.hpp | 28 +- .../src/include/duckdb/main/settings.hpp | 36 +- .../duckdb/parser/peg/inlined_grammar.hpp | 7 +- .../peg/transformer/peg_transformer.hpp | 453 ++++++++------- .../src/include/duckdb/planner/binder.hpp | 10 + .../include/duckdb/planner/operator/list.hpp | 1 + .../planner/operator/logical_trigger.hpp | 52 ++ .../subquery/flatten_dependent_join.hpp | 2 + .../storage/statistics/variant_stats.hpp | 5 + .../storage/table/variant_column_data.hpp | 4 + src/duckdb/src/main/config.cpp | 18 +- .../main/settings/autogenerated_settings.cpp | 7 + .../optimizer/common_subplan_optimizer.cpp | 237 ++++++-- .../optimizer/join_order/relation_manager.cpp | 10 +- .../src/optimizer/remove_unused_columns.cpp | 21 +- .../statistics/expression/propagate_cast.cpp | 4 + .../transformer/peg_transformer_factory.cpp | 3 - .../peg/transformer/transform_alter.cpp | 109 ++-- .../peg/transformer/transform_analyze.cpp | 16 +- .../peg/transformer/transform_attach.cpp | 28 +- .../peg/transformer/transform_checkpoint.cpp | 8 +- .../peg/transformer/transform_common.cpp | 101 ++-- .../peg/transformer/transform_connect.cpp | 58 +- .../parser/peg/transformer/transform_copy.cpp | 51 +- .../transformer/transform_create_index.cpp | 62 +- .../transformer/transform_create_macro.cpp | 21 +- .../transformer/transform_create_schema.cpp | 2 +- .../transformer/transform_create_secret.cpp | 12 +- .../transformer/transform_create_sequence.cpp | 27 +- .../transformer/transform_create_table.cpp | 215 +++---- .../transformer/transform_create_trigger.cpp | 29 +- .../peg/transformer/transform_create_type.cpp | 20 +- .../peg/transformer/transform_create_view.cpp | 16 +- .../peg/transformer/transform_deallocate.cpp | 2 +- .../peg/transformer/transform_delete.cpp | 31 +- .../peg/transformer/transform_describe.cpp | 20 +- .../peg/transformer/transform_detach.cpp | 3 +- .../parser/peg/transformer/transform_drop.cpp | 41 +- .../peg/transformer/transform_execute.cpp | 15 +- .../peg/transformer/transform_explain.cpp | 21 +- .../peg/transformer/transform_export.cpp | 26 +- .../peg/transformer/transform_generated.cpp | 530 +++++++++++------- .../transform_generic_copy_option.cpp | 16 +- .../peg/transformer/transform_insert.cpp | 69 ++- .../parser/peg/transformer/transform_load.cpp | 26 +- .../peg/transformer/transform_merge_into.cpp | 64 ++- .../peg/transformer/transform_pragma.cpp | 6 +- .../peg/transformer/transform_prepare.cpp | 8 +- .../parser/peg/transformer/transform_set.cpp | 9 +- .../peg/transformer/transform_transaction.cpp | 14 +- .../peg/transformer/transform_update.cpp | 36 +- .../parser/peg/transformer/transform_use.cpp | 10 +- .../peg/transformer/transform_vacuum.cpp | 32 +- .../expression/bind_function_expression.cpp | 29 +- .../binder/query_node/bind_select_node.cpp | 4 +- .../query_node/bind_trigger_expansion.cpp | 187 +++++- .../planner/binder/statement/bind_create.cpp | 70 ++- .../planner/binder/statement/bind_insert.cpp | 3 + .../expression/bound_function_expression.cpp | 10 +- .../table_function_binder.cpp | 19 + .../src/planner/operator/logical_trigger.cpp | 19 + src/duckdb/src/planner/planner.cpp | 32 ++ .../subquery/delim_join_cte_rewriter.cpp | 44 +- .../subquery/flatten_dependent_join.cpp | 32 ++ .../src/storage/statistics/variant_stats.cpp | 71 ++- .../src/storage/table/struct_column_data.cpp | 3 +- .../table/variant/variant_shredding.cpp | 45 +- .../table/variant/variant_unshredding.cpp | 5 +- .../src/transaction/duck_transaction.cpp | 36 +- .../src/transaction/transaction_context.cpp | 3 +- src/duckdb/ub_src_function_scalar_generic.cpp | 2 + src/duckdb/ub_src_planner_operator.cpp | 2 + 106 files changed, 2792 insertions(+), 1213 deletions(-) create mode 100644 src/duckdb/src/function/scalar/generic/invoke.cpp create mode 100644 src/duckdb/src/include/duckdb/common/enums/table_function_identifier_conversion.hpp create mode 100644 src/duckdb/src/include/duckdb/planner/operator/logical_trigger.hpp create mode 100644 src/duckdb/src/planner/operator/logical_trigger.cpp diff --git a/src/duckdb/extension/parquet/column_reader.cpp b/src/duckdb/extension/parquet/column_reader.cpp index e9386b59b..eedda2cad 100644 --- a/src/duckdb/extension/parquet/column_reader.cpp +++ b/src/duckdb/extension/parquet/column_reader.cpp @@ -408,7 +408,10 @@ void ColumnReader::PreparePageV2(PageHeader &page_hdr) { } if (chunk->meta_data.codec == CompressionCodec::UNCOMPRESSED) { if (page_hdr.compressed_page_size != page_hdr.uncompressed_page_size) { - throw InvalidInputException("Failed to read file \"%s\": Page size mismatch", Reader().GetFileName()); + const auto &file_name = Reader().GetFileName(); + throw InvalidInputException( + "Parquet file (%s) corrupted: uncompressed page size mismatch (expected %d, actual: %d)", file_name, + page_hdr.uncompressed_page_size, page_hdr.compressed_page_size); } uncompressed = true; } @@ -467,7 +470,10 @@ void ColumnReader::PreparePage(PageHeader &page_hdr) { if (chunk->meta_data.codec == CompressionCodec::UNCOMPRESSED) { if (compressed_page_size != NumericCast(page_hdr.uncompressed_page_size)) { - throw InternalException("Page size mismatch"); + const auto &file_name = Reader().GetFileName(); + throw InvalidInputException( + "Parquet file (%s) corrupted: uncompressed page size mismatch (expected %d, actual: %d)", file_name, + page_hdr.uncompressed_page_size, compressed_page_size); } ReadData(block->ptr, compressed_page_size, page_hdr.type); return; diff --git a/src/duckdb/extension/parquet/column_writer.cpp b/src/duckdb/extension/parquet/column_writer.cpp index 31ba7f0a3..aea0190bd 100644 --- a/src/duckdb/extension/parquet/column_writer.cpp +++ b/src/duckdb/extension/parquet/column_writer.cpp @@ -355,6 +355,9 @@ unique_ptr ColumnWriter::CreateWriterRecursive(ClientContext &cont } if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { + if (type.id() == LogicalTypeId::STRUCT && StructType::GetChildTypes(type).empty()) { + throw InvalidInputException("Empty STRUCT columns are not supported in the Parquet format"); + } auto struct_column = ParquetColumnSchema::FromLogicalType(name, type, max_define, max_repeat, 0, null_type, allow_geometry); if (field_id && field_id->set) { diff --git a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp index 94ba907e3..8f8c88ab1 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp @@ -62,6 +62,18 @@ static void CheckTypeIsSupported(const LogicalType &logical_type, AttachedDataba case LogicalTypeId::TYPE: { throw InvalidInputException("A table cannot be created with a 'TYPE' column"); } break; + case LogicalTypeId::STRUCT: { + const auto storage_version = db.GetStorageManager().GetStorageVersion(); + + if (storage_version < StorageVersion::V2_0_0 && StructType::GetChildCount(type) == 0) { + auto required = GetStorageVersionName(StorageVersion::V2_0_0, false); + auto current = GetStorageVersionName(storage_version, false); + + throw InvalidInputException("Empty STRUCT columns are not supported in storage versions prior to %s " + "(database \"%s\" is using storage version %s)", + required, db.GetName(), current); + } + } break; case LogicalTypeId::VARIANT: { const auto storage_version = db.GetStorageManager().GetStorageVersion(); diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp index 4e6143ea9..d42a388d0 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -392,14 +392,28 @@ void TableCatalogEntry::ScanTriggers(CatalogTransaction transaction, } vector> TableCatalogEntry::GetTriggersForEvent(CatalogTransaction transaction, - TriggerTiming timing, - TriggerEventType event_type) const { + TriggerEventType event_type, + TriggerForEach for_each) const { vector> result; // CatalogSet is backed by case_insensitive_tree_t (a map with case-insensitive comparator), // so ScanTriggers yields entries in alphabetical order by name ScanTriggers(transaction, [&](CatalogEntry &entry) { auto &trigger = entry.Cast(); - if (trigger.timing == timing && trigger.event_type == event_type) { + if (trigger.event_type == event_type && trigger.for_each == for_each) { + result.emplace_back(trigger); + } + }); + return result; +} + +vector> TableCatalogEntry::GetTriggersForEvent(CatalogTransaction transaction, + TriggerTiming timing, + TriggerEventType event_type, + TriggerForEach for_each) const { + vector> result; + ScanTriggers(transaction, [&](CatalogEntry &entry) { + auto &trigger = entry.Cast(); + if (trigger.event_type == event_type && trigger.for_each == for_each && trigger.timing == timing) { result.emplace_back(trigger); } }); diff --git a/src/duckdb/src/catalog/default/default_types.cpp b/src/duckdb/src/catalog/default/default_types.cpp index 16b60a0f9..f19c913e2 100644 --- a/src/duckdb/src/catalog/default/default_types.cpp +++ b/src/duckdb/src/catalog/default/default_types.cpp @@ -260,13 +260,9 @@ LogicalType BindArrayType(BindLogicalTypeInput &input) { //---------------------------------------------------------------------------------------------------------------------- LogicalType BindStructType(BindLogicalTypeInput &input) { auto &arguments = input.modifiers; - - if (arguments.empty()) { - throw BinderException("STRUCT type requires at least one child type"); - } - auto all_name = true; auto all_anon = true; + for (auto &arg : arguments) { if (arg.HasName()) { all_anon = false; diff --git a/src/duckdb/src/common/arrow/arrow_converter.cpp b/src/duckdb/src/common/arrow/arrow_converter.cpp index 044c57e12..e71083b2f 100644 --- a/src/duckdb/src/common/arrow/arrow_converter.cpp +++ b/src/duckdb/src/common/arrow/arrow_converter.cpp @@ -75,7 +75,7 @@ void SetArrowStructFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &chi for (idx_t type_idx = 0; type_idx < child_types.size(); type_idx++) { root_holder.nested_children_ptr.back()[type_idx] = &root_holder.nested_children.back()[type_idx]; } - child.children = &root_holder.nested_children_ptr.back()[0]; + child.children = child_types.empty() ? nullptr : &root_holder.nested_children_ptr.back()[0]; for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { InitializeChild(*child.children[type_idx], root_holder); root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first.GetIdentifierName())); diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index ff0a4fd37..683913391 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -71,6 +71,7 @@ #include "duckdb/common/enums/storage_block_prefetch.hpp" #include "duckdb/common/enums/stream_execution_result.hpp" #include "duckdb/common/enums/subquery_type.hpp" +#include "duckdb/common/enums/table_function_identifier_conversion.hpp" #include "duckdb/common/enums/tableref_type.hpp" #include "duckdb/common/enums/task_scheduler_type.hpp" #include "duckdb/common/enums/thread_pin_mode.hpp" @@ -1643,19 +1644,20 @@ const StringUtil::EnumStringLiteral *GetDebugVectorVerificationValues() { { static_cast(DebugVectorVerification::CONSTANT_OPERATOR), "CONSTANT_OPERATOR" }, { static_cast(DebugVectorVerification::SEQUENCE_OPERATOR), "SEQUENCE_OPERATOR" }, { static_cast(DebugVectorVerification::NESTED_SHUFFLE), "NESTED_SHUFFLE" }, - { static_cast(DebugVectorVerification::VARIANT_VECTOR), "VARIANT_VECTOR" } + { static_cast(DebugVectorVerification::VARIANT_VECTOR), "VARIANT_VECTOR" }, + { static_cast(DebugVectorVerification::SHREDDED_VECTOR), "SHREDDED_VECTOR" } }; return values; } template<> const char* EnumUtil::ToChars(DebugVectorVerification value) { - return StringUtil::EnumToString(GetDebugVectorVerificationValues(), 7, "DebugVectorVerification", static_cast(value)); + return StringUtil::EnumToString(GetDebugVectorVerificationValues(), 8, "DebugVectorVerification", static_cast(value)); } template<> DebugVectorVerification EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDebugVectorVerificationValues(), 7, "DebugVectorVerification", value)); + return static_cast(StringUtil::StringToEnum(GetDebugVectorVerificationValues(), 8, "DebugVectorVerification", value)); } const StringUtil::EnumStringLiteral *GetDebugVerificationModeValues() { @@ -3175,6 +3177,7 @@ const StringUtil::EnumStringLiteral *GetLogicalOperatorTypeValues() { { static_cast(LogicalOperatorType::LOGICAL_DELETE), "LOGICAL_DELETE" }, { static_cast(LogicalOperatorType::LOGICAL_UPDATE), "LOGICAL_UPDATE" }, { static_cast(LogicalOperatorType::LOGICAL_MERGE_INTO), "LOGICAL_MERGE_INTO" }, + { static_cast(LogicalOperatorType::LOGICAL_TRIGGER), "LOGICAL_TRIGGER" }, { static_cast(LogicalOperatorType::LOGICAL_ALTER), "LOGICAL_ALTER" }, { static_cast(LogicalOperatorType::LOGICAL_CREATE_TABLE), "LOGICAL_CREATE_TABLE" }, { static_cast(LogicalOperatorType::LOGICAL_CREATE_INDEX), "LOGICAL_CREATE_INDEX" }, @@ -3208,12 +3211,12 @@ const StringUtil::EnumStringLiteral *GetLogicalOperatorTypeValues() { template<> const char* EnumUtil::ToChars(LogicalOperatorType value) { - return StringUtil::EnumToString(GetLogicalOperatorTypeValues(), 65, "LogicalOperatorType", static_cast(value)); + return StringUtil::EnumToString(GetLogicalOperatorTypeValues(), 66, "LogicalOperatorType", static_cast(value)); } template<> LogicalOperatorType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLogicalOperatorTypeValues(), 65, "LogicalOperatorType", value)); + return static_cast(StringUtil::StringToEnum(GetLogicalOperatorTypeValues(), 66, "LogicalOperatorType", value)); } const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { @@ -5532,6 +5535,25 @@ TableFilterType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetTableFilterTypeValues(), 12, "TableFilterType", value)); } +const StringUtil::EnumStringLiteral *GetTableFunctionIdentifierConversionValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(TableFunctionIdentifierConversion::DEFAULT), "DEFAULT" }, + { static_cast(TableFunctionIdentifierConversion::ENABLE_IMPLICIT_STRING), "ENABLE_IMPLICIT_STRING" }, + { static_cast(TableFunctionIdentifierConversion::DISABLE_IMPLICIT_STRING), "DISABLE_IMPLICIT_STRING" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(TableFunctionIdentifierConversion value) { + return StringUtil::EnumToString(GetTableFunctionIdentifierConversionValues(), 3, "TableFunctionIdentifierConversion", static_cast(value)); +} + +template<> +TableFunctionIdentifierConversion EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetTableFunctionIdentifierConversionValues(), 3, "TableFunctionIdentifierConversion", value)); +} + const StringUtil::EnumStringLiteral *GetTableFunctionParallelismValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(TableFunctionParallelism::SELF_MANAGED_PARALLELISM), "SELF_MANAGED_PARALLELISM" }, diff --git a/src/duckdb/src/common/enums/logical_operator_type.cpp b/src/duckdb/src/common/enums/logical_operator_type.cpp index 772e5f502..446fd25c3 100644 --- a/src/duckdb/src/common/enums/logical_operator_type.cpp +++ b/src/duckdb/src/common/enums/logical_operator_type.cpp @@ -72,6 +72,8 @@ string LogicalOperatorToString(LogicalOperatorType type) { return "UPDATE"; case LogicalOperatorType::LOGICAL_MERGE_INTO: return "MERGE_INTO"; + case LogicalOperatorType::LOGICAL_TRIGGER: + return "TRIGGER"; case LogicalOperatorType::LOGICAL_PREPARE: return "PREPARE"; case LogicalOperatorType::LOGICAL_DUMMY_SCAN: diff --git a/src/duckdb/src/common/types.cpp b/src/duckdb/src/common/types.cpp index 6b1768af2..3ccc63062 100644 --- a/src/duckdb/src/common/types.cpp +++ b/src/duckdb/src/common/types.cpp @@ -411,8 +411,12 @@ string LogicalType::ToString() const { if (!type_info_) { return "STRUCT"; } - auto is_unnamed = StructType::IsUnnamed(*this); auto &child_types = StructType::GetChildTypes(*this); + if (child_types.empty()) { + return "STRUCT"; + } + + auto is_unnamed = StructType::IsUnnamed(*this); string ret = "STRUCT("; for (size_t i = 0; i < child_types.size(); i++) { if (is_unnamed) { @@ -688,7 +692,12 @@ bool LogicalType::IsComplete() const { D_ASSERT(type.AuxInfo()); switch (type.AuxInfo()->type) { case ExtraTypeInfoType::STRUCT_TYPE_INFO: - return type.AuxInfo()->Cast().child_types.empty(); // Cannot be empty + // empty STRUCTs are complete (children, if any, are checked by recursion) + // UNION/VARIANT (which also use STRUCT_TYPE_INFO) cannot be empty + if (type.id() == LogicalTypeId::STRUCT) { + return false; + } + return type.AuxInfo()->Cast().child_types.empty(); case ExtraTypeInfoType::DECIMAL_TYPE_INFO: return DecimalType::GetWidth(type) >= 1 && DecimalType::GetWidth(type) <= Decimal::MAX_WIDTH_DECIMAL && DecimalType::GetScale(type) <= DecimalType::GetWidth(type); diff --git a/src/duckdb/src/common/types/variant/variant_iterator.cpp b/src/duckdb/src/common/types/variant/variant_iterator.cpp index bd52b81f8..5214142b2 100644 --- a/src/duckdb/src/common/types/variant/variant_iterator.cpp +++ b/src/duckdb/src/common/types/variant/variant_iterator.cpp @@ -58,9 +58,9 @@ bool ShreddedIsValid(const ShreddedVariantIterator &node, idx_t index) { } // namespace //===--------------------------------------------------------------------===// -// VariantIteratorState +// VariantIterator //===--------------------------------------------------------------------===// -VariantIteratorState::VariantIteratorState(const Vector &variant) +VariantIterator::VariantIterator(const Vector &variant) //! The unshredded ("core") source is the variant itself, or the unshredded component of a shredded vector : unshredded(variant.GetVectorType() == VectorType::SHREDDED_VECTOR ? ShreddedVector::GetUnshreddedVector(variant) : variant) { @@ -155,27 +155,27 @@ uint32_t UnshreddedVariantIterator::GetValuesIndex(idx_t row, idx_t child_index) //===--------------------------------------------------------------------===// // Root / row validity //===--------------------------------------------------------------------===// -VariantIterator VariantIteratorState::Root(idx_t row) const { +VariantNode VariantIterator::Root(idx_t row) const { if (is_shredded) { //! The shredded component's top-level validity is the authoritative row validity (a SQL-NULL row //! has the whole shredded struct set to NULL). This must be checked separately because //! ResolveShredded only inspects the typed_value / untyped_value_index of a wrapper, never the //! wrapper's own struct validity. if (!ShreddedIsValid(shredded_format, row)) { - return VariantIterator::MakeNull(*this); + return VariantNode::MakeNull(*this); } - auto root = VariantIterator::ResolveShredded(*this, shredded_format, row, row); + auto root = VariantNode::ResolveShredded(*this, shredded_format, row, row); //! a root value is never "missing" - treat any such case as a SQL NULL - return root.IsMissing() ? VariantIterator::MakeNull(*this) : root; + return root.IsMissing() ? VariantNode::MakeNull(*this) : root; } if (!unshredded.RowIsValid(row)) { - return VariantIterator::MakeNull(*this); + return VariantNode::MakeNull(*this); } //! The unshredded root value lives at values[0] - return VariantIterator::MakeUnshredded(*this, row, 0); + return VariantNode::MakeUnshredded(*this, row, 0); } -bool VariantIteratorState::RowIsValid(idx_t row) const { +bool VariantIterator::RowIsValid(idx_t row) const { //! A VARIANT is never NULL at the root via a VARIANT_NULL value (that is reserved for nested values) - //! a root that resolves to NULL is a genuine SQL NULL. This matches the semantics of unshredding, //! where a shredded value whose typed leaf is NULL (with no unshredded leftover) becomes a SQL NULL. @@ -183,24 +183,24 @@ bool VariantIteratorState::RowIsValid(idx_t row) const { } //===--------------------------------------------------------------------===// -// VariantIterator - factory helpers +// VariantNode - factory helpers //===--------------------------------------------------------------------===// -VariantIterator VariantIterator::MakeNull(const VariantIteratorState &state) { - VariantIterator result; +VariantNode VariantNode::MakeNull(const VariantIterator &state) { + VariantNode result; result.state = &state; result.kind = Kind::NULL_VALUE; return result; } -VariantIterator VariantIterator::MakeMissing(const VariantIteratorState &state) { - VariantIterator result; +VariantNode VariantNode::MakeMissing(const VariantIterator &state) { + VariantNode result; result.state = &state; result.kind = Kind::MISSING; return result; } -VariantIterator VariantIterator::MakeUnshredded(const VariantIteratorState &state, idx_t row, uint32_t value_index) { - VariantIterator result; +VariantNode VariantNode::MakeUnshredded(const VariantIterator &state, idx_t row, uint32_t value_index) { + VariantNode result; result.state = &state; result.kind = Kind::UNSHREDDED; result.row = row; @@ -208,9 +208,9 @@ VariantIterator VariantIterator::MakeUnshredded(const VariantIteratorState &stat return result; } -VariantIterator VariantIterator::MakeShredded(const VariantIteratorState &state, const ShreddedVariantIterator &content, - idx_t index, idx_t row, uint32_t overlay_value_index) { - VariantIterator result; +VariantNode VariantNode::MakeShredded(const VariantIterator &state, const ShreddedVariantIterator &content, idx_t index, + idx_t row, uint32_t overlay_value_index) { + VariantNode result; result.state = &state; result.kind = Kind::SHREDDED; result.row = row; @@ -223,8 +223,8 @@ VariantIterator VariantIterator::MakeShredded(const VariantIteratorState &state, //===--------------------------------------------------------------------===// // Shredded resolution //===--------------------------------------------------------------------===// -VariantIterator VariantIterator::ResolveShredded(const VariantIteratorState &state, const ShreddedVariantIterator &node, - idx_t index, idx_t row) { +VariantNode VariantNode::ResolveShredded(const VariantIterator &state, const ShreddedVariantIterator &node, idx_t index, + idx_t row) { if (node.logical_type.id() != LogicalTypeId::STRUCT) { //! A flattened (fully-consistent) primitive - a NULL here represents a VARIANT_NULL value if (!ShreddedIsValid(node, index)) { @@ -348,12 +348,11 @@ static VariantLogicalType ShreddedTypeId(const ShreddedVariantIterator &content, case LogicalTypeId::GEOMETRY: return VariantLogicalType::GEOMETRY; default: - throw NotImplementedException("Shredded VARIANT type '%s' is not supported by VariantIterator", - type.ToString()); + throw NotImplementedException("Shredded VARIANT type '%s' is not supported by VariantNode", type.ToString()); } } -VariantLogicalType VariantIterator::GetTypeId() const { +VariantLogicalType VariantNode::GetTypeId() const { switch (kind) { case Kind::NULL_VALUE: return VariantLogicalType::VARIANT_NULL; @@ -362,14 +361,14 @@ VariantLogicalType VariantIterator::GetTypeId() const { case Kind::SHREDDED: return ShreddedTypeId(*shredded_format, shredded_index); default: - throw InternalException("VariantIterator::GetTypeId called on a MISSING value"); + throw InternalException("VariantNode::GetTypeId called on a MISSING value"); } } //===--------------------------------------------------------------------===// // Primitive accessors //===--------------------------------------------------------------------===// -const_data_ptr_t VariantIterator::GetDataPointer() const { +const_data_ptr_t VariantNode::GetDataPointer() const { if (kind == Kind::UNSHREDDED) { auto &blob = state->unshredded.GetBlob(row); return const_data_ptr_cast(blob.GetData()) + state->unshredded.GetByteOffset(row, value_index); @@ -380,7 +379,7 @@ const_data_ptr_t VariantIterator::GetDataPointer() const { return content.unified.data + content.unified.sel->get_index(shredded_index) * type_size; } -string_t VariantIterator::GetString() const { +string_t VariantNode::GetString() const { if (kind == Kind::UNSHREDDED) { return DecodeStringData(state->unshredded.GetBlob(row), state->unshredded.GetByteOffset(row, value_index)); } @@ -389,7 +388,7 @@ string_t VariantIterator::GetString() const { return content.unified.GetData()[content.unified.sel->get_index(shredded_index)]; } -VariantDecimalData VariantIterator::GetDecimal() const { +VariantDecimalData VariantNode::GetDecimal() const { if (kind == Kind::UNSHREDDED) { return DecodeDecimalData(state->unshredded.GetBlob(row), state->unshredded.GetByteOffset(row, value_index)); } @@ -406,22 +405,22 @@ VariantDecimalData VariantIterator::GetDecimal() const { //===--------------------------------------------------------------------===// // Nested accessors //===--------------------------------------------------------------------===// -VariantObjectIterator VariantIterator::GetObjectChildren(VariantIterationOrder order) const { +VariantObjectIterator VariantNode::GetObjectChildren(VariantIterationOrder order) const { return VariantObjectIterator(*this, order); } -VariantArrayIterator VariantIterator::GetArrayChildren() const { +VariantArrayIterator VariantNode::GetArrayChildren() const { return VariantArrayIterator(*this); } //===--------------------------------------------------------------------===// // VariantArrayIterator //===--------------------------------------------------------------------===// -VariantArrayIterator::VariantArrayIterator(const VariantIterator &array) - : state(array.state), row(array.row), shredded(array.kind == VariantIterator::Kind::SHREDDED) { +VariantArrayIterator::VariantArrayIterator(const VariantNode &array) + : state(*array.state), row(array.row), shredded(array.kind == VariantNode::Kind::SHREDDED) { if (!shredded) { - auto nested = - DecodeNestedData(state->unshredded.GetBlob(row), state->unshredded.GetByteOffset(row, array.value_index)); + auto &unshredded = state.get().unshredded; + auto nested = DecodeNestedData(unshredded.GetBlob(row), unshredded.GetByteOffset(row, array.value_index)); base = nested.children_idx; length = nested.child_count; return; @@ -435,21 +434,22 @@ VariantArrayIterator::VariantArrayIterator(const VariantIterator &array) element_node = content.children[0]; } -VariantIterator VariantArrayIterator::operator[](idx_t i) const { +VariantNode VariantArrayIterator::operator[](idx_t i) const { + auto &state_ref = state.get(); if (shredded) { - return VariantIterator::ResolveShredded(*state, *element_node, base + i, row); + return VariantNode::ResolveShredded(state_ref, *element_node, base + i, row); } - return VariantIterator::MakeUnshredded(*state, row, state->unshredded.GetValuesIndex(row, base + i)); + return VariantNode::MakeUnshredded(state_ref, row, state_ref.unshredded.GetValuesIndex(row, base + i)); } //===--------------------------------------------------------------------===// // VariantObjectIterator //===--------------------------------------------------------------------===// -VariantObjectIterator::VariantObjectIterator(const VariantIterator &object, VariantIterationOrder order) - : state(object.state), row(object.row), order(order), shredded(object.kind == VariantIterator::Kind::SHREDDED) { +VariantObjectIterator::VariantObjectIterator(const VariantNode &object, VariantIterationOrder order) + : state(*object.state), row(object.row), order(order), shredded(object.kind == VariantNode::Kind::SHREDDED) { + auto &unshredded = state.get().unshredded; if (!shredded) { - auto nested = - DecodeNestedData(state->unshredded.GetBlob(row), state->unshredded.GetByteOffset(row, object.value_index)); + auto nested = DecodeNestedData(unshredded.GetBlob(row), unshredded.GetByteOffset(row, object.value_index)); base = nested.children_idx; raw_count = nested.child_count; } else { @@ -460,8 +460,7 @@ VariantObjectIterator::VariantObjectIterator(const VariantIterator &object, Vari raw_count = typed_field_count; if (object.overlay_value_index != 0) { auto overlay_value_index = object.overlay_value_index - 1; - auto nested = DecodeNestedData(state->unshredded.GetBlob(row), - state->unshredded.GetByteOffset(row, overlay_value_index)); + auto nested = DecodeNestedData(unshredded.GetBlob(row), unshredded.GetByteOffset(row, overlay_value_index)); overlay_base = nested.children_idx; raw_count += nested.child_count; } @@ -485,26 +484,27 @@ VariantObjectIterator::VariantObjectIterator(const VariantIterator &object, Vari } VariantObjectEntry VariantObjectIterator::RawEntry(idx_t raw_pos) const { + auto &state_ref = state.get(); + auto &unshredded = state_ref.unshredded; if (!shredded) { auto child_idx = base + raw_pos; - auto key_idx = state->unshredded.GetKeysIndex(row, child_idx); - auto value_idx = state->unshredded.GetValuesIndex(row, child_idx); - return VariantObjectEntry {state->unshredded.GetKey(row, key_idx), - VariantIterator::MakeUnshredded(*state, row, value_idx)}; + auto key_idx = unshredded.GetKeysIndex(row, child_idx); + auto value_idx = unshredded.GetValuesIndex(row, child_idx); + return VariantObjectEntry {unshredded.GetKey(row, key_idx), + VariantNode::MakeUnshredded(state_ref, row, value_idx)}; } if (raw_pos < typed_field_count) { //! A shredded (typed) field - struct fields preserve the (logical) index of the parent auto &name = StructType::GetChildTypes(content->logical_type)[raw_pos].first; return VariantObjectEntry { string_t(name.c_str(), NumericCast(name.size())), - VariantIterator::ResolveShredded(*state, content->children[raw_pos], shredded_index, row)}; + VariantNode::ResolveShredded(state_ref, content->children[raw_pos], shredded_index, row)}; } //! A leftover field from the overlay (unshredded) object auto child_idx = overlay_base + (raw_pos - typed_field_count); - auto key_idx = state->unshredded.GetKeysIndex(row, child_idx); - auto value_idx = state->unshredded.GetValuesIndex(row, child_idx); - return VariantObjectEntry {state->unshredded.GetKey(row, key_idx), - VariantIterator::MakeUnshredded(*state, row, value_idx)}; + auto key_idx = unshredded.GetKeysIndex(row, child_idx); + auto value_idx = unshredded.GetValuesIndex(row, child_idx); + return VariantObjectEntry {unshredded.GetKey(row, key_idx), VariantNode::MakeUnshredded(state_ref, row, value_idx)}; } void VariantObjectIterator::Iterator::Load() { diff --git a/src/duckdb/src/common/vector_operations/vector_hash.cpp b/src/duckdb/src/common/vector_operations/vector_hash.cpp index e7d6c77df..f25e55e41 100644 --- a/src/duckdb/src/common/vector_operations/vector_hash.cpp +++ b/src/duckdb/src/common/vector_operations/vector_hash.cpp @@ -100,7 +100,15 @@ template void StructLoopHash(const Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { const auto &children = StructVector::GetEntries(input); - D_ASSERT(!children.empty()); + if (children.empty()) { + // an empty struct has no fields to hash: for the first hash every row gets the same constant value. + if (FIRST_HASH) { + hashes.SetVectorType(VectorType::CONSTANT_VECTOR); + FlatVector::SetSize(hashes, count_t(count)); + *ConstantVector::GetData(hashes) = 0x9e3779b97f4a7c15ULL; // some (arbitrary) constant + } + return; + } idx_t col_no = 0; if (HAS_RSEL) { if (FIRST_HASH) { diff --git a/src/duckdb/src/execution/expression_executor.cpp b/src/duckdb/src/execution/expression_executor.cpp index e8a49fade..037d65732 100644 --- a/src/duckdb/src/execution/expression_executor.cpp +++ b/src/duckdb/src/execution/expression_executor.cpp @@ -8,6 +8,7 @@ #include "duckdb/main/settings.hpp" #include "duckdb/function/cast/cast_function_set.hpp" #include "duckdb/common/type_visitor.hpp" +#include "duckdb/storage/table/variant_column_data.hpp" namespace duckdb { @@ -208,6 +209,16 @@ void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t co vector.Reference(result); vector.Verify(); } + if (debug_vector_verification == DebugVectorVerification::SHREDDED_VECTOR) { + //! Shred (top-level) VARIANT vectors based on the schema of their first value, so downstream + //! operators are exercised against shredded (and partially-shredded) variant vectors. + //! A SHREDDED_VECTOR is never a constant vector - skip constant vectors so we don't break callers + //! that require a constant result (e.g. scalar expression folding in EvaluateScalar). + if (vector.GetType().id() == LogicalTypeId::VARIANT && vector.GetVectorType() != VectorType::CONSTANT_VECTOR) { + VariantColumnData::DebugShred(vector, count); + vector.Verify(); + } + } } unique_ptr ExpressionExecutor::InitializeState(const Expression &expr, diff --git a/src/duckdb/src/execution/physical_plan_generator.cpp b/src/duckdb/src/execution/physical_plan_generator.cpp index 95b0d84fb..1c28fc901 100644 --- a/src/duckdb/src/execution/physical_plan_generator.cpp +++ b/src/duckdb/src/execution/physical_plan_generator.cpp @@ -63,7 +63,8 @@ unique_ptr PhysicalPlanGenerator::PlanInternal(LogicalOperator &op auto debug_verify_vector = Settings::Get(context); if (debug_verify_vector != DebugVectorVerification::NONE) { if (debug_verify_vector != DebugVectorVerification::DICTIONARY_EXPRESSION && - debug_verify_vector != DebugVectorVerification::VARIANT_VECTOR) { + debug_verify_vector != DebugVectorVerification::VARIANT_VECTOR && + debug_verify_vector != DebugVectorVerification::SHREDDED_VECTOR) { physical_plan->SetRoot(Make(physical_plan->Root(), debug_verify_vector)); } } @@ -126,6 +127,8 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalOperator &op) { return CreatePlan(op.Cast()); case LogicalOperatorType::LOGICAL_MERGE_INTO: return CreatePlan(op.Cast()); + case LogicalOperatorType::LOGICAL_TRIGGER: + throw InternalException("LogicalTrigger must be rewritten before physical planning"); case LogicalOperatorType::LOGICAL_CREATE_TABLE: return CreatePlan(op.Cast()); case LogicalOperatorType::LOGICAL_CREATE_INDEX: diff --git a/src/duckdb/src/function/cast/variant/to_variant.cpp b/src/duckdb/src/function/cast/variant/to_variant.cpp index 4255ac751..9da98fa7e 100644 --- a/src/duckdb/src/function/cast/variant/to_variant.cpp +++ b/src/duckdb/src/function/cast/variant/to_variant.cpp @@ -154,6 +154,10 @@ static bool SupportsShreddedCast(const LogicalType &type) { if (type.id() == LogicalTypeId::STRUCT) { // for struct types recurse into the child types auto &child_types = StructType::GetChildTypes(type); + if (child_types.empty()) { + // an empty struct has no typed_value to shred into + return false; + } for (auto &entry : child_types) { if (!SupportsShreddedCast(entry.second)) { return false; diff --git a/src/duckdb/src/function/function_list.cpp b/src/duckdb/src/function/function_list.cpp index b094cf3a8..a5f9c40fb 100644 --- a/src/duckdb/src/function/function_list.cpp +++ b/src/duckdb/src/function/function_list.cpp @@ -170,6 +170,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_WINDOW_FUNCTION(FirstValueFun), DUCKDB_SCALAR_FUNCTION(GetVariableFun), DUCKDB_SCALAR_FUNCTION(IlikeEscapeFun), + DUCKDB_SCALAR_FUNCTION(InvokeFun), DUCKDB_WINDOW_FUNCTION(LagFun), DUCKDB_AGGREGATE_FUNCTION_SET(LastFun), DUCKDB_WINDOW_FUNCTION(LastValueFun), diff --git a/src/duckdb/src/function/scalar/generic/invoke.cpp b/src/duckdb/src/function/scalar/generic/invoke.cpp new file mode 100644 index 000000000..f198b5a24 --- /dev/null +++ b/src/duckdb/src/function/scalar/generic/invoke.cpp @@ -0,0 +1,144 @@ +#include "duckdb/function/scalar/generic_functions.hpp" +#include "duckdb/function/lambda_functions.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/planner/expression/bound_lambda_expression.hpp" + +namespace duckdb { + +namespace { + +struct LambdaInvokeData final : public LambdaFunctionData { + unique_ptr lambda_expr; + + explicit LambdaInvokeData(unique_ptr lambda_expr_p) : lambda_expr(std::move(lambda_expr_p)) { + } + + unique_ptr Copy() const override { + auto lambda_expr_copy = lambda_expr ? lambda_expr->Copy() : nullptr; + return make_uniq(std::move(lambda_expr_copy)); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return Expression::Equals(lambda_expr, other.lambda_expr); + } + + //! Serializes a lambda function's bind data + static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, + const BoundScalarFunction &function) { + auto &bind_data = bind_data_p->Cast(); + serializer.WritePropertyWithDefault(101, "lambda_expr", bind_data.lambda_expr, unique_ptr()); + } + + //! Deserializes a lambda function's bind data + static unique_ptr Deserialize(Deserializer &deserializer, BoundScalarFunction &) { + auto lambda_expr = deserializer.ReadPropertyWithExplicitDefault>( + 101, "lambda_expr", unique_ptr()); + return make_uniq(std::move(lambda_expr)); + } + + optional_ptr GetLambdaExpression() const override { + if (!lambda_expr) { + return nullptr; + } + auto &bound_lambda_expr = lambda_expr->Cast(); + return bound_lambda_expr.LambdaExpr().get(); + } +}; + +struct LambdaInvokeState final : public FunctionLocalState { + unique_ptr executor; + DataChunk input_chunk; + idx_t parameter_count; + + LambdaInvokeState(unique_ptr executor_p, const vector &input_types, + const idx_t parameter_count_p) + : executor(std::move(executor_p)), parameter_count(parameter_count_p) { + input_chunk.InitializeEmpty(input_types); + } + + static unique_ptr Init(ExpressionState &state, const BoundFunctionExpression &expr, + FunctionData *bind_data) { + auto &bdata = bind_data->Cast(); + if (!bdata.lambda_expr) { + throw InternalException("Invoke function is missing its bound lambda expression"); + } + auto &bound_lambda_expr = bdata.lambda_expr->Cast(); + const auto parameter_count = bound_lambda_expr.ParameterCount(); + D_ASSERT(parameter_count <= expr.GetChildren().size()); + + vector input_types; + input_types.reserve(expr.GetChildren().size()); + for (idx_t i = 0; i < parameter_count; i++) { + input_types.push_back(expr.GetChildren()[parameter_count - i - 1]->GetReturnType()); + } + for (idx_t i = parameter_count; i < expr.GetChildren().size(); i++) { + input_types.push_back(expr.GetChildren()[i]->GetReturnType()); + } + + auto executor = make_uniq(state.GetContext(), *bound_lambda_expr.LambdaExpr()); + return make_uniq(std::move(executor), input_types, parameter_count); + } +}; + +void LambdaInvokeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); + for (idx_t i = 0; i < lstate.parameter_count; i++) { + lstate.input_chunk.data[i].Reference(args.data[lstate.parameter_count - i - 1]); + } + for (idx_t i = lstate.parameter_count; i < args.ColumnCount(); i++) { + lstate.input_chunk.data[i].Reference(args.data[i]); + } + lstate.input_chunk.SetChildCardinality(args.size()); + lstate.executor->ExecuteExpression(lstate.input_chunk, result); +} + +unique_ptr LambdaInvokeBind(BindScalarFunctionInput &input) { + auto &bound_function = input.GetBoundFunction(); + auto &arguments = input.GetArguments(); + // the list column and the bound lambda expression + if (arguments[0]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { + throw BinderException("Invalid lambda expression passed to 'invoke' function."); + } + + auto &bound_lambda_expr = arguments[0]->Cast(); + if (bound_lambda_expr.ParameterCount() != arguments.size() - 1) { + throw BinderException("The number of lambda parameters does not match the number of arguments passed to the " + "'invoke' function, expected %d, got %d.", + bound_lambda_expr.ParameterCount(), arguments.size() - 1); + } + + bound_function.SetReturnType(bound_lambda_expr.LambdaExpr()->GetReturnType()); + + return make_uniq(bound_lambda_expr.Copy()); +} + +LogicalType LambdaInvokeBindParameters(ClientContext &context, const vector &function_child_types, + const idx_t parameter_idx, optional_ptr bind_lambda_context) { + // The first parameter is always the lambda + auto child_idx = parameter_idx + 1; + if (child_idx >= function_child_types.size()) { + throw BinderException("The number of lambda parameters does not match the number of arguments passed to the " + "'invoke' function, expected at least %d, got %d.", + parameter_idx + 1, function_child_types.size() - 1); + } + return function_child_types[child_idx]; +} + +} // namespace + +ScalarFunction InvokeFun::GetFunction() { + ScalarFunction fun("invoke", {LogicalType::LAMBDA, LogicalType::ANY}, LogicalType::ANY, LambdaInvokeFunction); + fun.SetBindCallback(LambdaInvokeBind); + fun.SetVarArgs(LogicalType::ANY); + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetBindLambdaCallback(LambdaInvokeBindParameters); + fun.SetInitStateCallback(LambdaInvokeState::Init); + fun.SetSerializeCallback(LambdaInvokeData::Serialize); + fun.SetDeserializeCallback(LambdaInvokeData::Deserialize); + return fun; +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/struct/struct_contains.cpp b/src/duckdb/src/function/scalar/struct/struct_contains.cpp index ad2271536..35385d144 100644 --- a/src/duckdb/src/function/scalar/struct/struct_contains.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_contains.cpp @@ -210,7 +210,9 @@ static unique_ptr StructContainsBind(BindScalarFunctionInput &inpu auto &struct_children = StructType::GetChildTypes(arguments[0]->GetReturnType()); if (struct_children.empty()) { - throw InternalException("Can't check for containment in an empty struct"); + // an empty struct contains nothing, the search always returns false (or position 0) + bound_function.GetArguments()[0] = child_type; + return nullptr; } if (!StructType::IsUnnamed(child_type)) { throw BinderException("%s can only be used on unnamed structs", bound_function.GetName()); diff --git a/src/duckdb/src/function/scalar/struct/struct_extract.cpp b/src/duckdb/src/function/scalar/struct/struct_extract.cpp index 4c1d47beb..74f3ac5e2 100644 --- a/src/duckdb/src/function/scalar/struct/struct_extract.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_extract.cpp @@ -38,7 +38,7 @@ static unique_ptr StructExtractBind(BindScalarFunctionInput &input D_ASSERT(LogicalTypeId::STRUCT == child_type.id()); auto &struct_children = StructType::GetChildTypes(child_type); if (struct_children.empty()) { - throw InternalException("Can't extract something from an empty struct"); + throw BinderException("Can't extract something from an empty struct"); } if (StructType::IsUnnamed(child_type)) { throw BinderException( @@ -102,7 +102,7 @@ static unique_ptr StructExtractBindInternal(ClientContext &context D_ASSERT(LogicalTypeId::STRUCT == child_type.id()); auto &struct_children = StructType::GetChildTypes(child_type); if (struct_children.empty()) { - throw InternalException("Can't extract something from an empty struct"); + throw BinderException("Can't extract something from an empty struct"); } if (struct_extract && !StructType::IsUnnamed(child_type)) { throw BinderException( diff --git a/src/duckdb/src/function/scalar/struct/struct_pack.cpp b/src/duckdb/src/function/scalar/struct/struct_pack.cpp index 01c3c61b3..372ca3ee9 100644 --- a/src/duckdb/src/function/scalar/struct/struct_pack.cpp +++ b/src/duckdb/src/function/scalar/struct/struct_pack.cpp @@ -18,6 +18,12 @@ static void StructPackFunction(DataChunk &args, ExpressionState &state, Vector & // this should never happen if the binder below is sane D_ASSERT(args.ColumnCount() == StructType::GetChildTypes(info.stype).size()); #endif + if (args.ColumnCount() == 0) { + // empty struct: no children to reference, the value is a single non-null constant + result.SetVectorType(VectorType::CONSTANT_VECTOR); + ConstantVector::SetNull(result, false); + return; + } bool all_const = true; auto &child_entries = StructVector::GetEntries(result); idx_t children_size = 0; @@ -45,9 +51,7 @@ static unique_ptr StructPackBind(BindScalarFunctionInput &input) { identifier_set_t name_collision_set; // collect names and deconflict, construct return type - if (arguments.empty()) { - throw InvalidInputException("Can't pack nothing into a struct"); - } + // note: zero arguments is allowed, producing an empty struct child_list_t struct_children; for (idx_t i = 0; i < arguments.size(); i++) { auto &child = arguments[i]; diff --git a/src/duckdb/src/function/scalar/system/aggregate_export.cpp b/src/duckdb/src/function/scalar/system/aggregate_export.cpp index d832cad16..f84bbdddd 100644 --- a/src/duckdb/src/function/scalar/system/aggregate_export.cpp +++ b/src/duckdb/src/function/scalar/system/aggregate_export.cpp @@ -1,6 +1,7 @@ #include "duckdb/common/vector/flat_vector.hpp" #include "duckdb/common/vector/list_vector.hpp" #include "duckdb/common/vector/struct_vector.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/common/types/list_segment.hpp" #include "duckdb/common/types/variant_value.hpp" #include "duckdb/function/aggregate_state_layout.hpp" @@ -665,7 +666,6 @@ unique_ptr BindAggregateState(BindScalarFunctionInput &input) { void ExportAggregateFinalize(Vector &state, AggregateFinalizeInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { - D_ASSERT(offset == 0); const data_ptr_t *addresses_ptrs; if (state.GetVectorType() == VectorType::CONSTANT_VECTOR) { if (count != 1) { @@ -678,8 +678,15 @@ void ExportAggregateFinalize(Vector &state, AggregateFinalizeInputData &aggr_inp auto layout = GetLayout(aggr_input_data.function, aggr_input_data.bind_data); - result.Flatten(); - SerializeState(layout, result, count, addresses_ptrs); + if (offset == 0) { + SerializeState(layout, result, count, addresses_ptrs); + return; + } + // finalizing at a non-zero offset (e.g. ordered aggregates) - serialize into a temporary vector and copy the + // result into place so the rest of the result vector is left untouched + Vector temp(result.GetType(), count); + SerializeState(layout, temp, count, addresses_ptrs); + VectorOperations::Copy(temp, result, count, 0, offset); } // the executor invokes this callback with combine_aggr's own bind data (ExportAggregateBindData) - the underlying @@ -760,14 +767,21 @@ void CombineAggrUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx void CombineAggrFinalize(Vector &state, AggregateFinalizeInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { - D_ASSERT(offset == 0); auto &bind_data = aggr_input_data.bind_data->Cast(); auto &underlying_aggr = bind_data.aggr; auto layout = GetLayout(underlying_aggr, bind_data.bind_data.get()); - result.Flatten(); - SerializeState(underlying_aggr, bind_data.bind_data.get(), layout, state, count, result, aggr_input_data.allocator); + if (offset == 0) { + SerializeState(underlying_aggr, bind_data.bind_data.get(), layout, state, count, result, + aggr_input_data.allocator); + return; + } + // finalizing at a non-zero offset (e.g. ordered aggregates) - serialize into a temporary vector and copy the + // result into place so the rest of the result vector is left untouched + Vector temp(result.GetType(), count); + SerializeState(underlying_aggr, bind_data.bind_data.get(), layout, state, count, temp, aggr_input_data.allocator); + VectorOperations::Copy(temp, result, count, 0, offset); } // constructs the AGGREGATE_STATE type for the given bound aggregate function diff --git a/src/duckdb/src/function/scalar/variant/variant_comparator.cpp b/src/duckdb/src/function/scalar/variant/variant_comparator.cpp index c3661c575..bd4dd04da 100644 --- a/src/duckdb/src/function/scalar/variant/variant_comparator.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_comparator.cpp @@ -6,10 +6,13 @@ #include "duckdb/common/types/uhugeint.hpp" #include "duckdb/common/types/bignum.hpp" #include "duckdb/common/types/bit.hpp" +#include "duckdb/common/types/value.hpp" #include "duckdb/common/types/variant.hpp" #include "duckdb/common/enum_util.hpp" #include "duckdb/common/vector/flat_vector.hpp" #include "duckdb/common/types/variant_iterator.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" namespace duckdb { @@ -257,7 +260,7 @@ VariantNumberKey IntegerNumberKey(T value) { } //! Compute the number key for any value in the NUMBER rank (integer, decimal or bignum) -VariantNumberKey VariantGetNumberKey(VariantLogicalType type_id, const VariantIterator &it) { +VariantNumberKey VariantGetNumberKey(VariantLogicalType type_id, const VariantNode &it) { switch (type_id) { case VariantLogicalType::DECIMAL: { auto decimal_data = it.GetDecimal(); @@ -360,7 +363,7 @@ void VariantEncodeString(SINK &sink, const string_t &str, bool is_varchar) { } template -void EncodeVariantValue(const VariantIterator &it, SINK &sink) { +void EncodeVariantValue(const VariantNode &it, SINK &sink) { auto type_id = it.GetTypeId(); // write the type rank - this guarantees values are ordered by type first sink.Write(GetVariantTypeRank(type_id)); @@ -484,7 +487,7 @@ void EncodeVariantValue(const VariantIterator &it, SINK &sink) { //! encode the *logical* value of the variant - this encoding is intentionally not reversible (e.g. //! all integer widths fold together). NULLs are propagated into the result validity. void CreateVariantComparator(const Vector &input, idx_t count, Vector &result) { - VariantIteratorState variant(input); + auto variant = input.Values(); result.SetVectorType(VectorType::FLAT_VECTOR); auto writer = FlatVector::Writer(result, count); @@ -492,7 +495,7 @@ void CreateVariantComparator(const Vector &input, idx_t count, Vector &result) { //! reused growable buffer - the key is encoded once and then copied into the result vector string buffer; for (idx_t r = 0; r < count; r++) { - auto root = variant.Root(r); + auto root = variant[r]; // a VARIANT is only NULL at the root via a genuine SQL NULL (never a VARIANT_NULL value) if (root.IsNull()) { // propagate NULL so that NULL = NULL stays NULL and ORDER BY ... NULLS FIRST/LAST is honored @@ -519,11 +522,114 @@ void VariantComparatorFunction(DataChunk &input, ExpressionState &state, Vector CreateVariantComparator(input.data[0], input.size(), result); } +//===--------------------------------------------------------------------===// +// Statistics Propagation +//===--------------------------------------------------------------------===// +// When the input VARIANT is fully shredded onto a single primitive type, every (non-NULL) value lives +// in the same comparator "bucket" - all of its sort keys share the type-rank prefix - and the values +// are bounded by the typed min/max of the shredded column. Because the comparator encoding is +// order-preserving, encoding the typed min/max with the exact same encoding yields valid min/max sort +// keys for the BLOB output: min/max are derived by running the real comparator (CreateVariantComparator) +// over a cast of the bound to VARIANT, so the derived bounds can never diverge from the runtime encoding. +// A root that resolves to NULL is always a genuine SQL NULL (VARIANT_NULL is reserved for nested values), +// so its NULL is propagated via validity and never lands in the primitive bucket. + +//! Encode a single bound value into its comparator sort key by reusing the exact comparator encoding. +bool TryEncodeBoundKey(ClientContext &context, const Value &bound, string &result_key) { + if (bound.IsNull()) { + return false; + } + Value variant_value; + if (!bound.TryCastAs(context, LogicalType::VARIANT(), variant_value, nullptr)) { + return false; + } + if (variant_value.IsNull()) { + return false; + } + Vector input(LogicalType::VARIANT(), 1); + input.SetValue(0, variant_value); + Vector sort_key(LogicalType::BLOB, 1); + CreateVariantComparator(input, 1, sort_key); + auto key = sort_key.GetValue(0); + if (key.IsNull()) { + return false; + } + result_key = StringValue::Get(key); + return true; +} + +//! Extract the lower/upper bound of the (primitive) typed stats as Values - leaves a bound as NULL when +//! it is not available. +void GetTypedBounds(const BaseStatistics &typed_stats, Value &min_bound, Value &max_bound) { + switch (typed_stats.GetStatsType()) { + case StatisticsType::NUMERIC_STATS: + if (NumericStats::HasMin(typed_stats)) { + min_bound = NumericStats::Min(typed_stats); + } + if (NumericStats::HasMax(typed_stats)) { + max_bound = NumericStats::Max(typed_stats); + } + break; + case StatisticsType::STRING_STATS: + // only VARCHAR has UTF-8 ordered min/max we can turn into valid lower/upper bounds + if (typed_stats.GetType().id() == LogicalTypeId::VARCHAR) { + if (StringStats::HasMin(typed_stats)) { + min_bound = StringStats::TryGetValidMin(typed_stats); + } + if (StringStats::HasMax(typed_stats)) { + max_bound = StringStats::TryGetValidMax(typed_stats); + } + } + break; + default: + break; + } +} + +unique_ptr VariantComparatorStats(ClientContext &context, FunctionStatisticsInput &input) { + auto &variant_stats = input.child_stats[0]; + // Keep the (loosest) validity baseline - the BLOB sort key is NULL only when the input variant is a + // SQL NULL, but the input validity is not propagated here: this code only narrows the value bounds, + // and over-claiming "cannot be NULL" off of incomplete child stats would prune valid rows. + auto result = BaseStatistics::CreateUnknown(input.expr.GetReturnType()); + + if (variant_stats.GetStatsType() != StatisticsType::VARIANT_STATS || !VariantStats::IsShredded(variant_stats)) { + return result.ToUnique(); + } + auto &shredded_stats = VariantStats::GetShreddedStats(variant_stats); + if (!VariantShreddedStats::IsFullyShredded(shredded_stats)) { + // values can live in the unshredded component with arbitrary type ranks - no bucket + return result.ToUnique(); + } + auto &typed_stats = VariantStats::GetTypedStats(shredded_stats); + if (typed_stats.GetType().IsNested()) { + // only a primitive shredding maps every value into a single comparator bucket + return result.ToUnique(); + } + + Value min_bound, max_bound; + GetTypedBounds(typed_stats, min_bound, max_bound); + + string min_key; + if (TryEncodeBoundKey(context, min_bound, min_key)) { + StringStats::SetMin(result, string_t(min_key.data(), NumericCast(min_key.size())), + StringStatsType::TRUNCATED_STATS); + } + string max_key; + if (TryEncodeBoundKey(context, max_bound, max_key)) { + StringStats::SetMax(result, string_t(max_key.data(), NumericCast(max_key.size())), + StringStatsType::TRUNCATED_STATS); + } + return result.ToUnique(); +} + } // namespace ScalarFunction VariantComparatorFun::GetFunction() { auto variant_type = LogicalType::VARIANT(); - return ScalarFunction("variant_comparator", {variant_type}, LogicalType::BLOB, VariantComparatorFunction); + ScalarFunction function("variant_comparator", {variant_type}, LogicalType::BLOB, VariantComparatorFunction); + function.SetStatisticsCallback(VariantComparatorStats); + return function; } } // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp b/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp index 6fcdc81ab..f85a93c52 100644 --- a/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp +++ b/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp @@ -246,10 +246,6 @@ unique_ptr ArrowType::GetTypeFromFormat(ClientContext &context, Arrow } else if (format == "+s") { child_list_t child_types; vector> children; - if (schema.n_children == 0) { - throw InvalidInputException( - "Attempted to convert a STRUCT with no fields to DuckDB which is not supported"); - } for (idx_t type_idx = 0; type_idx < static_cast(schema.n_children); type_idx++) { children.emplace_back(GetArrowLogicalType(context, *schema.children[type_idx])); child_types.emplace_back(schema.children[type_idx]->name, children.back()->GetDuckType()); diff --git a/src/duckdb/src/function/table/system/test_all_types.cpp b/src/duckdb/src/function/table/system/test_all_types.cpp index 9357a9c6d..357bef7fe 100644 --- a/src/duckdb/src/function/table/system/test_all_types.cpp +++ b/src/duckdb/src/function/table/system/test_all_types.cpp @@ -201,6 +201,13 @@ vector TestAllTypesFun::GetTestTypes(const bool use_large_enum, const result.emplace_back(struct_type, "struct", min_struct_val, max_struct_val); + // Empty struct + child_list_t empty_struct_values; + child_list_t empty_struct_types; + auto empty_struct_type = LogicalType::STRUCT(empty_struct_types); + auto empty_struct_val = Value::STRUCT(empty_struct_values); + result.emplace_back(empty_struct_type, "empty_struct", empty_struct_val, empty_struct_val); + // structs with lists child_list_t struct_list_type_list; struct_list_type_list.emplace_back(make_pair("a", int_list_type)); diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index a81baf279..423ef4229 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev8946" +#define DUCKDB_PATCH_VERSION "0-dev9045" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 6 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.6.0-dev8946" +#define DUCKDB_VERSION "v1.6.0-dev9045" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "18b593788d" +#define DUCKDB_SOURCE_ID "17ba7dd2b6" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp index 10dc1a26e..352f96ee6 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp @@ -154,9 +154,13 @@ class TableCatalogEntry : public StandardEntry { //! Scan all triggers on this table (default: no-op - non-DuckDB tables have no triggers) virtual void ScanTriggers(CatalogTransaction transaction, const std::function &callback) const; - //! Collect triggers matching the given timing and event type + //! Collect triggers matching the given event type and for_each granularity, regardless of timing vector> - GetTriggersForEvent(CatalogTransaction transaction, TriggerTiming timing, TriggerEventType event_type) const; + GetTriggersForEvent(CatalogTransaction transaction, TriggerEventType event_type, TriggerForEach for_each) const; + //! Collect triggers matching the given timing, event type, and for_each granularity + vector> GetTriggersForEvent(CatalogTransaction transaction, + TriggerTiming timing, TriggerEventType event_type, + TriggerForEach for_each) const; protected: //! A list of columns that are part of this table diff --git a/src/duckdb/src/include/duckdb/common/enum_util.hpp b/src/duckdb/src/include/duckdb/common/enum_util.hpp index d2e0e6b01..f28d89e1c 100644 --- a/src/duckdb/src/include/duckdb/common/enum_util.hpp +++ b/src/duckdb/src/include/duckdb/common/enum_util.hpp @@ -492,6 +492,8 @@ enum class TableColumnType : uint8_t; enum class TableFilterType : uint8_t; +enum class TableFunctionIdentifierConversion : uint8_t; + enum class TableFunctionParallelism : uint8_t; enum class TablePartitionInfo : uint8_t; @@ -1251,6 +1253,9 @@ const char* EnumUtil::ToChars(TableColumnType value); template<> const char* EnumUtil::ToChars(TableFilterType value); +template<> +const char* EnumUtil::ToChars(TableFunctionIdentifierConversion value); + template<> const char* EnumUtil::ToChars(TableFunctionParallelism value); @@ -2044,6 +2049,9 @@ TableColumnType EnumUtil::FromString(const char *value); template<> TableFilterType EnumUtil::FromString(const char *value); +template<> +TableFunctionIdentifierConversion EnumUtil::FromString(const char *value); + template<> TableFunctionParallelism EnumUtil::FromString(const char *value); diff --git a/src/duckdb/src/include/duckdb/common/enums/debug_vector_verification.hpp b/src/duckdb/src/include/duckdb/common/enums/debug_vector_verification.hpp index 7a48243b1..e9370f4fa 100644 --- a/src/duckdb/src/include/duckdb/common/enums/debug_vector_verification.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/debug_vector_verification.hpp @@ -19,7 +19,8 @@ enum class DebugVectorVerification : uint8_t { CONSTANT_OPERATOR, SEQUENCE_OPERATOR, NESTED_SHUFFLE, - VARIANT_VECTOR + VARIANT_VECTOR, + SHREDDED_VECTOR }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp b/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp index 6dd62ce67..54f6539ff 100644 --- a/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/logical_operator_type.hpp @@ -68,6 +68,7 @@ enum class LogicalOperatorType : uint8_t { LOGICAL_DELETE = 101, LOGICAL_UPDATE = 102, LOGICAL_MERGE_INTO = 103, + LOGICAL_TRIGGER = 104, // ----------------------------- // Schema diff --git a/src/duckdb/src/include/duckdb/common/enums/table_function_identifier_conversion.hpp b/src/duckdb/src/include/duckdb/common/enums/table_function_identifier_conversion.hpp new file mode 100644 index 000000000..bbc6b012b --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/table_function_identifier_conversion.hpp @@ -0,0 +1,21 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/table_function_identifier_conversion.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class TableFunctionIdentifierConversion : uint8_t { + DEFAULT = 0, + ENABLE_IMPLICIT_STRING = 1, + DISABLE_IMPLICIT_STRING = 2 +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types.hpp b/src/duckdb/src/include/duckdb/common/types.hpp index ebdca4d00..297a6fb76 100644 --- a/src/duckdb/src/include/duckdb/common/types.hpp +++ b/src/duckdb/src/include/duckdb/common/types.hpp @@ -38,6 +38,9 @@ struct VectorStructType; template struct VectorListType; +//! Marker type selecting the VARIANT specialization of VectorIterator (see variant_iterator.hpp) +struct VectorVariantType; + template buffer_ptr make_buffer(ARGS &&...args) { // NOLINT: mimic std casing return make_shared_ptr(std::forward(args)...); diff --git a/src/duckdb/src/include/duckdb/common/types/variant_iterator.hpp b/src/duckdb/src/include/duckdb/common/types/variant_iterator.hpp index 69cac2e30..744d971ff 100644 --- a/src/duckdb/src/include/duckdb/common/types/variant_iterator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/variant_iterator.hpp @@ -31,7 +31,8 @@ enum class VariantIterationOrder { //===--------------------------------------------------------------------===// // VariantIterator //===--------------------------------------------------------------------===// -// VariantIterator iterates over the logical values of a VARIANT vector *without* unshredding it. +// VariantIterator iterates over the logical values of a VARIANT vector *without* unshredding it, +// handing out a VariantNode cursor per row (each cursor points at a single logical value/node). // // A VARIANT vector is either stored in its canonical "unshredded" layout: // STRUCT( @@ -95,19 +96,19 @@ struct ShreddedVariantIterator { LogicalType logical_type; }; -class VariantIterator; +class VariantNode; //! Shared state required to iterate a single VARIANT vector. Owns the vector iterators / flattened -//! vectors that the individual VariantIterator cursors point into - so it must outlive any cursor. -class VariantIteratorState { +//! vectors that the individual VariantNode cursors point into - so it must outlive any cursor. +class VariantIterator { public: - explicit VariantIteratorState(const Vector &variant); + explicit VariantIterator(const Vector &variant); public: //! Whether the row is a (SQL) NULL variant bool RowIsValid(idx_t row) const; //! Returns a cursor pointing at the root value of the given row - VariantIterator Root(idx_t row) const; + VariantNode Root(idx_t row) const; private: //! The "core": the unshredded component reader (plain vector iterators) @@ -118,7 +119,7 @@ class VariantIteratorState { //! The shredded component - the (recursive) view of the root of the shredded tree ShreddedVariantIterator shredded_format; - friend class VariantIterator; + friend class VariantNode; friend class VariantArrayIterator; friend class VariantObjectIterator; }; @@ -127,7 +128,7 @@ class VariantArrayIterator; class VariantObjectIterator; //! A lightweight cursor pointing at a single logical VARIANT value. -class VariantIterator { +class VariantNode { public: enum class Kind { NULL_VALUE, //! a (SQL/variant) NULL value @@ -137,7 +138,7 @@ class VariantIterator { }; public: - VariantIterator() : state(nullptr), kind(Kind::NULL_VALUE) { + VariantNode() : state(nullptr), kind(Kind::NULL_VALUE) { } public: @@ -172,17 +173,18 @@ class VariantIterator { private: //! Resolve the shredded node (a "STRUCT(typed_value, [untyped_value_index])" wrapper, or a //! flattened primitive) at the given index into a concrete cursor - static VariantIterator ResolveShredded(const VariantIteratorState &state, const ShreddedVariantIterator &node, - idx_t index, idx_t row); + static VariantNode ResolveShredded(const VariantIterator &state, const ShreddedVariantIterator &node, idx_t index, + idx_t row); - static VariantIterator MakeUnshredded(const VariantIteratorState &state, idx_t row, uint32_t value_index); - static VariantIterator MakeShredded(const VariantIteratorState &state, const ShreddedVariantIterator &content, - idx_t index, idx_t row, uint32_t overlay_value_index); - static VariantIterator MakeNull(const VariantIteratorState &state); - static VariantIterator MakeMissing(const VariantIteratorState &state); + static VariantNode MakeUnshredded(const VariantIterator &state, idx_t row, uint32_t value_index); + static VariantNode MakeShredded(const VariantIterator &state, const ShreddedVariantIterator &content, idx_t index, + idx_t row, uint32_t overlay_value_index); + static VariantNode MakeNull(const VariantIterator &state); + static VariantNode MakeMissing(const VariantIterator &state); private: - const VariantIteratorState *state; + //! The owning iterator this value lives in (null only for a default-constructed cursor) + optional_ptr state; Kind kind; //! The row this value belongs to (used for the unshredded component / overlay lookups) @@ -199,28 +201,28 @@ class VariantIterator { //! (0 means there is no leftover object to merge) uint32_t overlay_value_index = 0; - friend class VariantIteratorState; + friend class VariantIterator; friend class VariantArrayIterator; friend class VariantObjectIterator; }; -//! Lazily iterates the element values of an ARRAY VariantIterator. Random-access: no child cursor is +//! Lazily iterates the element values of an ARRAY VariantNode. Random-access: no child cursor is //! materialized until it is dereferenced. class VariantArrayIterator { public: - explicit VariantArrayIterator(const VariantIterator &array); + explicit VariantArrayIterator(const VariantNode &array); public: idx_t size() const { return length; } - VariantIterator operator[](idx_t i) const; + VariantNode operator[](idx_t i) const; class Iterator { public: Iterator(const VariantArrayIterator &parent, idx_t pos) : parent(parent), pos(pos) { } - VariantIterator operator*() const { + VariantNode operator*() const { return parent[pos]; } Iterator &operator++() { // NOLINT: match stl API @@ -243,7 +245,8 @@ class VariantArrayIterator { } private: - const VariantIteratorState *state; + //! The owning iterator this array's elements live in (always non-null for a real ARRAY node) + reference state; idx_t row; idx_t length; bool shredded; @@ -256,16 +259,16 @@ class VariantArrayIterator { //! A single (key, value) entry of an OBJECT struct VariantObjectEntry { string_t key; - VariantIterator value; + VariantNode value; }; -//! Iterates the (key, value) children of an OBJECT VariantIterator, merging the shredded (typed) +//! Iterates the (key, value) children of an OBJECT VariantNode, merging the shredded (typed) //! fields with the leftover unshredded fields. There are two backing modes: //! - INTERNAL order: lazy forward iteration over the raw entries (skipping missing fields) //! - LEXICOGRAPHIC order: iterates the materialized + sorted 'ordered_entries' class VariantObjectIterator { public: - VariantObjectIterator(const VariantIterator &object, VariantIterationOrder order); + VariantObjectIterator(const VariantNode &object, VariantIterationOrder order); public: //! Forward iterator over the object entries. All modes are position-based, so the only difference is @@ -315,7 +318,8 @@ class VariantObjectIterator { } private: - const VariantIteratorState *state; + //! The owning iterator this object's values live in (always non-null for a real OBJECT node) + reference state; idx_t row; VariantIterationOrder order; bool shredded; @@ -338,4 +342,57 @@ class VariantObjectIterator { friend class Iterator; }; +//! Specialization of VectorIterator for VectorVariantType. +//! Iterates over a VARIANT vector, handing out a VariantNode cursor per row via operator[]. +//! The cursors point back into the owned VariantIterator, so this iterator must outlive them. +template <> +class VectorIterator { +public: + explicit VectorIterator(const Vector &vector) : state(vector), count(vector.size()) { + } + +public: + //! Returns a cursor pointing at the root VARIANT value of the given row + VariantNode operator[](idx_t row) const { + return state.Root(row); + } + //! Whether the row is a valid (non-NULL) variant + bool RowIsValid(idx_t row) const { + return state.RowIsValid(row); + } + idx_t size() const { + return count; + } + + class Iterator { + public: + Iterator(const VectorIterator &parent, idx_t index) : parent(parent), index(index) { + } + VariantNode operator*() const { + return parent[index]; + } + Iterator &operator++() { // NOLINT: match stl API + ++index; + return *this; + } + bool operator!=(const Iterator &other) const { + return index != other.index; + } + + private: + const VectorIterator &parent; + idx_t index; + }; + Iterator begin() const { // NOLINT: match stl API + return Iterator(*this, 0); + } + Iterator end() const { // NOLINT: match stl API + return Iterator(*this, count); + } + +private: + VariantIterator state; + idx_t count; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/lambda_functions.hpp b/src/duckdb/src/include/duckdb/function/lambda_functions.hpp index 7fc9753eb..baa6b62db 100644 --- a/src/duckdb/src/include/duckdb/function/lambda_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/lambda_functions.hpp @@ -18,7 +18,11 @@ namespace duckdb { -struct ListLambdaBindData final : public FunctionData { +struct LambdaFunctionData : public FunctionData { + DUCKDB_API virtual optional_ptr GetLambdaExpression() const = 0; +}; + +struct ListLambdaBindData final : public LambdaFunctionData { public: ListLambdaBindData(const LogicalType &return_type, unique_ptr lambda_expr, const bool has_index = false, const bool has_initial = false) @@ -50,6 +54,10 @@ struct ListLambdaBindData final : public FunctionData { const BoundScalarFunction &function); //! Deserializes a lambda function's bind data static unique_ptr Deserialize(Deserializer &deserializer, BoundScalarFunction &); + + optional_ptr GetLambdaExpression() const override { + return lambda_expr.get(); + } }; class LambdaFunctions { diff --git a/src/duckdb/src/include/duckdb/function/scalar/generic_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/generic_functions.hpp index f502b4e82..f476ff431 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/generic_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/generic_functions.hpp @@ -55,4 +55,14 @@ struct CreateSortKeyFun { static ScalarFunction GetFunction(); }; +struct InvokeFun { + static constexpr const char *Name = "invoke"; + static constexpr const char *Parameters = "lambda,arg1,arg2,..."; + static constexpr const char *Description = "Invokes a lambda function with the given arguments"; + static constexpr const char *Example = "invoke(x -> x + 1, 5)"; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp b/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp index 74ec66fd8..f7470757d 100644 --- a/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp +++ b/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp @@ -48,11 +48,13 @@ struct VariantShreddingStats { public: void Update(const Vector &input, idx_t count); - LogicalType GetShreddedType() const; + //! If force_partial is set, every level keeps its 'untyped_value_index' (overlay) column even when the + //! sampled values are fully consistent - allowing later inconsistent values to be partially shredded. + LogicalType GetShreddedType(bool force_partial = false) const; private: bool GetShreddedTypeInternal(const VariantColumnStatsData &column, LogicalType &out_type, - optional_idx parent_count = optional_idx()) const; + optional_idx parent_count = optional_idx(), bool force_partial = false) const; private: //! Nested type analysis diff --git a/src/duckdb/src/include/duckdb/main/extension_entries.hpp b/src/duckdb/src/include/duckdb/main/extension_entries.hpp index d3ac41792..8ac05ab1f 100644 --- a/src/duckdb/src/include/duckdb/main/extension_entries.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_entries.hpp @@ -866,19 +866,19 @@ static constexpr ExtensionFunctionOverloadEntry EXTENSION_FUNCTION_OVERLOADS[] = {"date_diff", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIME,TIME]>BIGINT"}, {"date_diff", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIMESTAMP,TIMESTAMP]>BIGINT"}, {"date_diff", "icu", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIMESTAMPTZ,TIMESTAMPTZ]>BIGINT"}, - {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',DATE]>STRUCT()"}, - {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',INTERVAL]>STRUCT()"}, - {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIME]>STRUCT()"}, - {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMESTAMP]>STRUCT()"}, - {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMETZ]>STRUCT()"}, - {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIME_NS]>STRUCT()"}, + {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',DATE]>STRUCT"}, + {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',INTERVAL]>STRUCT"}, + {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIME]>STRUCT"}, + {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMESTAMP]>STRUCT"}, + {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMETZ]>STRUCT"}, + {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIME_NS]>STRUCT"}, {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,DATE]>BIGINT"}, {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,INTERVAL]>BIGINT"}, {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIME]>BIGINT"}, {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIMESTAMP]>BIGINT"}, {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIMETZ]>BIGINT"}, {"date_part", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIME_NS]>BIGINT"}, - {"date_part", "icu", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMESTAMPTZ]>STRUCT()"}, + {"date_part", "icu", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMESTAMPTZ]>STRUCT"}, {"date_part", "icu", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIMESTAMPTZ]>BIGINT"}, {"date_sub", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,DATE,DATE]>BIGINT"}, {"date_sub", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIME,TIME]>BIGINT"}, @@ -892,19 +892,19 @@ static constexpr ExtensionFunctionOverloadEntry EXTENSION_FUNCTION_OVERLOADS[] = {"datediff", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIME,TIME]>BIGINT"}, {"datediff", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIMESTAMP,TIMESTAMP]>BIGINT"}, {"datediff", "icu", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIMESTAMPTZ,TIMESTAMPTZ]>BIGINT"}, - {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',DATE]>STRUCT()"}, - {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',INTERVAL]>STRUCT()"}, - {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIME]>STRUCT()"}, - {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMESTAMP]>STRUCT()"}, - {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMETZ]>STRUCT()"}, - {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIME_NS]>STRUCT()"}, + {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',DATE]>STRUCT"}, + {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',INTERVAL]>STRUCT"}, + {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIME]>STRUCT"}, + {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMESTAMP]>STRUCT"}, + {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMETZ]>STRUCT"}, + {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIME_NS]>STRUCT"}, {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,DATE]>BIGINT"}, {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,INTERVAL]>BIGINT"}, {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIME]>BIGINT"}, {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIMESTAMP]>BIGINT"}, {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIMETZ]>BIGINT"}, {"datepart", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIME_NS]>BIGINT"}, - {"datepart", "icu", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMESTAMPTZ]>STRUCT()"}, + {"datepart", "icu", CatalogType::SCALAR_FUNCTION_ENTRY, "['VARCHAR[]',TIMESTAMPTZ]>STRUCT"}, {"datepart", "icu", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIMESTAMPTZ]>BIGINT"}, {"datesub", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,DATE,DATE]>BIGINT"}, {"datesub", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY, "[VARCHAR,TIME,TIME]>BIGINT"}, diff --git a/src/duckdb/src/include/duckdb/main/settings.hpp b/src/duckdb/src/include/duckdb/main/settings.hpp index 518f46b45..c97bde0f2 100644 --- a/src/duckdb/src/include/duckdb/main/settings.hpp +++ b/src/duckdb/src/include/duckdb/main/settings.hpp @@ -531,6 +531,28 @@ struct DebugEvictionQueueSleepMicroSecondsSetting { static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); }; +struct DebugForceCommitFailureSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "debug_force_commit_failure"; + static constexpr const char *Description = "DEBUG SETTING: force transaction commit to fail after the undo buffer " + "has been committed, used for testing commit error recovery"; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "false"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); +}; + +struct DebugForceCommitRevertFailureSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "debug_force_commit_revert_failure"; + static constexpr const char *Description = + "DEBUG SETTING: force RevertCommit to fail while recovering from a commit failure, used for testing"; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "false"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); +}; + struct DebugForceExternalSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "debug_force_external"; @@ -787,7 +809,7 @@ struct DelimJoinAsCteSetting { static constexpr const char *Description = "Rewrite delim joins to materialized CTEs during dependent join flattening"; static constexpr const char *InputType = "BOOLEAN"; - static constexpr const char *DefaultValue = "false"; + static constexpr const char *DefaultValue = "true"; static constexpr SettingScopeTarget Scope = SettingScopeTarget::LOCAL_DEFAULT; static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); }; @@ -1848,6 +1870,18 @@ struct StreamingBufferSizeSetting { static Value GetSetting(const ClientContext &context); }; +struct TableFunctionIdentifierConversionSetting { + using RETURN_TYPE = TableFunctionIdentifierConversion; + static constexpr const char *Name = "table_function_identifier_conversion"; + static constexpr const char *Description = "Configures the use of deprecated implicit conversion of unbound " + "identifiers to strings in table function arguments."; + static constexpr const char *InputType = "VARCHAR"; + static constexpr const char *DefaultValue = "DEFAULT"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::LOCAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); + static void OnSet(SettingCallbackInfo &info, Value &input); +}; + struct TempDirectorySetting { using RETURN_TYPE = string; static constexpr const char *Name = "temp_directory"; diff --git a/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp b/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp index 274464d23..e460a2a94 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp @@ -154,7 +154,7 @@ const char INLINED_PEG_GRAMMAR[] = { "CatalogReservedSchemaTypeName <- CatalogQualification ReservedSchemaQualification ReservedTypeName\n" "SchemaReservedTypeName <- SchemaQualification ReservedTypeName\n" "TypeModifiers <- Parens(List(Expression)?)\n" - "RowType <- RowOrStruct ColIdTypeList\n" + "RowType <- RowOrStruct ColIdTypeList?\n" "SetofType <- 'SETOF' Type\n" "UnionType <- 'UNION' ColIdTypeList\n" "ColIdTypeList <- Parens(List(ColIdType))\n" @@ -1462,7 +1462,10 @@ const char INLINED_PEG_GRAMMAR[] = { "DeleteUsingClause <- 'USING' List(TableRef)\n" "ConnectStatement <- 'CONNECT' SessionTarget?\n" "DisconnectStatement <- 'DISCONNECT'\n" - "SessionTarget <- 'LOCAL' / StringLiteral / CatalogName\n" + "SessionTarget <- LocalSessionTarget / StringSessionTarget / CatalogSessionTarget\n" + "LocalSessionTarget <- 'LOCAL'\n" + "StringSessionTarget <- StringLiteral\n" + "CatalogSessionTarget <- CatalogName\n" "CreateTypeStmt <- 'TYPE' IfNotExists? QualifiedName 'AS' CreateType\n" "CreateType <- EnumSelectType / EnumStringLiteralList / CreateTypeFromType\n" "CreateTypeFromType <- Type\n" diff --git a/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp b/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp index a90fc6d71..ad5793360 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp @@ -44,6 +44,7 @@ #include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/expression/parameter_expression.hpp" #include "duckdb/parser/expression/window_expression.hpp" +#include "duckdb/parser/parsed_data/connect_info.hpp" #include "duckdb/parser/parsed_data/create_type_info.hpp" #include "duckdb/parser/parsed_data/transaction_info.hpp" #include "duckdb/parser/parsed_data/vacuum_info.hpp" @@ -373,9 +374,6 @@ class PEGTransformerFactory { static unique_ptr TransformStatement(PEGTransformer &, ParseResult &list); - // connect.gram — both rules have optional sub-clauses, so the generator skips them and we - // hand-write the (PEGTransformer&, ParseResult&) entry points. - static unique_ptr TransformConnectStatement(PEGTransformer &transformer, ParseResult &parse_result); // comment.gram static Value TransformCommentValue(PEGTransformer &transformer, ParseResult &parse_result); @@ -424,12 +422,12 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformAlterTableStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterTableStmt(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformAlterTableStmt(PEGTransformer &transformer, const optional &if_exists, unique_ptr base_table_name, vector> alter_table_options); static unique_ptr TransformAlterSchemaStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterSchemaStmt(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformAlterSchemaStmt(PEGTransformer &transformer, const optional &if_exists, const QualifiedName &qualified_name, unique_ptr rename_alter); static unique_ptr TransformAlterTableOptionsInternal(PEGTransformer &transformer, @@ -440,32 +438,35 @@ class PEGTransformerFactory { unique_ptr top_level_constraint); static unique_ptr TransformAddColumnInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAddColumn(PEGTransformer &transformer, const bool &if_not_exists, + static unique_ptr TransformAddColumn(PEGTransformer &transformer, const bool &has_result, + const optional &if_not_exists, AddColumnEntry add_column_entry); static unique_ptr TransformAddColumnEntryInternal(PEGTransformer &transformer, ParseResult &parse_result); static AddColumnEntry TransformAddColumnEntry(PEGTransformer &transformer, const vector &dotted_identifier, - const LogicalType &type, GeneratedColumnDefinition generated_column, - vector column_constraint); + const optional &type, + optional generated_column, + optional> column_constraint); static unique_ptr TransformDropColumnInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropColumn(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformDropColumn(PEGTransformer &transformer, const bool &has_result, + const optional &if_exists, unique_ptr nested_column_name, - const bool &drop_behavior); + const optional &drop_behavior); static unique_ptr TransformAlterColumnInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterColumn(PEGTransformer &transformer, + static unique_ptr TransformAlterColumn(PEGTransformer &transformer, const bool &has_result, unique_ptr nested_column_name, unique_ptr alter_column_entry); static unique_ptr TransformRenameColumnInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformRenameColumn(PEGTransformer &transformer, + static unique_ptr TransformRenameColumn(PEGTransformer &transformer, const bool &has_result, unique_ptr nested_column_name, const Identifier &identifier); static unique_ptr TransformNestedColumnNameInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformNestedColumnName(PEGTransformer &transformer, - const vector &identifier_dot, + const optional> &identifier_dot, const Identifier &column_name); static unique_ptr TransformIdentifierDotInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -522,27 +523,29 @@ class PEGTransformerFactory { static string TransformSetNullability(PEGTransformer &transformer); static unique_ptr TransformAlterTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterType(PEGTransformer &transformer, const LogicalType &type, - unique_ptr using_expression); + static unique_ptr TransformAlterType(PEGTransformer &transformer, const bool &has_result, + const optional &type, + optional> using_expression); static unique_ptr TransformUsingExpressionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformUsingExpression(PEGTransformer &transformer, unique_ptr expression); static unique_ptr TransformAlterViewStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterViewStmt(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformAlterViewStmt(PEGTransformer &transformer, const optional &if_exists, unique_ptr base_table_name, unique_ptr rename_alter); static unique_ptr TransformAlterSequenceStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterSequenceStmt(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformAlterSequenceStmt(PEGTransformer &transformer, + const optional &if_exists, const QualifiedName &qualified_sequence_name, unique_ptr alter_sequence_options); static unique_ptr TransformQualifiedSequenceNameInternal(PEGTransformer &transformer, ParseResult &parse_result); static QualifiedName TransformQualifiedSequenceName(PEGTransformer &transformer, - const Identifier &catalog_qualification, - const Identifier &schema_qualification, + const optional &catalog_qualification, + const optional &schema_qualification, const Identifier &sequence_name); static unique_ptr TransformAlterSequenceOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -554,27 +557,29 @@ class PEGTransformerFactory { vector>> sequence_option); static unique_ptr TransformAlterDatabaseStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAlterDatabaseStmt(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformAlterDatabaseStmt(PEGTransformer &transformer, + const optional &if_exists, const Identifier &identifier, const Identifier &identifier_1); static unique_ptr TransformAnalyzeStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAnalyzeStatement(PEGTransformer &transformer, const bool &analyze_verbose, - AnalyzeTarget analyze_target); + static unique_ptr TransformAnalyzeStatement(PEGTransformer &transformer, + const optional &analyze_verbose, + optional analyze_target); static unique_ptr TransformAnalyzeTargetInternal(PEGTransformer &transformer, ParseResult &parse_result); static AnalyzeTarget TransformAnalyzeTarget(PEGTransformer &transformer, unique_ptr base_table_name, - const vector &name_list); + const optional> &name_list); static unique_ptr TransformAnalyzeVerboseInternal(PEGTransformer &transformer, ParseResult &parse_result); static bool TransformAnalyzeVerbose(PEGTransformer &transformer); static unique_ptr TransformAttachStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformAttachStatement(PEGTransformer &transformer, const bool &or_replace, - const bool &if_not_exists, - unique_ptr database_path, - const Identifier &attach_alias, - const vector &attach_options); + static unique_ptr + TransformAttachStatement(PEGTransformer &transformer, const optional &or_replace, + const optional &if_not_exists, const bool &has_result, + unique_ptr database_path, const optional &attach_alias, + const optional> &attach_options); static unique_ptr TransformDatabasePathInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformDatabasePath(PEGTransformer &transformer, @@ -594,8 +599,8 @@ class PEGTransformerFactory { static unique_ptr TransformCheckpointStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformCheckpointStatement(PEGTransformer &transformer, - const bool &checkpoint_force, - const Identifier &catalog_name); + const optional &checkpoint_force, + const optional &catalog_name); static unique_ptr TransformCheckpointForceInternal(PEGTransformer &transformer, ParseResult &parse_result); static bool TransformCheckpointForce(PEGTransformer &transformer); @@ -655,7 +660,7 @@ class PEGTransformerFactory { static unique_ptr TransformTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static LogicalType TransformType(PEGTransformer &transformer, unique_ptr type_variations, - const vector &array_bounds); + const optional> &array_bounds); static unique_ptr TransformTypeVariationsInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformSimpleTypeInternal(PEGTransformer &transformer, @@ -663,16 +668,13 @@ class PEGTransformerFactory { static unique_ptr TransformCharacterSimpleTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr - TransformCharacterSimpleType(PEGTransformer &transformer, const string &character_type, - vector> type_modifiers); + TransformCharacterSimpleType(PEGTransformer &transformer, + optional>> type_modifiers); static unique_ptr TransformQualifiedSimpleTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformQualifiedSimpleType(PEGTransformer &transformer, const QualifiedName &qualified_type_name, - vector> type_modifiers); - static unique_ptr TransformCharacterTypeInternal(PEGTransformer &transformer, - ParseResult &parse_result); - static string TransformCharacterType(PEGTransformer &transformer); + optional>> type_modifiers); static unique_ptr TransformIntervalTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformIntervalIntervalInternal(PEGTransformer &transformer, @@ -765,12 +767,12 @@ class PEGTransformerFactory { const DatePartSpecifier &second_keyword); static unique_ptr TransformBitTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformBitType(PEGTransformer &transformer, - vector> expression); + static unique_ptr TransformBitType(PEGTransformer &transformer, const bool &has_result, + optional>> expression); static unique_ptr TransformGeometryTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformGeometryType(PEGTransformer &transformer, - unique_ptr expression); + optional> expression); static unique_ptr TransformVariantTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformVariantType(PEGTransformer &transformer); @@ -805,19 +807,19 @@ class PEGTransformerFactory { static unique_ptr TransformFloatTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformFloatType(PEGTransformer &transformer, - unique_ptr number_literal); + optional> number_literal); static unique_ptr TransformDecimalTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDecimalType(PEGTransformer &transformer, - vector> type_modifiers); + static unique_ptr + TransformDecimalType(PEGTransformer &transformer, optional>> type_modifiers); static unique_ptr TransformDecTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformDecType(PEGTransformer &transformer, - vector> type_modifiers); + optional>> type_modifiers); static unique_ptr TransformNumericModTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformNumericModType(PEGTransformer &transformer, - vector> type_modifiers); + static unique_ptr + TransformNumericModType(PEGTransformer &transformer, optional>> type_modifiers); static unique_ptr TransformQualifiedTypeNameInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformTypeNameAsQualifiedNameInternal(PEGTransformer &transformer, @@ -836,8 +838,8 @@ class PEGTransformerFactory { const Identifier &reserved_type_name); static unique_ptr TransformTypeModifiersInternal(PEGTransformer &transformer, ParseResult &parse_result); - static vector> TransformTypeModifiers(PEGTransformer &transformer, - vector> expression); + static vector> + TransformTypeModifiers(PEGTransformer &transformer, optional>> expression); static unique_ptr TransformRowTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformRowType(PEGTransformer &transformer, @@ -867,13 +869,14 @@ class PEGTransformerFactory { static int64_t TransformArrayKeyword(PEGTransformer &transformer); static unique_ptr TransformSquareBracketsArrayInternal(PEGTransformer &transformer, ParseResult &parse_result); - static int64_t TransformSquareBracketsArray(PEGTransformer &transformer, unique_ptr expression); + static int64_t TransformSquareBracketsArray(PEGTransformer &transformer, + optional> expression); static unique_ptr TransformTimeTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformTimeType(PEGTransformer &transformer, const LogicalTypeId &time_or_timestamp, - vector> type_modifiers, - const bool &time_zone); + optional>> type_modifiers, + const optional &time_zone); static unique_ptr TransformTimeOrTimestampInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformTimeTypeIdInternal(PEGTransformer &transformer, @@ -893,9 +896,26 @@ class PEGTransformerFactory { static unique_ptr TransformWithoutRuleInternal(PEGTransformer &transformer, ParseResult &parse_result); static bool TransformWithoutRule(PEGTransformer &transformer); + static unique_ptr TransformConnectStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformConnectStatement(PEGTransformer &transformer, + optional> session_target); static unique_ptr TransformDisconnectStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformDisconnectStatement(PEGTransformer &transformer); + static unique_ptr TransformSessionTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformLocalSessionTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformLocalSessionTarget(PEGTransformer &transformer); + static unique_ptr TransformStringSessionTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformStringSessionTarget(PEGTransformer &transformer, + const string &string_literal); + static unique_ptr TransformCatalogSessionTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformCatalogSessionTarget(PEGTransformer &transformer, + const Identifier &catalog_name); static unique_ptr TransformCopyStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformCopyStatement(PEGTransformer &transformer, @@ -906,9 +926,10 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformCopyTable(PEGTransformer &transformer, unique_ptr base_table_name, - const vector &insert_column_list, const bool &from_or_to, + const optional> &insert_column_list, + const bool &from_or_to, unique_ptr copy_file_name, - const vector ©_options); + const optional> ©_options); static unique_ptr TransformFromOrToInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformCopyFromInternal(PEGTransformer &transformer, @@ -922,7 +943,7 @@ class PEGTransformerFactory { static unique_ptr TransformCopySelect(PEGTransformer &transformer, unique_ptr select_statement_internal, unique_ptr copy_file_name, - const vector ©_options); + const optional> ©_options); static unique_ptr TransformCopyFileNameInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformCopyFileNameExpressionInternal(PEGTransformer &transformer, @@ -945,14 +966,15 @@ class PEGTransformerFactory { const Identifier &col_id); static unique_ptr TransformCopyOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result); - static vector TransformCopyOptions(PEGTransformer &transformer, + static vector TransformCopyOptions(PEGTransformer &transformer, const bool &has_result, const vector ©_option_list); static unique_ptr TransformCopyOptionListInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformSpecializedOptionListInternal(PEGTransformer &transformer, ParseResult &parse_result); static vector - TransformSpecializedOptionList(PEGTransformer &transformer, const vector &specialized_option); + TransformSpecializedOptionList(PEGTransformer &transformer, + const optional> &specialized_option); static unique_ptr TransformSpecializedOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformSingleOptionInternal(PEGTransformer &transformer, @@ -974,22 +996,26 @@ class PEGTransformerFactory { static GenericCopyOption TransformHeaderOption(PEGTransformer &transformer); static unique_ptr TransformNullAsOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformNullAsOption(PEGTransformer &transformer, const string &string_literal); + static GenericCopyOption TransformNullAsOption(PEGTransformer &transformer, const bool &has_result, + const string &string_literal); static unique_ptr TransformDelimiterAsOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformDelimiterAsOption(PEGTransformer &transformer, const string &string_literal); + static GenericCopyOption TransformDelimiterAsOption(PEGTransformer &transformer, const bool &has_result, + const string &string_literal); static unique_ptr TransformQuoteAsOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformQuoteAsOption(PEGTransformer &transformer, const string &string_literal); + static GenericCopyOption TransformQuoteAsOption(PEGTransformer &transformer, const bool &has_result, + const string &string_literal); static unique_ptr TransformEscapeAsOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformEscapeAsOption(PEGTransformer &transformer, const string &string_literal); + static GenericCopyOption TransformEscapeAsOption(PEGTransformer &transformer, const bool &has_result, + const string &string_literal); static unique_ptr TransformEncodingOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); static GenericCopyOption TransformEncodingOption(PEGTransformer &transformer, const string &string_literal); static unique_ptr TransformForceQuoteOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformForceQuoteOption(PEGTransformer &transformer, const bool &force_quote, + static GenericCopyOption TransformForceQuoteOption(PEGTransformer &transformer, const optional &force_quote, const vector &star_symbol_column_list); static unique_ptr TransformStarSymbolColumnListInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -1002,7 +1028,7 @@ class PEGTransformerFactory { const vector &star_symbol_column_list); static unique_ptr TransformForceNullOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static GenericCopyOption TransformForceNullOption(PEGTransformer &transformer, const bool &force_not_null, + static GenericCopyOption TransformForceNullOption(PEGTransformer &transformer, const optional &force_not_null, const vector &column_list); static unique_ptr TransformForceNotNullInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -1014,7 +1040,7 @@ class PEGTransformerFactory { static unique_ptr TransformGenericCopyOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); static GenericCopyOption TransformGenericCopyOption(PEGTransformer &transformer, const Identifier ©_option_name, - GenericCopyOptionValue generic_copy_option_value); + optional generic_copy_option_value); static unique_ptr TransformGenericCopyOptionValueInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformGenericCopyOptionOrderListInternal(PEGTransformer &transformer, @@ -1059,11 +1085,14 @@ class PEGTransformerFactory { static CopyDatabaseType TransformCopyData(PEGTransformer &transformer); static unique_ptr TransformCreateIndexStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCreateIndexStmt( - PEGTransformer &transformer, const bool &unique_index, const bool &if_not_exists, const Identifier &index_name, - unique_ptr base_table_name, const vector &insert_column_list, - const Identifier &index_type, vector> index_element, - case_insensitive_map_t> with_list, unique_ptr where_clause); + static unique_ptr + TransformCreateIndexStmt(PEGTransformer &transformer, const optional &unique_index, + const optional &if_not_exists, const optional &index_name, + unique_ptr base_table_name, + const optional> &insert_column_list, const optional &index_type, + optional>> index_element, + optional>> with_list, + optional> where_clause); static unique_ptr TransformWithListInternal(PEGTransformer &transformer, ParseResult &parse_result); static case_insensitive_map_t> @@ -1092,8 +1121,8 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformIndexElement(PEGTransformer &transformer, unique_ptr expression, - const OrderType &desc_or_asc, - const OrderByNullType &nulls_first_or_last); + const optional &desc_or_asc, + const optional &nulls_first_or_last); static unique_ptr TransformUniqueIndexInternal(PEGTransformer &transformer, ParseResult &parse_result); static bool TransformUniqueIndex(PEGTransformer &transformer); @@ -1104,7 +1133,7 @@ class PEGTransformerFactory { ParseResult &parse_result); static pair> TransformRelOption(PEGTransformer &transformer, const Identifier &rel_option_name, - unique_ptr rel_option_argument_opt); + optional> rel_option_argument_opt); static unique_ptr TransformRelOptionNameInternal(PEGTransformer &transformer, ParseResult &parse_result); static Identifier TransformRelOptionName(PEGTransformer &transformer, const string &child); @@ -1133,9 +1162,11 @@ class PEGTransformerFactory { static unique_ptr TransformNoneLiteral(PEGTransformer &transformer); static unique_ptr TransformCreateMacroStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr - TransformCreateMacroStmt(PEGTransformer &transformer, const bool ¯o_or_function, const bool &if_not_exists, - const QualifiedName &qualified_name, vector> macro_definition); + static unique_ptr TransformCreateMacroStmt(PEGTransformer &transformer, + const bool ¯o_or_function, + const optional &if_not_exists, + const QualifiedName &qualified_name, + vector> macro_definition); static unique_ptr TransformMacroOrFunctionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformMacroKeywordInternal(PEGTransformer &transformer, @@ -1147,7 +1178,7 @@ class PEGTransformerFactory { static unique_ptr TransformMacroDefinitionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformMacroDefinition(PEGTransformer &transformer, - vector macro_parameters, + optional> macro_parameters, unique_ptr macro_definition_body); static unique_ptr TransformMacroDefinitionBodyInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -1160,7 +1191,7 @@ class PEGTransformerFactory { static unique_ptr TransformSimpleParameterInternal(PEGTransformer &transformer, ParseResult &parse_result); static MacroParameter TransformSimpleParameter(PEGTransformer &transformer, const Identifier &type_func_name, - const LogicalType &type); + const optional &type); static unique_ptr TransformScalarMacroDefinitionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformScalarMacroDefinition(PEGTransformer &transformer, @@ -1171,13 +1202,15 @@ class PEGTransformerFactory { TransformTableMacroDefinition(PEGTransformer &transformer, unique_ptr select_statement_internal); static unique_ptr TransformCreateSchemaStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCreateSchemaStmt(PEGTransformer &transformer, const bool &if_not_exists, + static unique_ptr TransformCreateSchemaStmt(PEGTransformer &transformer, + const optional &if_not_exists, const QualifiedName &qualified_name); static unique_ptr TransformCreateSecretStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr - TransformCreateSecretStmt(PEGTransformer &transformer, const bool &if_not_exists, const Identifier &secret_name, - const Identifier &secret_storage_specifier, + TransformCreateSecretStmt(PEGTransformer &transformer, const optional &if_not_exists, + const optional &secret_name, + const optional &secret_storage_specifier, const vector &generic_copy_option_list); static unique_ptr TransformSecretStorageSpecifierInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -1188,9 +1221,9 @@ class PEGTransformerFactory { static unique_ptr TransformCreateSequenceStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr - TransformCreateSequenceStmt(PEGTransformer &transformer, const bool &if_not_exists, + TransformCreateSequenceStmt(PEGTransformer &transformer, const optional &if_not_exists, const QualifiedName &qualified_name, - vector>> sequence_option); + optional>>> sequence_option); static unique_ptr TransformSequenceOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformSeqSetCycleInternal(PEGTransformer &transformer, @@ -1204,6 +1237,7 @@ class PEGTransformerFactory { static unique_ptr TransformSeqSetIncrementInternal(PEGTransformer &transformer, ParseResult &parse_result); static pair> TransformSeqSetIncrement(PEGTransformer &transformer, + const bool &has_result, unique_ptr expression); static unique_ptr TransformSeqSetMinMaxInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -1216,8 +1250,8 @@ class PEGTransformerFactory { const string &seq_min_or_max); static unique_ptr TransformSeqStartWithInternal(PEGTransformer &transformer, ParseResult &parse_result); - static pair> TransformSeqStartWith(PEGTransformer &transformer, - unique_ptr expression); + static pair> + TransformSeqStartWith(PEGTransformer &transformer, const bool &has_result, unique_ptr expression); static unique_ptr TransformSeqOwnedByInternal(PEGTransformer &transformer, ParseResult &parse_result); static pair> TransformSeqOwnedBy(PEGTransformer &transformer, @@ -1232,8 +1266,9 @@ class PEGTransformerFactory { static string TransformMaxValue(PEGTransformer &transformer); static unique_ptr TransformCreateStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCreateStatement(PEGTransformer &transformer, const bool &or_replace, - const SecretPersistType &temporary, + static unique_ptr TransformCreateStatement(PEGTransformer &transformer, + const optional &or_replace, + const optional &temporary, unique_ptr create_statement_variation); static unique_ptr TransformCreateStatementVariationInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -1253,18 +1288,20 @@ class PEGTransformerFactory { static SecretPersistType TransformTemporaryPersistent(PEGTransformer &transformer); static unique_ptr TransformCreateTableStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCreateTableStmt(PEGTransformer &transformer, const bool &if_not_exists, + static unique_ptr TransformCreateTableStmt(PEGTransformer &transformer, + const optional &if_not_exists, const QualifiedName &qualified_name, CreateTableDefinition create_table_definition, - const bool &commit_action); + const optional &commit_action); static unique_ptr TransformCreateTableDefinitionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformCreateTableAsInternal(PEGTransformer &transformer, ParseResult &parse_result); - static CreateTableDefinition TransformCreateTableAs(PEGTransformer &transformer, ColumnList identifier_list, - PartitionSortedOptions partition_sorted_options, - case_insensitive_map_t> with_list, - unique_ptr statement, const bool &with_data); + static CreateTableDefinition + TransformCreateTableAs(PEGTransformer &transformer, optional identifier_list, + optional partition_sorted_options, + optional>> with_list, + unique_ptr statement, const optional &with_data); static unique_ptr TransformPartitionSortedOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformPartitionOptSortedOptionsInternal(PEGTransformer &transformer, @@ -1272,12 +1309,12 @@ class PEGTransformerFactory { static PartitionSortedOptions TransformPartitionOptSortedOptions(PEGTransformer &transformer, vector> partition_options, - vector> sorted_options); + optional>> sorted_options); static unique_ptr TransformSortedOptPartitionOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result); static PartitionSortedOptions TransformSortedOptPartitionOptions(PEGTransformer &transformer, vector> sorted_options, - vector> partition_options); + optional>> partition_options); static unique_ptr TransformPartitionOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result); static vector> @@ -1300,9 +1337,9 @@ class PEGTransformerFactory { static unique_ptr TransformCreateColumnListInternal(PEGTransformer &transformer, ParseResult &parse_result); static CreateTableDefinition - TransformCreateColumnList(PEGTransformer &transformer, ColumnElements create_table_column_list, - PartitionSortedOptions partition_sorted_options, - case_insensitive_map_t> with_list); + TransformCreateColumnList(PEGTransformer &transformer, optional create_table_column_list, + optional partition_sorted_options, + optional>> with_list); static unique_ptr TransformIfNotExistsInternal(PEGTransformer &transformer, ParseResult &parse_result); static bool TransformIfNotExists(PEGTransformer &transformer); @@ -1356,11 +1393,10 @@ class PEGTransformerFactory { unique_ptr top_level_constraint); static unique_ptr TransformColumnDefinitionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static ConstraintColumnDefinition TransformColumnDefinition(PEGTransformer &transformer, - const vector &dotted_identifier, - const LogicalType &type, - GeneratedColumnDefinition generated_column, - vector column_constraint); + static ConstraintColumnDefinition + TransformColumnDefinition(PEGTransformer &transformer, const vector &dotted_identifier, + const optional &type, optional generated_column, + const bool &has_result, optional> column_constraint); static unique_ptr TransformColumnConstraintInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformNotNullConstraintInternal(PEGTransformer &transformer, @@ -1390,7 +1426,7 @@ class PEGTransformerFactory { ParseResult &parse_result); static ColumnConstraintEntry TransformForeignKeyConstraint(PEGTransformer &transformer, unique_ptr base_table_name, - const vector &column_list, + const optional> &column_list, const KeyActions &key_actions); static unique_ptr TransformColumnCollationInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -1402,8 +1438,8 @@ class PEGTransformerFactory { const Identifier &col_id_or_string); static unique_ptr TransformKeyActionsInternal(PEGTransformer &transformer, ParseResult &parse_result); - static KeyActions TransformKeyActions(PEGTransformer &transformer, const string &update_action, - const string &delete_action); + static KeyActions TransformKeyActions(PEGTransformer &transformer, const optional &update_action, + const optional &delete_action); static unique_ptr TransformUpdateActionInternal(PEGTransformer &transformer, ParseResult &parse_result); static string TransformUpdateAction(PEGTransformer &transformer, const string &key_action); @@ -1429,7 +1465,7 @@ class PEGTransformerFactory { static string TransformSetDefaultKeyAction(PEGTransformer &transformer); static unique_ptr TransformTopLevelConstraintInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformTopLevelConstraint(PEGTransformer &transformer, + static unique_ptr TransformTopLevelConstraint(PEGTransformer &transformer, const bool &has_result, unique_ptr top_level_constraint_list); static unique_ptr TransformTopLevelConstraintListInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -1454,7 +1490,7 @@ class PEGTransformerFactory { static unique_ptr TransformDottedIdentifierInternal(PEGTransformer &transformer, ParseResult &parse_result); static vector TransformDottedIdentifier(PEGTransformer &transformer, const Identifier &identifier, - const vector &dot_col_label); + const optional> &dot_col_label); static unique_ptr TransformDotColLabelInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformColIdInternal(PEGTransformer &transformer, @@ -1466,9 +1502,9 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformGeneratedColumnInternal(PEGTransformer &transformer, ParseResult &parse_result); - static GeneratedColumnDefinition TransformGeneratedColumn(PEGTransformer &transformer, + static GeneratedColumnDefinition TransformGeneratedColumn(PEGTransformer &transformer, const bool &has_result, unique_ptr expression, - const bool &generated_column_type); + const optional &generated_column_type); static unique_ptr TransformGeneratedColumnTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformCommitActionInternal(PEGTransformer &transformer, @@ -1491,11 +1527,11 @@ class PEGTransformerFactory { static unique_ptr TransformCreateTriggerStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr - TransformCreateTriggerStmt(PEGTransformer &transformer, const bool &if_not_exists, const Identifier &trigger_name, - const TriggerTiming &trigger_timing, const TriggerEventInfo &trigger_event, - unique_ptr base_table_name, - const TriggerTableReferencingInfo &referencing_clause, - const TriggerForEach &for_each_clause, unique_ptr trigger_body); + TransformCreateTriggerStmt(PEGTransformer &transformer, const optional &if_not_exists, + const Identifier &trigger_name, const TriggerTiming &trigger_timing, + const TriggerEventInfo &trigger_event, unique_ptr base_table_name, + const optional &referencing_clause, + const optional &for_each_clause, unique_ptr trigger_body); static unique_ptr TransformTriggerBodyInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformTriggerNameInternal(PEGTransformer &transformer, @@ -1505,7 +1541,7 @@ class PEGTransformerFactory { ParseResult &parse_result); static TriggerTableReferencingInfo TransformReferencingClause(PEGTransformer &transformer, const TriggerTableReferencingInfo &referencing_item, - const TriggerTableReferencingInfo &referencing_item_1); + const optional &referencing_item_1); static unique_ptr TransformReferencingItemInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformReferencingNewTableAsInternal(PEGTransformer &transformer, @@ -1555,7 +1591,8 @@ class PEGTransformerFactory { static TriggerForEach TransformForEachStatement(PEGTransformer &transformer); static unique_ptr TransformCreateTypeStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCreateTypeStmt(PEGTransformer &transformer, const bool &if_not_exists, + static unique_ptr TransformCreateTypeStmt(PEGTransformer &transformer, + const optional &if_not_exists, const QualifiedName &qualified_name, unique_ptr create_type); static unique_ptr TransformCreateTypeInternal(PEGTransformer &transformer, @@ -1570,13 +1607,14 @@ class PEGTransformerFactory { static unique_ptr TransformEnumStringLiteralListInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformEnumStringLiteralList(PEGTransformer &transformer, - const vector &string_literal); + const optional> &string_literal); static unique_ptr TransformCreateViewStmtInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr - TransformCreateViewStmt(PEGTransformer &transformer, const bool &create_recursive, const bool &if_not_exists, - const QualifiedName &qualified_name, const vector &insert_column_list, - case_insensitive_map_t> with_list, + TransformCreateViewStmt(PEGTransformer &transformer, const optional &create_recursive, + const optional &if_not_exists, const QualifiedName &qualified_name, + const optional> &insert_column_list, + optional>> with_list, unique_ptr select_statement_internal); static unique_ptr TransformCreateRecursiveInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -1584,28 +1622,28 @@ class PEGTransformerFactory { static unique_ptr TransformDeallocateStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformDeallocateStatement(PEGTransformer &transformer, - const bool &deallocate_prepare, + const optional &deallocate_prepare, const Identifier &identifier); static unique_ptr TransformDeallocatePrepareInternal(PEGTransformer &transformer, ParseResult &parse_result); static bool TransformDeallocatePrepare(PEGTransformer &transformer); static unique_ptr TransformDeleteStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDeleteStatement(PEGTransformer &transformer, - CommonTableExpressionMap with_clause, - unique_ptr target_opt_alias, - vector> delete_using_clause, - unique_ptr where_clause, - vector> returning_clause); + static unique_ptr + TransformDeleteStatement(PEGTransformer &transformer, optional with_clause, + unique_ptr target_opt_alias, + optional>> delete_using_clause, + optional> where_clause, + optional>> returning_clause); static unique_ptr TransformTruncateStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformTruncateStatement(PEGTransformer &transformer, + static unique_ptr TransformTruncateStatement(PEGTransformer &transformer, const bool &has_result, unique_ptr base_table_name); static unique_ptr TransformTargetOptAliasInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformTargetOptAlias(PEGTransformer &transformer, unique_ptr base_table_name, - const Identifier &col_id); + const bool &has_result, const optional &col_id); static unique_ptr TransformDeleteUsingClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); static vector> TransformDeleteUsingClause(PEGTransformer &transformer, @@ -1626,7 +1664,7 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformShowQualifiedName(PEGTransformer &transformer, const ShowType &show_or_describe_or_summarize, - DescribeTarget describe_target); + optional describe_target); static unique_ptr TransformShowTablesInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformShowTables(PEGTransformer &transformer, const ShowType &show_or_describe, @@ -1662,43 +1700,44 @@ class PEGTransformerFactory { static ShowType TransformDescRule(PEGTransformer &transformer); static unique_ptr TransformDetachStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDetachStatement(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformDetachStatement(PEGTransformer &transformer, const bool &has_result, + const optional &if_exists, const Identifier &catalog_name); static unique_ptr TransformDropStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformDropStatement(PEGTransformer &transformer, unique_ptr drop_entries, - const bool &drop_behavior); + const optional &drop_behavior); static unique_ptr TransformDropEntriesInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformDropTriggerInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropTrigger(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformDropTrigger(PEGTransformer &transformer, const optional &if_exists, const Identifier &trigger_name, unique_ptr base_table_name); static unique_ptr TransformDropTableInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformDropTable(PEGTransformer &transformer, const CatalogType &table_or_view, - const bool &if_exists, + const optional &if_exists, vector> base_table_name); static unique_ptr TransformDropTableFunctionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformDropTableFunction(PEGTransformer &transformer, const CatalogType &comment_macro_table, - const bool &if_exists, + const optional &if_exists, const vector &table_function_name); static unique_ptr TransformDropFunctionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformDropFunction(PEGTransformer &transformer, const bool &function_type_macro, - const bool &if_exists, + const optional &if_exists, const vector &function_identifier); static unique_ptr TransformDropSchemaInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropSchema(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformDropSchema(PEGTransformer &transformer, const optional &if_exists, const vector &qualified_schema_name); static unique_ptr TransformDropIndexInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropIndex(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformDropIndex(PEGTransformer &transformer, const optional &if_exists, const vector &qualified_index_name); static unique_ptr TransformQualifiedIndexNameInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -1718,22 +1757,23 @@ class PEGTransformerFactory { const Identifier &reserved_index_name); static unique_ptr TransformDropSequenceInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropSequence(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformDropSequence(PEGTransformer &transformer, const optional &if_exists, const vector &qualified_sequence_name); static unique_ptr TransformDropCollationInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropCollation(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformDropCollation(PEGTransformer &transformer, + const optional &if_exists, const vector &collation_name); static unique_ptr TransformDropTypeInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformDropType(PEGTransformer &transformer, const bool &if_exists, + static unique_ptr TransformDropType(PEGTransformer &transformer, const optional &if_exists, const vector &qualified_type_name); static unique_ptr TransformDropSecretInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformDropSecret(PEGTransformer &transformer, - const SecretPersistType &temporary, const bool &if_exists, - const Identifier &secret_name, - const Identifier &drop_secret_storage); + const optional &temporary, + const optional &if_exists, const Identifier &secret_name, + const optional &drop_secret_storage); static unique_ptr TransformTableOrViewInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformMaterializedViewEntryInternal(PEGTransformer &transformer, @@ -1773,13 +1813,15 @@ class PEGTransformerFactory { static Identifier TransformDropSecretStorage(PEGTransformer &transformer, const Identifier &identifier); static unique_ptr TransformExecuteStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformExecuteStatement(PEGTransformer &transformer, const Identifier &identifier, - vector table_function_arguments); + static unique_ptr + TransformExecuteStatement(PEGTransformer &transformer, const Identifier &identifier, + optional> table_function_arguments); static unique_ptr TransformExplainStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformExplainStatement(PEGTransformer &transformer, const bool &explain_analyze, - const vector &explain_option_list, - unique_ptr explainable_statements); + static unique_ptr + TransformExplainStatement(PEGTransformer &transformer, const optional &explain_analyze, + const optional> &explain_option_list, + unique_ptr explainable_statements); static unique_ptr TransformExplainAnalyzeInternal(PEGTransformer &transformer, ParseResult &parse_result); static bool TransformExplainAnalyze(PEGTransformer &transformer); @@ -1790,7 +1832,7 @@ class PEGTransformerFactory { static unique_ptr TransformExplainOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); static GenericCopyOption TransformExplainOption(PEGTransformer &transformer, const Identifier &explain_option_name, - unique_ptr expression); + optional> expression); static unique_ptr TransformExplainSelectStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr @@ -1799,9 +1841,10 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformExportStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformExportStatement(PEGTransformer &transformer, const string &export_source, - const string &string_literal, - const vector &generic_copy_option_list); + static unique_ptr + TransformExportStatement(PEGTransformer &transformer, const optional &export_source, + const string &string_literal, + const optional> &generic_copy_option_list); static unique_ptr TransformExportSourceInternal(PEGTransformer &transformer, ParseResult &parse_result); static string TransformExportSource(PEGTransformer &transformer, const Identifier &catalog_name); @@ -2757,11 +2800,12 @@ class PEGTransformerFactory { static unique_ptr TransformInsertStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr - TransformInsertStatement(PEGTransformer &transformer, CommonTableExpressionMap with_clause, - const OnConflictAction &or_action, unique_ptr insert_target, - const InsertColumnOrder &by_name_or_position, const vector &insert_column_list, - InsertValues insert_values, unique_ptr on_conflict_clause, - vector> returning_clause); + TransformInsertStatement(PEGTransformer &transformer, optional with_clause, + const optional &or_action, unique_ptr insert_target, + const optional &by_name_or_position, + const optional> &insert_column_list, InsertValues insert_values, + optional> on_conflict_clause, + optional>> returning_clause); static unique_ptr TransformOrActionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformInsertOrReplaceInternal(PEGTransformer &transformer, @@ -2786,7 +2830,7 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformInsertTarget(PEGTransformer &transformer, unique_ptr base_table_name, - const Identifier &insert_alias); + const optional &insert_alias); static unique_ptr TransformInsertAliasInternal(PEGTransformer &transformer, ParseResult &parse_result); static Identifier TransformInsertAlias(PEGTransformer &transformer, const Identifier &identifier); @@ -2808,15 +2852,15 @@ class PEGTransformerFactory { static unique_ptr TransformOnConflictClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformOnConflictClause(PEGTransformer &transformer, - OnConflictExpressionTarget on_conflict_target, + optional on_conflict_target, unique_ptr on_conflict_action); static unique_ptr TransformOnConflictTargetInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformOnConflictExpressionTargetInternal(PEGTransformer &transformer, ParseResult &parse_result); - static OnConflictExpressionTarget TransformOnConflictExpressionTarget(PEGTransformer &transformer, - const vector &column_id_list, - unique_ptr where_clause); + static OnConflictExpressionTarget + TransformOnConflictExpressionTarget(PEGTransformer &transformer, const vector &column_id_list, + optional> where_clause); static unique_ptr TransformOnConflictIndexTargetInternal(PEGTransformer &transformer, ParseResult &parse_result); static OnConflictExpressionTarget TransformOnConflictIndexTarget(PEGTransformer &transformer, @@ -2827,7 +2871,7 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformOnConflictUpdate(PEGTransformer &transformer, unique_ptr update_set_clause, - unique_ptr where_clause); + optional> where_clause); static unique_ptr TransformOnConflictNothingInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformOnConflictNothing(PEGTransformer &transformer); @@ -2839,20 +2883,20 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformLoadStatement(PEGTransformer &transformer, const Identifier &col_id_or_string, - const Identifier &extension_alias); + const optional &extension_alias); static unique_ptr TransformExtensionAliasInternal(PEGTransformer &transformer, ParseResult &parse_result); static Identifier TransformExtensionAlias(PEGTransformer &transformer, const Identifier &identifier); static unique_ptr TransformInstallStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformInstallStatement(PEGTransformer &transformer, + static unique_ptr TransformInstallStatement(PEGTransformer &transformer, const bool &has_result, const QualifiedName &identifier_or_string_literal, - const ExtensionRepositoryInfo &from_source, - const string &version_number); + const optional &from_source, + const optional &version_number); static unique_ptr TransformUpdateExtensionsStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformUpdateExtensionsStatement(PEGTransformer &transformer, - const vector &identifier); + const optional> &identifier); static unique_ptr TransformFromSourceInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformFromSourceIdentifierInternal(PEGTransformer &transformer, @@ -2869,11 +2913,11 @@ class PEGTransformerFactory { static unique_ptr TransformMergeIntoStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr - TransformMergeIntoStatement(PEGTransformer &transformer, CommonTableExpressionMap with_clause, + TransformMergeIntoStatement(PEGTransformer &transformer, optional with_clause, unique_ptr target_opt_alias, unique_ptr merge_into_using_clause, JoinQualifier join_qualifier, vector>> merge_match, - vector> returning_clause); + optional>> returning_clause); static unique_ptr TransformMergeIntoUsingClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformMergeIntoUsingClause(PEGTransformer &transformer, @@ -2883,14 +2927,14 @@ class PEGTransformerFactory { static unique_ptr TransformMatchedClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); static pair> - TransformMatchedClause(PEGTransformer &transformer, unique_ptr and_expression, + TransformMatchedClause(PEGTransformer &transformer, optional> and_expression, unique_ptr matched_clause_action); static unique_ptr TransformMatchedClauseActionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformUpdateMatchClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformUpdateMatchClause(PEGTransformer &transformer, - unique_ptr update_match_info); + static unique_ptr + TransformUpdateMatchClause(PEGTransformer &transformer, optional> update_match_info); static unique_ptr TransformUpdateMatchInfoInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformUpdateMatchSetActionInternal(PEGTransformer &transformer, @@ -2906,8 +2950,8 @@ class PEGTransformerFactory { static unique_ptr TransformDeleteMatchClause(PEGTransformer &transformer); static unique_ptr TransformInsertMatchClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformInsertMatchClause(PEGTransformer &transformer, - unique_ptr insert_match_info); + static unique_ptr + TransformInsertMatchClause(PEGTransformer &transformer, optional> insert_match_info); static unique_ptr TransformInsertMatchInfoInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformInsertDefaultValuesInternal(PEGTransformer &transformer, @@ -2915,12 +2959,13 @@ class PEGTransformerFactory { static unique_ptr TransformInsertDefaultValues(PEGTransformer &transformer); static unique_ptr TransformInsertByNameOrPositionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformInsertByNameOrPosition(PEGTransformer &transformer, - const InsertColumnOrder &by_name_or_position); + static unique_ptr + TransformInsertByNameOrPosition(PEGTransformer &transformer, const optional &by_name_or_position, + const bool &has_result); static unique_ptr TransformInsertValuesListInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformInsertValuesList(PEGTransformer &transformer, - const vector &insert_column_list, + const optional> &insert_column_list, vector> expression); static unique_ptr TransformDoNothingMatchClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -2928,7 +2973,7 @@ class PEGTransformerFactory { static unique_ptr TransformErrorMatchClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformErrorMatchClause(PEGTransformer &transformer, - unique_ptr expression); + optional> expression); static unique_ptr TransformUpdateMatchSetClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformUpdateMatchSetInfoInternal(PEGTransformer &transformer, @@ -2940,8 +2985,8 @@ class PEGTransformerFactory { static unique_ptr TransformNotMatchedClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); static pair> - TransformNotMatchedClause(PEGTransformer &transformer, const MergeActionCondition &by_source_or_target, - unique_ptr and_expression, + TransformNotMatchedClause(PEGTransformer &transformer, const optional &by_source_or_target, + optional> and_expression, unique_ptr matched_clause_action); static unique_ptr TransformBySourceOrTargetInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -3006,8 +3051,9 @@ class PEGTransformerFactory { vector> variable_list); static unique_ptr TransformPragmaFunctionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformPragmaFunction(PEGTransformer &transformer, const Identifier &pragma_name, - vector> pragma_parameters); + static unique_ptr + TransformPragmaFunction(PEGTransformer &transformer, const Identifier &pragma_name, + optional>> pragma_parameters); static unique_ptr TransformPragmaParametersInternal(PEGTransformer &transformer, ParseResult &parse_result); static vector> @@ -3015,7 +3061,7 @@ class PEGTransformerFactory { static unique_ptr TransformPrepareStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformPrepareStatement(PEGTransformer &transformer, const Identifier &identifier, - const vector &type_list, + const optional> &type_list, unique_ptr statement); static unique_ptr TransformTypeListInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -3612,7 +3658,7 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformZoneIntervalWithInterval(PEGTransformer &transformer, const string &string_literal, - const DatePartSpecifier &interval); + const optional &interval); static unique_ptr TransformZoneIntervalWithPrecisionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformZoneIntervalWithPrecision(PEGTransformer &transformer, @@ -3620,7 +3666,7 @@ class PEGTransformerFactory { const string &string_literal); static unique_ptr TransformSetSettingInternal(PEGTransformer &transformer, ParseResult &parse_result); - static SettingInfo TransformSetSetting(PEGTransformer &transformer, const SetScope &setting_scope, + static SettingInfo TransformSetSetting(PEGTransformer &transformer, const optional &setting_scope, const Identifier &setting_name); static unique_ptr TransformSetVariableInternal(PEGTransformer &transformer, ParseResult &parse_result); @@ -3652,14 +3698,14 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformBeginTransactionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformBeginTransaction(PEGTransformer &transformer, - const TransactionModifierType &read_or_write); + static unique_ptr TransformBeginTransaction(PEGTransformer &transformer, const bool &has_result, + const optional &read_or_write); static unique_ptr TransformRollbackTransactionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformRollbackTransaction(PEGTransformer &transformer); + static unique_ptr TransformRollbackTransaction(PEGTransformer &transformer, const bool &has_result); static unique_ptr TransformCommitTransactionInternal(PEGTransformer &transformer, ParseResult &parse_result); - static unique_ptr TransformCommitTransaction(PEGTransformer &transformer); + static unique_ptr TransformCommitTransaction(PEGTransformer &transformer, const bool &has_result); static unique_ptr TransformReadOrWriteInternal(PEGTransformer &transformer, ParseResult &parse_result); static TransactionModifierType TransformReadOrWrite(PEGTransformer &transformer, @@ -3675,10 +3721,11 @@ class PEGTransformerFactory { static unique_ptr TransformUpdateStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr - TransformUpdateStatement(PEGTransformer &transformer, CommonTableExpressionMap with_clause, + TransformUpdateStatement(PEGTransformer &transformer, optional with_clause, unique_ptr update_target, unique_ptr update_set_clause, - unique_ptr from_clause, unique_ptr where_clause, - vector> returning_clause); + optional> from_clause, + optional> where_clause, + optional>> returning_clause); static unique_ptr TransformUpdateTargetInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformBaseTableSetInternal(PEGTransformer &transformer, @@ -3689,10 +3736,11 @@ class PEGTransformerFactory { ParseResult &parse_result); static unique_ptr TransformBaseTableAliasSet(PEGTransformer &transformer, unique_ptr base_table_name, - const Identifier &update_alias); + const optional &update_alias); static unique_ptr TransformUpdateAliasInternal(PEGTransformer &transformer, ParseResult &parse_result); - static Identifier TransformUpdateAlias(PEGTransformer &transformer, const Identifier &col_id); + static Identifier TransformUpdateAlias(PEGTransformer &transformer, const bool &has_result, + const Identifier &col_id); static unique_ptr TransformUpdateSetClauseInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformUpdateSetTupleInternal(PEGTransformer &transformer, @@ -3713,7 +3761,7 @@ class PEGTransformerFactory { static unique_ptr TransformUpdateSetColumnTargetInternal(PEGTransformer &transformer, ParseResult &parse_result); static string TransformUpdateSetColumnTarget(PEGTransformer &transformer, const Identifier &column_name, - const vector &dot_identifier); + const optional> &dot_identifier); static unique_ptr TransformUseStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformUseStatement(PEGTransformer &transformer, const QualifiedName &use_target); @@ -3729,15 +3777,15 @@ class PEGTransformerFactory { ParseResult &parse_result); static QualifiedName TransformUseTargetCatalogSchema(PEGTransformer &transformer, const Identifier &catalog_name, const Identifier &reserved_schema_name, - const vector &dot_identifier); + const optional> &dot_identifier); static unique_ptr TransformDotIdentifierInternal(PEGTransformer &transformer, ParseResult &parse_result); static Identifier TransformDotIdentifier(PEGTransformer &transformer, const Identifier &identifier); static unique_ptr TransformVacuumStatementInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformVacuumStatement(PEGTransformer &transformer, - const VacuumOptions &vacuum_options, - AnalyzeTarget analyze_target); + const optional &vacuum_options, + optional analyze_target); static unique_ptr TransformVacuumOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformVacuumParensOptionsInternal(PEGTransformer &transformer, @@ -3745,9 +3793,10 @@ class PEGTransformerFactory { static VacuumOptions TransformVacuumParensOptions(PEGTransformer &transformer, const vector &vacuum_option); static unique_ptr TransformVacuumLegacyOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result); - static VacuumOptions TransformVacuumLegacyOptions(PEGTransformer &transformer, const string &opt_full, - const string &opt_freeze, const string &opt_verbose, - const string &opt_analyze); + static VacuumOptions TransformVacuumLegacyOptions(PEGTransformer &transformer, const optional &opt_full, + const optional &opt_freeze, + const optional &opt_verbose, + const optional &opt_analyze); static unique_ptr TransformVacuumOptionInternal(PEGTransformer &transformer, ParseResult &parse_result); static unique_ptr TransformOptAnalyzeInternal(PEGTransformer &transformer, diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index 2fa157e18..1cf246e28 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -472,6 +472,16 @@ class Binder : public enable_shared_from_this { BoundStatement ExpandTriggers(QueryNode &node, TableCatalogEntry &table, const vector> &before_triggers, const vector> &after_triggers); + unique_ptr TryExpandRowTriggers(QueryNode &node, + vector> &returning_list, + TableCatalogEntry &table, TriggerEventType event_type); + BoundStatement ExpandRowTriggers(QueryNode &node, vector> &returning_list, + const TableCatalogEntry &table, + const vector> &triggers); + //! Registers NEW as a generic binding so child binders resolve NEW.col at depth=1. The returned binder is + //! pushed onto GetActiveBinders(). the caller must keep it alive until the matching pop_back(). + unique_ptr SetupNewRowScope(TableIndex table_index, const vector &col_names, + const vector &col_types); BoundStatement BindNode(UpdateQueryNode &node); BoundStatement BindNode(DeleteQueryNode &node); BoundStatement BindNode(MergeQueryNode &node); diff --git a/src/duckdb/src/include/duckdb/planner/operator/list.hpp b/src/duckdb/src/include/duckdb/planner/operator/list.hpp index 82f3d6c08..004cd853f 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/list.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/list.hpp @@ -42,6 +42,7 @@ #include "duckdb/planner/operator/logical_set_operation.hpp" #include "duckdb/planner/operator/logical_simple.hpp" #include "duckdb/planner/operator/logical_top_n.hpp" +#include "duckdb/planner/operator/logical_trigger.hpp" #include "duckdb/planner/operator/logical_unnest.hpp" #include "duckdb/planner/operator/logical_update.hpp" #include "duckdb/planner/operator/logical_vacuum.hpp" diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_trigger.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_trigger.hpp new file mode 100644 index 000000000..604080260 --- /dev/null +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_trigger.hpp @@ -0,0 +1,52 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/planner/operator/logical_trigger.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/common/enums/trigger_type.hpp" + +namespace duckdb { + +//! LogicalTrigger represents a FOR EACH ROW trigger attached to a DML statement. +//! +//! child[0] = affected rows source (CTE scan of the fired DML's returning output) +//! child[1] = trigger body (correlated subplan NEW.col refs are BoundColumnRef at depth=1) +//! +//! This node is transient: a pre-decorrelation rewrite pass converts it into a LogicalDependentJoin +//! before FlattenDependentJoins::DecorrelateIndependent runs. +//! +//! Known limitations of this set-based model (from firing every row as one decorrelated batch +//! against a single snapshot, rather than iterating rows sequentially): +//! - Visibility: a firing cannot see rows written by an earlier firing in the same statement +//! (self-accumulating bodies; cascades into another row-triggered table are rejected at bind time). +//! - Order: rows fire in an unspecified order, and with multiple triggers each trigger fires for the whole +//! batch before the next (not PostgreSQL's per-row A,B / A,B interleave). Observable only with +//! order-sensitive side effects (sequences, clock_timestamp(), or reading prior firings' writes). +//! Correct semantics for all of these require per-row iterative execution (not yet implemented). +class LogicalTrigger : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_TRIGGER; + +public: + LogicalTrigger(string trigger_name, TriggerTiming timing, TriggerEventType event_type, + CorrelatedColumns correlated_columns); + + string trigger_name; + TriggerTiming timing; + TriggerEventType event_type; + //! The NEW.col references from child[1] that correlate with child[0] + CorrelatedColumns correlated_columns; + +protected: + vector GetColumnBindings() override; + void ResolveTypes() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp b/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp index 0cbc47e76..2dc09829a 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/flatten_dependent_join.hpp @@ -98,6 +98,8 @@ class FlattenDependentJoins { vector PushDownDistinct(unique_ptr &plan, vector state); vector PushDownExpressionGet(unique_ptr &plan, bool propagate_null_values, vector state); + vector PushDownDML(unique_ptr &plan, bool propagate_null_values, + vector state); vector PushDownGet(unique_ptr &plan, vector state); vector PushDownCTE(unique_ptr &plan, bool propagate_null_values, vector state); diff --git a/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp index 1b777fbba..7072f9e75 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp @@ -47,6 +47,11 @@ struct VariantStats { DUCKDB_API static BaseStatistics CreateUnknown(LogicalType type); DUCKDB_API static BaseStatistics CreateEmpty(LogicalType type); DUCKDB_API static BaseStatistics CreateShredded(const LogicalType &shredded_type); + //! Propagate statistics through a cast to VARIANT - builds fully-shredded VARIANT statistics describing + //! a (possibly nested) non-variant value of `source_type` with statistics `child_stats`. + //! Returns nullptr when the type can not be represented as a single consistent shredding. + DUCKDB_API static unique_ptr StatisticsPropagateToVariant(const LogicalType &source_type, + const BaseStatistics &child_stats); public: //! Stats related to the 'unshredded' column, which holds all data that doesn't fit in the structure of the shredded diff --git a/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp index 078b6a59c..d4890fa7a 100644 --- a/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/variant_column_data.hpp @@ -78,6 +78,10 @@ class VariantColumnData : public ColumnData { void Verify(RowGroup &parent) override; static void ShredVariantData(const Vector &input, Vector &output, idx_t count); + //! Debug/verification helper: shred a (top-level) VARIANT vector in place, deriving the shredding + //! schema from the first value (so inconsistent values become partially shredded). No-op if the + //! vector is already shredded or the first value yields no shreddable type. + static void DebugShred(Vector &variant, idx_t count); void SetValidityData(shared_ptr validity_p); void SetChildData(vector> child_data); diff --git a/src/duckdb/src/main/config.cpp b/src/duckdb/src/main/config.cpp index 0d24bfc6c..d1fab8b1a 100644 --- a/src/duckdb/src/main/config.cpp +++ b/src/duckdb/src/main/config.cpp @@ -104,6 +104,8 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING(DebugCheckpointSleepMsSetting), DUCKDB_SETTING(DebugDisableOptimizerSetting), DUCKDB_SETTING(DebugEvictionQueueSleepMicroSecondsSetting), + DUCKDB_SETTING(DebugForceCommitFailureSetting), + DUCKDB_SETTING(DebugForceCommitRevertFailureSetting), DUCKDB_SETTING(DebugForceExternalSetting), DUCKDB_SETTING(DebugForceFetchRowSetting), DUCKDB_SETTING(DebugForceNoCrossProductSetting), @@ -225,6 +227,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING_CALLBACK(StorageBlockPrefetchSetting), DUCKDB_GLOBAL(StorageCompatibilityVersionSetting), DUCKDB_LOCAL(StreamingBufferSizeSetting), + DUCKDB_SETTING_CALLBACK(TableFunctionIdentifierConversionSetting), DUCKDB_GLOBAL(TempDirectorySetting), DUCKDB_SETTING_CALLBACK(TempFileEncryptionSetting), DUCKDB_GLOBAL(ThreadsSetting), @@ -242,12 +245,12 @@ static const ConfigurationOption internal_options[] = { static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("configure_metrics", 29), DUCKDB_SETTING_ALIAS("custom_profiling_settings", 29), - DUCKDB_SETTING_ALIAS("memory_limit", 124), - DUCKDB_SETTING_ALIAS("null_order", 57), - DUCKDB_SETTING_ALIAS("profile_output", 147), - DUCKDB_SETTING_ALIAS("user", 163), + DUCKDB_SETTING_ALIAS("memory_limit", 126), + DUCKDB_SETTING_ALIAS("null_order", 59), + DUCKDB_SETTING_ALIAS("profile_output", 149), + DUCKDB_SETTING_ALIAS("user", 166), DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 28), - DUCKDB_SETTING_ALIAS("worker_threads", 161), + DUCKDB_SETTING_ALIAS("worker_threads", 164), FINAL_ALIAS}; vector DBConfig::GetOptions() { @@ -481,6 +484,11 @@ LogicalType DBConfig::ParseLogicalType(const string &type) { return LogicalType::UNION(union_members); } + if (type == "STRUCT") { + // empty struct + return LogicalType::STRUCT({}); + } + if (StringUtil::StartsWith(type, "STRUCT(") && StringUtil::EndsWith(type, ")")) { // struct - recurse string struct_members_str = type.substr(7, type.size() - 8); diff --git a/src/duckdb/src/main/settings/autogenerated_settings.cpp b/src/duckdb/src/main/settings/autogenerated_settings.cpp index 4b9b8c79d..8c650a89a 100644 --- a/src/duckdb/src/main/settings/autogenerated_settings.cpp +++ b/src/duckdb/src/main/settings/autogenerated_settings.cpp @@ -216,6 +216,13 @@ void StorageBlockPrefetchSetting::OnSet(SettingCallbackInfo &info, Value ¶me EnumUtil::FromString(StringValue::Get(parameter)); } +//===----------------------------------------------------------------------===// +// Table Function Identifier Conversion +//===----------------------------------------------------------------------===// +void TableFunctionIdentifierConversionSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + EnumUtil::FromString(StringValue::Get(parameter)); +} + //===----------------------------------------------------------------------===// // Validate External File Cache //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/optimizer/common_subplan_optimizer.cpp b/src/duckdb/src/optimizer/common_subplan_optimizer.cpp index 3b73f6843..4adeaf0cd 100644 --- a/src/duckdb/src/optimizer/common_subplan_optimizer.cpp +++ b/src/duckdb/src/optimizer/common_subplan_optimizer.cpp @@ -11,6 +11,8 @@ #include "duckdb/planner/filter/expression_filter.hpp" #include "duckdb/planner/column_binding_map.hpp" +#include + namespace duckdb { //===--------------------------------------------------------------------===// @@ -722,46 +724,59 @@ class CommonSubplanFinder { } } - // Collect all subplan bindings, and figure out which subplan has the most outgoing bindings - idx_t max_subplan_idx = 0; - for (idx_t subplan_idx = 0; subplan_idx < subplan_info.subplans.size(); subplan_idx++) { - const auto &subplan_bindings = subplan_info.subplans[subplan_idx].canonical_bindings; - const auto &max_subplan_bindings = subplan_info.subplans[max_subplan_idx].canonical_bindings; - if (subplan_bindings.size() > max_subplan_bindings.size()) { - max_subplan_idx = subplan_idx; - } - } - - // Move the "maximum subplan" to the front - std::swap(subplan_info.subplans[0], subplan_info.subplans[max_subplan_idx]); - // We can bail on a subplan for various reasons (some of which could potentially be fixed) bool bail = false; - // Insert the bindings of the subplan with the most bindings into a set - column_binding_set_t max_subplan_column_binding_set; - for (auto &cb : subplan_info.subplans[0].canonical_bindings) { - if (max_subplan_column_binding_set.find(cb) != max_subplan_column_binding_set.end()) { - bail = true; // Subplan contains duplicate column bindings, i.e., another nested duplicate subplan + column_binding_set_t required_bindings; + for (auto &subplan : subplan_info.subplans) { + column_binding_set_t subplan_bindings; + for (auto &cb : subplan.canonical_bindings) { + if (subplan_bindings.find(cb) != subplan_bindings.end()) { + bail = + true; // Subplan contains duplicate column bindings, i.e., another nested duplicate subplan + break; + } + subplan_bindings.insert(cb); + required_bindings.insert(cb); + } + if (bail) { break; } - max_subplan_column_binding_set.insert(cb); } - // Check if the maximum subplan fully contains the column bindings of the other subplans - for (idx_t subplan_idx = 1; subplan_idx < subplan_info.subplans.size() && !bail; subplan_idx++) { - const auto &subplan_bindings = subplan_info.subplans[subplan_idx].canonical_bindings; - for (auto &cb : subplan_bindings) { - if (max_subplan_column_binding_set.find(cb) == max_subplan_column_binding_set.end()) { - bail = true; // Subplan does not fully contain the other subplans + idx_t primary_subplan_idx = subplan_info.subplans.size(); + idx_t primary_subplan_binding_count = 0; + for (idx_t subplan_idx = 0; subplan_idx < subplan_info.subplans.size() && !bail; subplan_idx++) { + const auto expanded_bindings = + GetExpandedCanonicalBindings(*subplan_info.subplans[subplan_idx].op.get()); + bool contains_required_bindings = true; + for (auto &cb : required_bindings) { + if (std::find(expanded_bindings.begin(), expanded_bindings.end(), cb) == expanded_bindings.end()) { + contains_required_bindings = false; break; } } + if (!contains_required_bindings) { + continue; + } + auto &subplan = subplan_info.subplans[subplan_idx]; + if (primary_subplan_idx == subplan_info.subplans.size() || + subplan.canonical_bindings.size() > primary_subplan_binding_count) { + primary_subplan_idx = subplan_idx; + primary_subplan_binding_count = subplan.canonical_bindings.size(); + } + } + if (primary_subplan_idx == subplan_info.subplans.size()) { + bail = true; // None of the subplans can expose every binding required by the duplicate occurrences } if (bail) { to_remove.push_back(signature); + continue; } + + // Move the primary subplan to the front + std::swap(subplan_info.subplans[0], subplan_info.subplans[primary_subplan_idx]); } // Only remove them all at the end so the logic above doesn't get affected @@ -788,18 +803,61 @@ class CommonSubplanFinder { // Resolve types to be used for creating the materialized CTE and refs op->ResolveOperatorTypes(); - // Get types and names const auto &primary_subplan = subplan_info.subplans[0]; - const auto &types = primary_subplan.op.get()->types; + + vector> old_bindings(subplan_info.subplans.size()); + vector> old_types(subplan_info.subplans.size()); + arena_vector required_canonical_bindings(state.allocator); + column_binding_set_t required_canonical_binding_set; + for (idx_t subplan_idx = 0; subplan_idx < subplan_info.subplans.size(); subplan_idx++) { + auto &subplan = subplan_info.subplans[subplan_idx]; + old_bindings[subplan_idx] = subplan.op.get()->GetColumnBindings(); + old_types[subplan_idx] = subplan.op.get()->types; + for (auto &cb : subplan.canonical_bindings) { + if (required_canonical_binding_set.find(cb) != required_canonical_binding_set.end()) { + continue; + } + required_canonical_binding_set.insert(cb); + required_canonical_bindings.push_back(cb); + } + } + + if (required_canonical_bindings.size() != primary_subplan.canonical_bindings.size()) { + // The signature ignores projection maps. If duplicate subplans expose different + // subsets of the same work, widen the materialized producer just enough to make + // every referenced output available, then project each reader back to its original + // schema below. + ClearProjectionMaps(*primary_subplan.op.get()); + primary_subplan.op.get()->ResolveOperatorTypes(); + } + + const auto materialized_bindings = primary_subplan.op.get()->GetColumnBindings(); + const auto materialized_canonical_bindings = GetCanonicalBindings(*primary_subplan.op.get()); + column_binding_map_t materialized_binding_index; + for (idx_t i = 0; i < materialized_canonical_bindings.size(); i++) { + materialized_binding_index.emplace(materialized_canonical_bindings[i], i); + } + + vector types; + vector materialized_output_bindings; + types.reserve(required_canonical_bindings.size()); + materialized_output_bindings.reserve(required_canonical_bindings.size()); + column_binding_map_t cte_binding_index; + for (idx_t i = 0; i < required_canonical_bindings.size(); i++) { + auto &cb = required_canonical_bindings[i]; + const auto entry = materialized_binding_index.find(cb); + D_ASSERT(entry != materialized_binding_index.end()); // guaranteed by FilterSubplans + const auto materialized_col_idx = entry->second; + cte_binding_index.emplace(cb, i); + types.push_back(primary_subplan.op.get()->types[materialized_col_idx]); + materialized_output_bindings.push_back(materialized_bindings[materialized_col_idx]); + } + + // Get names vector col_names; for (idx_t i = 0; i < types.size(); i++) { col_names.emplace_back(StringUtil::Format("%s_col_%llu", cte_name, i + 1)); } - const auto &primary_subplan_bindings = primary_subplan.canonical_bindings; - column_binding_map_t primary_binding_index; - for (idx_t i = 0; i < primary_subplan_bindings.size(); i++) { - primary_binding_index.emplace(primary_subplan_bindings[i], i); - } vector> cte_column_indexes(subplan_info.subplans.size()); vector needs_projection(subplan_info.subplans.size(), false); for (idx_t subplan_idx = 0; subplan_idx < subplan_info.subplans.size(); subplan_idx++) { @@ -809,11 +867,11 @@ class CommonSubplanFinder { needs_projection[subplan_idx] = canonical_bindings.size() != types.size(); for (idx_t i = 0; i < canonical_bindings.size(); i++) { const auto &cb = canonical_bindings[i]; - const auto entry = primary_binding_index.find(cb); - D_ASSERT(entry != primary_binding_index.end()); // guaranteed by FilterSubplans + const auto entry = cte_binding_index.find(cb); + D_ASSERT(entry != cte_binding_index.end()); // guaranteed by FilterSubplans const auto cte_col_idx = entry->second; // Types must match: same canonical binding = same base column = same type - D_ASSERT(subplan.op.get()->types[i] == types[cte_col_idx]); + D_ASSERT(old_types[subplan_idx][i] == types[cte_col_idx]); cte_column_indexes[subplan_idx].push_back(cte_col_idx); needs_projection[subplan_idx] = needs_projection[subplan_idx] || cte_col_idx != i; } @@ -829,11 +887,10 @@ class CommonSubplanFinder { if (subplan.op.get()->has_estimated_cardinality) { cte_refs.back()->SetEstimatedCardinality(subplan.op.get()->estimated_cardinality); } - const auto old_bindings = subplan.op.get()->GetColumnBindings(); auto new_bindings = cte_refs.back()->GetColumnBindings(); if (needs_projection[subplan_idx]) { // Preserve each subplan's original output order when it differs from the - // primary materialized CTE. + // materialized CTE. vector> select_list; for (auto cte_col_idx : cte_column_indexes[subplan_idx]) { select_list.emplace_back(make_uniq( @@ -847,9 +904,9 @@ class CommonSubplanFinder { cte_refs.back() = std::move(proj); new_bindings = cte_refs.back()->GetColumnBindings(); } - D_ASSERT(old_bindings.size() == new_bindings.size()); - for (idx_t i = 0; i < old_bindings.size(); i++) { - replacer.replacement_bindings.emplace_back(old_bindings[i], new_bindings[i]); + D_ASSERT(old_bindings[subplan_idx].size() == new_bindings.size()); + for (idx_t i = 0; i < old_bindings[subplan_idx].size(); i++) { + replacer.replacement_bindings.emplace_back(old_bindings[subplan_idx][i], new_bindings[i]); } } @@ -859,10 +916,9 @@ class CommonSubplanFinder { auto materialized_subplan = std::move(primary_subplan.op.get()); auto remainder = std::move(lowest_common_ancestor); vector> materialized_select_list; - const auto materialized_bindings = materialized_subplan->GetColumnBindings(); - for (idx_t i = 0; i < materialized_bindings.size(); i++) { + for (idx_t i = 0; i < materialized_output_bindings.size(); i++) { materialized_select_list.emplace_back( - make_uniq(types[i], materialized_bindings[i])); + make_uniq(types[i], materialized_output_bindings[i])); } auto materialized_projection = make_uniq(optimizer.binder.GenerateTableIndex(), std::move(materialized_select_list)); @@ -919,21 +975,8 @@ class CommonSubplanFinder { arena_vector GetCanonicalBindings(LogicalOperator &op) { // Compute the canonical column bindings coming out of this operator for convenience later - const auto &table_index_map = state.table_index_map.GetMap(); const auto original_bindings = op.GetColumnBindings(); - arena_vector canonical_bindings(state.allocator); - for (idx_t col_idx = 0; col_idx < original_bindings.size(); col_idx++) { - auto &cb = original_bindings[col_idx]; - const auto canonical_table_index = to_canonical_table_index.at(cb.table_index); - auto &table_map = table_index_map.at(cb.table_index); - if (table_map.Empty()) { - canonical_bindings.emplace_back(canonical_table_index, cb.column_index); - } else { - const auto canonical_col_idx = table_map.Get(cb.column_index); - canonical_bindings.emplace_back(canonical_table_index, canonical_col_idx); - } - } - return canonical_bindings; + return GetCanonicalBindings(original_bindings); } unique_ptr &LowestCommonAncestor(reference> a, @@ -966,6 +1009,86 @@ class CommonSubplanFinder { return a.get(); } + vector GetExpandedColumnBindings(LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_FILTER: + case LogicalOperatorType::LOGICAL_ORDER_BY: + return GetExpandedColumnBindings(*op.children[0]); + case LogicalOperatorType::LOGICAL_ANY_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + auto &join = op.Cast(); + auto left_bindings = GetExpandedColumnBindings(*op.children[0]); + if (join.join_type == JoinType::SEMI || join.join_type == JoinType::ANTI) { + return left_bindings; + } + if (join.join_type == JoinType::MARK) { + left_bindings.emplace_back(join.mark_index, ProjectionIndex(0)); + return left_bindings; + } + auto right_bindings = GetExpandedColumnBindings(*op.children[1]); + if (join.join_type == JoinType::RIGHT_SEMI || join.join_type == JoinType::RIGHT_ANTI) { + return right_bindings; + } + left_bindings.insert(left_bindings.end(), right_bindings.begin(), right_bindings.end()); + return left_bindings; + } + case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: + case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: { + auto left_bindings = GetExpandedColumnBindings(*op.children[0]); + auto right_bindings = GetExpandedColumnBindings(*op.children[1]); + left_bindings.insert(left_bindings.end(), right_bindings.begin(), right_bindings.end()); + return left_bindings; + } + default: + return op.GetColumnBindings(); + } + } + + arena_vector GetCanonicalBindings(const vector &original_bindings) { + const auto &table_index_map = state.table_index_map.GetMap(); + arena_vector canonical_bindings(state.allocator); + for (auto &cb : original_bindings) { + const auto canonical_table_index = to_canonical_table_index.at(cb.table_index); + auto &table_map = table_index_map.at(cb.table_index); + if (table_map.Empty()) { + canonical_bindings.emplace_back(canonical_table_index, cb.column_index); + } else { + const auto canonical_col_idx = table_map.Get(cb.column_index); + canonical_bindings.emplace_back(canonical_table_index, canonical_col_idx); + } + } + return canonical_bindings; + } + + arena_vector GetExpandedCanonicalBindings(LogicalOperator &op) { + return GetCanonicalBindings(GetExpandedColumnBindings(op)); + } + + static void ClearProjectionMaps(LogicalOperator &op) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_ANY_JOIN: + case LogicalOperatorType::LOGICAL_ASOF_JOIN: + case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { + auto &join = op.Cast(); + join.left_projection_map.clear(); + join.right_projection_map.clear(); + break; + } + case LogicalOperatorType::LOGICAL_FILTER: + op.Cast().projection_map.clear(); + break; + case LogicalOperatorType::LOGICAL_ORDER_BY: + op.Cast().projection_map.clear(); + break; + default: + break; + } + for (auto &child : op.children) { + ClearProjectionMaps(*child); + } + } + bool ShouldMaterialize(const SubplanInfo &subplan_info) const { auto &subplan = subplan_info.subplans[0].op.get(); return CTEInlining::EndsInAggregateOrDistinct(*subplan) || IsSelectiveMultiTablePlan(subplan); diff --git a/src/duckdb/src/optimizer/join_order/relation_manager.cpp b/src/duckdb/src/optimizer/join_order/relation_manager.cpp index 14384075c..6eed88094 100644 --- a/src/duckdb/src/optimizer/join_order/relation_manager.cpp +++ b/src/duckdb/src/optimizer/join_order/relation_manager.cpp @@ -134,6 +134,11 @@ static bool OperatorIsNonReorderable(LogicalOperatorType op_type) { case LogicalOperatorType::LOGICAL_INTERSECT: case LogicalOperatorType::LOGICAL_ANY_JOIN: case LogicalOperatorType::LOGICAL_ASOF_JOIN: + // DML operators have side effects and must never be reordered away or passed through (a correlated + // trigger body can place an INSERT/UPDATE/DELETE mid-plan, as a cross-product child for example). + case LogicalOperatorType::LOGICAL_INSERT: + case LogicalOperatorType::LOGICAL_UPDATE: + case LogicalOperatorType::LOGICAL_DELETE: return true; default: return false; @@ -265,8 +270,9 @@ bool RelationManager::ExtractJoinRelations(JoinOrderOptimizer &optimizer, Logica optional_ptr op = &input_op; vector> datasource_filters; optional_ptr limit_op = nullptr; - // pass through single child operators - while (op->children.size() == 1 && !OperatorNeedsRelation(op->type)) { + // pass through single child operators (but never past a non-reorderable op such as a DML node, which + // must be preserved as its own relation rather than skipped over and dropped during reconstruction) + while (op->children.size() == 1 && !OperatorNeedsRelation(op->type) && !OperatorIsNonReorderable(op->type)) { if (op->type == LogicalOperatorType::LOGICAL_FILTER) { if (HasNonReorderableChild(*op)) { datasource_filters.push_back(*op); diff --git a/src/duckdb/src/optimizer/remove_unused_columns.cpp b/src/duckdb/src/optimizer/remove_unused_columns.cpp index ef6168622..a93189bec 100644 --- a/src/duckdb/src/optimizer/remove_unused_columns.cpp +++ b/src/duckdb/src/optimizer/remove_unused_columns.cpp @@ -357,7 +357,26 @@ void RemoveUnusedColumns::VisitOperator(unique_ptr &op_ref) { // Gather all scans of this CTE in the query and mark them as expected readers of this CTE GatherCTEScans(cte.table_index, *cte.children[1], cte_map_entry.expected_readers); cte_map_entry.everything_referenced = false; - RemoveUnusedColumns rhs_child_optimizer(*this, true); + auto output_bindings = cte.GetColumnBindings(); + bool has_output_references = false; + for (auto &binding : output_bindings) { + if (column_references.find(binding) != column_references.end()) { + has_output_references = true; + break; + } + } + + // A materialized CTE returns the RHS/continuation columns. If this pass is allowed to prune outputs and a + // parent registered explicit references to those columns, use them to prune the continuation. Respect + // everything_referenced barriers, e.g. DML operators that rely on a fixed child schema. + auto prune_rhs_outputs = !everything_referenced && has_output_references; + RemoveUnusedColumns rhs_child_optimizer(*this, !prune_rhs_outputs); + if (prune_rhs_outputs) { + rhs_child_optimizer.column_references.reserve(column_references.size()); + for (auto &entry : column_references) { + rhs_child_optimizer.column_references.emplace(entry.first, entry.second); + } + } rhs_child_optimizer.VisitOperator(cte.children[1]); unordered_set referenced_columns_in_rhs; diff --git a/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp b/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp index 6d972f4c5..c1a9bcab8 100644 --- a/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp +++ b/src/duckdb/src/optimizer/statistics/expression/propagate_cast.cpp @@ -224,6 +224,10 @@ unique_ptr StatisticsPropagator::TryPropagateCast(const BaseStat if (source.id() == LogicalTypeId::VARIANT) { return StatisticsPropagateVariant(stats, target); } + if (target.id() == LogicalTypeId::VARIANT) { + // the cast shreds every value into a single bucket - mirror the (possibly nested) source as typed stats + return VariantStats::StatisticsPropagateToVariant(source, stats); + } if (!CanPropagateCast(source, target)) { return nullptr; } diff --git a/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp b/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp index f4a5091c7..31dce43e1 100644 --- a/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp +++ b/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp @@ -147,9 +147,6 @@ void PEGTransformerFactory::RegisterExpression() { } void PEGTransformerFactory::RegisterConnect() { - // connect.gram — both rules are hand-written; the generator skips them because of the - // optional SessionTarget sub-rule. - REGISTER_TRANSFORM(TransformConnectStatement); } void PEGTransformerFactory::RegisterPivot() { diff --git a/src/duckdb/src/parser/peg/transformer/transform_alter.cpp b/src/duckdb/src/parser/peg/transformer/transform_alter.cpp index 837308e34..1bbe66149 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_alter.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_alter.cpp @@ -40,7 +40,7 @@ unique_ptr PEGTransformerFactory::TransformAlterStatement(PEGTrans } unique_ptr -PEGTransformerFactory::TransformAlterTableStmt(PEGTransformer &transformer, const bool &if_exists, +PEGTransformerFactory::TransformAlterTableStmt(PEGTransformer &transformer, const optional &if_exists, unique_ptr base_table_name, vector> alter_table_options) { if (alter_table_options.size() > 1) { @@ -56,7 +56,7 @@ PEGTransformerFactory::TransformAlterTableStmt(PEGTransformer &transformer, cons } unique_ptr PEGTransformerFactory::TransformAlterDatabaseStmt(PEGTransformer &transformer, - const bool &if_exists, + const optional &if_exists, const Identifier &identifier, const Identifier &identifier_1) { OnEntryNotFound not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; @@ -66,7 +66,8 @@ unique_ptr PEGTransformerFactory::TransformAlterDatabaseStmt(PEGTrans return std::move(result); } -unique_ptr PEGTransformerFactory::TransformAlterViewStmt(PEGTransformer &transformer, const bool &if_exists, +unique_ptr PEGTransformerFactory::TransformAlterViewStmt(PEGTransformer &transformer, + const optional &if_exists, unique_ptr base_table_name, unique_ptr rename_alter) { auto rename_table = unique_ptr_cast(std::move(rename_alter)); @@ -79,14 +80,14 @@ unique_ptr PEGTransformerFactory::TransformAlterViewStmt(PEGTransform } unique_ptr PEGTransformerFactory::TransformAlterSchemaStmt(PEGTransformer &transformer, - const bool &if_exists, + const optional &if_exists, const QualifiedName &qualified_name, unique_ptr rename_alter) { throw NotImplementedException("Altering schemas is not yet supported"); } unique_ptr PEGTransformerFactory::TransformAlterSequenceStmt(PEGTransformer &transformer, - const bool &if_exists, + const optional &if_exists, const QualifiedName &qualified_sequence_name, unique_ptr alter_sequence_options) { if (qualified_sequence_name.schema.empty()) { @@ -101,12 +102,12 @@ unique_ptr PEGTransformerFactory::TransformAlterSequenceStmt(PEGTrans } QualifiedName PEGTransformerFactory::TransformQualifiedSequenceName(PEGTransformer &transformer, - const Identifier &catalog_qualification, - const Identifier &schema_qualification, + const optional &catalog_qualification, + const optional &schema_qualification, const Identifier &sequence_name) { QualifiedName result; - result.catalog = catalog_qualification.empty() ? INVALID_CATALOG : catalog_qualification; - result.schema = schema_qualification.empty() ? INVALID_SCHEMA : schema_qualification; + result.catalog = catalog_qualification ? *catalog_qualification : INVALID_CATALOG; + result.schema = schema_qualification ? *schema_qualification : INVALID_SCHEMA; result.name = sequence_name; return result; } @@ -206,7 +207,8 @@ unique_ptr PEGTransformerFactory::TransformAndMaterializeAlter( } unique_ptr PEGTransformerFactory::TransformAddColumn(PEGTransformer &transformer, - const bool &if_not_exists, + const bool &has_result, + const optional &if_not_exists, AddColumnEntry add_column_entry) { auto column_definition = ColumnDefinition(add_column_entry.column_path.back(), add_column_entry.type); if (add_column_entry.default_value) { @@ -214,26 +216,26 @@ unique_ptr PEGTransformerFactory::TransformAddColumn(PEGTransfor } unique_ptr result; + auto if_not_exists_value = if_not_exists.has_value(); if (add_column_entry.column_path.size() == 1) { - result = make_uniq(AlterEntryData(), std::move(column_definition), if_not_exists); + result = make_uniq(AlterEntryData(), std::move(column_definition), if_not_exists_value); } else { const auto parent_path = vector(add_column_entry.column_path.begin(), add_column_entry.column_path.end() - 1); - result = make_uniq(AlterEntryData(), parent_path, std::move(column_definition), if_not_exists); + result = + make_uniq(AlterEntryData(), parent_path, std::move(column_definition), if_not_exists_value); } return result; } -AddColumnEntry PEGTransformerFactory::TransformAddColumnEntry(PEGTransformer &transformer, - const vector &dotted_identifier, - const LogicalType &type, - GeneratedColumnDefinition generated_column, - vector column_constraint) { +AddColumnEntry PEGTransformerFactory::TransformAddColumnEntry( + PEGTransformer &transformer, const vector &dotted_identifier, const optional &type, + optional generated_column, optional> column_constraint) { AddColumnEntry new_column; new_column.column_path = StringsToIdentifiers(dotted_identifier); - bool has_type = type != LogicalType::INVALID; - bool has_generated = generated_column.expr != nullptr; + bool has_type = type.has_value(); + bool has_generated = generated_column && generated_column->expr != nullptr; // TODO(Dtenwolde) this checking logic should be moved to the binder if (!has_type && !has_generated) { throw ParserException("Column definition requires a type or generated expression"); @@ -241,34 +243,40 @@ AddColumnEntry PEGTransformerFactory::TransformAddColumnEntry(PEGTransformer &tr if (has_generated) { throw ParserException("Adding generated columns after table creation is not supported yet"); } - new_column.type = type; - for (auto &constraint : column_constraint) { - if (constraint.constraint_name == "DefaultValue") { - if (new_column.default_value) { - throw ParserException("Cannot define a default value twice"); + if (type) { + new_column.type = *type; + } + if (column_constraint) { + for (auto &constraint : *column_constraint) { + if (constraint.constraint_name == "DefaultValue") { + if (new_column.default_value) { + throw ParserException("Cannot define a default value twice"); + } + new_column.default_value = std::move(constraint.expression); } - new_column.default_value = std::move(constraint.expression); } } return new_column; } -unique_ptr -PEGTransformerFactory::TransformDropColumn(PEGTransformer &transformer, const bool &if_exists, - unique_ptr nested_column_name, - const bool &drop_behavior) { +unique_ptr PEGTransformerFactory::TransformDropColumn( + PEGTransformer &transformer, const bool &has_result, const optional &if_exists, + unique_ptr nested_column_name, const optional &drop_behavior) { + auto if_exists_value = if_exists.has_value(); + auto drop_behavior_value = drop_behavior ? *drop_behavior : false; if (nested_column_name->ColumnNames().size() == 1) { - auto result = make_uniq( - AlterEntryData(), nested_column_name->ColumnNames()[0].GetIdentifierName(), if_exists, drop_behavior); + auto result = + make_uniq(AlterEntryData(), nested_column_name->ColumnNames()[0].GetIdentifierName(), + if_exists_value, drop_behavior_value); return std::move(result); } - auto result = - make_uniq(AlterEntryData(), nested_column_name->ColumnNames(), if_exists, drop_behavior); + auto result = make_uniq(AlterEntryData(), nested_column_name->ColumnNames(), if_exists_value, + drop_behavior_value); return std::move(result); } unique_ptr -PEGTransformerFactory::TransformAlterColumn(PEGTransformer &transformer, +PEGTransformerFactory::TransformAlterColumn(PEGTransformer &transformer, const bool &has_result, unique_ptr nested_column_name, unique_ptr alter_column_entry) { if (alter_column_entry->alter_table_type == AlterTableType::SET_DEFAULT) { @@ -310,14 +318,19 @@ unique_ptr PEGTransformerFactory::TransformChangeNullability(PEG } } -unique_ptr PEGTransformerFactory::TransformAlterType(PEGTransformer &transformer, - const LogicalType &type, - unique_ptr using_expression) { - if (type == LogicalType::INVALID && !using_expression) { +unique_ptr +PEGTransformerFactory::TransformAlterType(PEGTransformer &transformer, const bool &has_result, + const optional &type, + optional> using_expression) { + if (!type && !using_expression) { throw ParserException("Omitting the type is only possible in combination with USING"); } - auto alter_type = type == LogicalType::INVALID ? LogicalType::UNKNOWN : type; - return make_uniq(AlterEntryData(), "", alter_type, std::move(using_expression)); + auto alter_type = type ? *type : LogicalType::UNKNOWN; + unique_ptr expression; + if (using_expression) { + expression = std::move(*using_expression); + } + return make_uniq(AlterEntryData(), "", alter_type, std::move(expression)); } unique_ptr PEGTransformerFactory::TransformUsingExpression(PEGTransformer &transformer, @@ -330,8 +343,10 @@ unique_ptr PEGTransformerFactory::TransformAddDefault(PEGTransfo return make_uniq(AlterEntryData(), "", std::move(expression)); } -unique_ptr PEGTransformerFactory::TransformRenameColumn( - PEGTransformer &transformer, unique_ptr nested_column_name, const Identifier &identifier) { +unique_ptr +PEGTransformerFactory::TransformRenameColumn(PEGTransformer &transformer, const bool &has_result, + unique_ptr nested_column_name, + const Identifier &identifier) { if (nested_column_name->ColumnNames().size() == 1) { auto result = make_uniq(AlterEntryData(), nested_column_name->ColumnNames()[0], identifier); return std::move(result); @@ -400,10 +415,12 @@ PEGTransformerFactory::TransformResetOptions(PEGTransformer &transformer, return make_uniq(AlterEntryData(), std::move(option_names)); } -unique_ptr -PEGTransformerFactory::TransformNestedColumnName(PEGTransformer &transformer, const vector &identifier_dot, - const Identifier &column_name) { - vector column_names = identifier_dot; +unique_ptr PEGTransformerFactory::TransformNestedColumnName( + PEGTransformer &transformer, const optional> &identifier_dot, const Identifier &column_name) { + vector column_names; + if (identifier_dot) { + column_names = *identifier_dot; + } column_names.push_back(column_name); return make_uniq(column_names); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_analyze.cpp b/src/duckdb/src/parser/peg/transformer/transform_analyze.cpp index 2b64f363a..8da07f636 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_analyze.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_analyze.cpp @@ -5,17 +5,17 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformAnalyzeStatement(PEGTransformer &transformer, - const bool &analyze_verbose, - AnalyzeTarget analyze_target) { + const optional &analyze_verbose, + optional analyze_target) { VacuumOptions vacuum_options; vacuum_options.analyze = true; auto result = make_uniq(vacuum_options); if (analyze_verbose) { throw NotImplementedException("ANALYZE VERBOSE is not implemented yet"); } - if (analyze_target.ref) { - result->info->columns = analyze_target.columns; - result->info->ref = std::move(analyze_target.ref); + if (analyze_target && analyze_target->ref) { + result->info->columns = analyze_target->columns; + result->info->ref = std::move(analyze_target->ref); result->info->has_table = true; } return std::move(result); @@ -23,10 +23,12 @@ unique_ptr PEGTransformerFactory::TransformAnalyzeStatement(PEGTra AnalyzeTarget PEGTransformerFactory::TransformAnalyzeTarget(PEGTransformer &transformer, unique_ptr base_table_name, - const vector &name_list) { + const optional> &name_list) { AnalyzeTarget result; result.ref = std::move(base_table_name); - result.columns = StringsToIdentifiers(name_list); + if (name_list) { + result.columns = StringsToIdentifiers(*name_list); + } return result; } diff --git a/src/duckdb/src/parser/peg/transformer/transform_attach.cpp b/src/duckdb/src/parser/peg/transformer/transform_attach.cpp index 174e78b9d..b26bd3b83 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_attach.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_attach.cpp @@ -4,11 +4,10 @@ namespace duckdb { -unique_ptr -PEGTransformerFactory::TransformAttachStatement(PEGTransformer &transformer, const bool &or_replace, - const bool &if_not_exists, unique_ptr database_path, - const Identifier &attach_alias, - const vector &attach_options) { +unique_ptr PEGTransformerFactory::TransformAttachStatement( + PEGTransformer &transformer, const optional &or_replace, const optional &if_not_exists, + const bool &has_result, unique_ptr database_path, const optional &attach_alias, + const optional> &attach_options) { auto result = make_uniq(); auto info = make_uniq(); @@ -25,26 +24,33 @@ PEGTransformerFactory::TransformAttachStatement(PEGTransformer &transformer, con } info->parsed_path = std::move(database_path); - info->name = Identifier(attach_alias); - for (const auto &attach_option : attach_options) { + if (attach_alias) { + info->name = Identifier(*attach_alias); + } + result->info = std::move(info); + if (!attach_options) { + return std::move(result); + } + + auto &attach_info = *result->info; + for (const auto &attach_option : *attach_options) { if (attach_option.expression) { - info->parsed_options[attach_option.name.GetIdentifierName()] = attach_option.expression->Copy(); + attach_info.parsed_options[attach_option.name.GetIdentifierName()] = attach_option.expression->Copy(); continue; } if (attach_option.children.empty()) { - info->options[attach_option.name.GetIdentifierName()] = Value(true); + attach_info.options[attach_option.name.GetIdentifierName()] = Value(true); } else if (attach_option.children.size() == 1) { auto val = attach_option.children[0]; if (val.IsNull()) { throw BinderException("NULL is not supported as a valid option for ATTACH option \"%s\"", attach_option.name); } - info->options[attach_option.name.GetIdentifierName()] = attach_option.children[0]; + attach_info.options[attach_option.name.GetIdentifierName()] = attach_option.children[0]; } else { throw ParserException("Option %s can only have one argument", attach_option.name); } } - result->info = std::move(info); return std::move(result); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_checkpoint.cpp b/src/duckdb/src/parser/peg/transformer/transform_checkpoint.cpp index ef6942d1f..29ec6e4a0 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_checkpoint.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_checkpoint.cpp @@ -5,16 +5,16 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformCheckpointStatement(PEGTransformer &transformer, - const bool &checkpoint_force, - const Identifier &catalog_name) { + const optional &checkpoint_force, + const optional &catalog_name) { auto checkpoint_name = checkpoint_force ? "force_checkpoint" : "checkpoint"; auto result = make_uniq(); vector> children; auto function = make_uniq(checkpoint_name, std::move(children)); function->CatalogMutable() = SYSTEM_CATALOG; function->SchemaMutable() = DEFAULT_SCHEMA; - if (!catalog_name.empty()) { - function->GetArgumentsMutable().emplace_back(make_uniq(catalog_name)); + if (catalog_name) { + function->GetArgumentsMutable().emplace_back(make_uniq(*catalog_name)); } result->function = std::move(function); return std::move(result); diff --git a/src/duckdb/src/parser/peg/transformer/transform_common.cpp b/src/duckdb/src/parser/peg/transformer/transform_common.cpp index 33e799813..ffab6d74f 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_common.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_common.cpp @@ -51,17 +51,19 @@ string PEGTransformerFactory::TransformIdentifierOrKeyword(PEGTransformer &trans LogicalType PEGTransformerFactory::TransformType(PEGTransformer &transformer, unique_ptr type_variations, - const vector &array_bounds) { + const optional> &array_bounds) { auto type = std::move(type_variations); - for (auto array_size : array_bounds) { - vector> children_types; - children_types.push_back(std::move(type)); - - if (array_size < 0) { - type = make_uniq(Identifier("list"), std::move(children_types)); - } else { - children_types.push_back(make_uniq(Value::BIGINT(array_size))); - type = make_uniq(Identifier("array"), std::move(children_types)); + if (array_bounds) { + for (auto array_size : *array_bounds) { + vector> children_types; + children_types.push_back(std::move(type)); + + if (array_size < 0) { + type = make_uniq(Identifier("list"), std::move(children_types)); + } else { + children_types.push_back(make_uniq(Value::BIGINT(array_size))); + type = make_uniq(Identifier("array"), std::move(children_types)); + } } } return LogicalType::UNBOUND(std::move(type)); @@ -72,15 +74,16 @@ int64_t PEGTransformerFactory::TransformArrayKeyword(PEGTransformer &transformer } int64_t PEGTransformerFactory::TransformSquareBracketsArray(PEGTransformer &transformer, - unique_ptr expression) { + optional> expression) { if (!expression) { // Empty array so we return -1 to signify it's a list return -1; } - if (expression->GetExpressionClass() != ExpressionClass::CONSTANT) { + auto &array_size = *expression; + if (array_size->GetExpressionClass() != ExpressionClass::CONSTANT) { throw ParserException("Expected a constant number as array size"); } - auto &const_number = expression->Cast(); + auto &const_number = array_size->Cast(); if (!const_number.GetValue().type().IsIntegral()) { throw BinderException("Expected an integer as array bound instead of %s", const_number.GetValue().ToString()); } @@ -93,10 +96,14 @@ int64_t PEGTransformerFactory::TransformSquareBracketsArray(PEGTransformer &tran unique_ptr PEGTransformerFactory::TransformTimeType(PEGTransformer &transformer, const LogicalTypeId &time_or_timestamp, - vector> type_modifiers, const bool &time_zone) { + optional>> type_modifiers, + const optional &time_zone) { auto type = time_or_timestamp; - auto modifiers = std::move(type_modifiers); - auto with_timezone = time_zone; + vector> modifiers; + if (type_modifiers) { + modifiers = std::move(*type_modifiers); + } + auto with_timezone = time_zone && *time_zone; if (type == LogicalTypeId::TIME) { if (!modifiers.empty()) { throw ParserException("Type TIME does not allow any modifiers"); @@ -202,56 +209,72 @@ string PEGTransformerFactory::TransformDoubleType(PEGTransformer &transformer) { return LogicalTypeIdToString(LogicalTypeId::DOUBLE); } -unique_ptr PEGTransformerFactory::TransformFloatType(PEGTransformer &transformer, - unique_ptr number_literal) { +unique_ptr +PEGTransformerFactory::TransformFloatType(PEGTransformer &transformer, + optional> number_literal) { return make_uniq(Identifier("FLOAT"), vector> {}); } unique_ptr PEGTransformerFactory::TransformDecimalType(PEGTransformer &transformer, - vector> type_modifiers) { - return make_uniq(Identifier("DECIMAL"), std::move(type_modifiers)); + optional>> type_modifiers) { + vector> modifiers; + if (type_modifiers) { + modifiers = std::move(*type_modifiers); + } + return make_uniq(Identifier("DECIMAL"), std::move(modifiers)); } unique_ptr PEGTransformerFactory::TransformDecType(PEGTransformer &transformer, - vector> type_modifiers) { + optional>> type_modifiers) { return TransformDecimalType(transformer, std::move(type_modifiers)); } unique_ptr PEGTransformerFactory::TransformNumericModType(PEGTransformer &transformer, - vector> type_modifiers) { + optional>> type_modifiers) { return TransformDecimalType(transformer, std::move(type_modifiers)); } vector> PEGTransformerFactory::TransformTypeModifiers(PEGTransformer &transformer, - vector> expression) { - for (auto &expr : expression) { + optional>> expression) { + if (!expression) { + return vector> {}; + } + for (auto &expr : *expression) { if (expr->GetExpressionClass() != ExpressionClass::CONSTANT) { throw ParserException("Expected a constant as type modifier"); } } - return expression; + return std::move(*expression); } unique_ptr -PEGTransformerFactory::TransformCharacterSimpleType(PEGTransformer &transformer, const string &character_type, - vector> type_modifiers) { - return make_uniq(character_type, std::move(type_modifiers)); +PEGTransformerFactory::TransformCharacterSimpleType(PEGTransformer &transformer, + optional>> type_modifiers) { + vector> modifiers; + if (type_modifiers) { + modifiers = std::move(*type_modifiers); + } + return make_uniq(Identifier("VARCHAR"), std::move(modifiers)); } unique_ptr PEGTransformerFactory::TransformQualifiedSimpleType(PEGTransformer &transformer, const QualifiedName &qualified_type_name, - vector> type_modifiers) { + optional>> type_modifiers) { auto result = qualified_type_name; if (result.schema.empty()) { result.schema = result.catalog; result.catalog = INVALID_CATALOG; } - return make_uniq(result.catalog, result.schema, result.name, std::move(type_modifiers)); + vector> modifiers; + if (type_modifiers) { + modifiers = std::move(*type_modifiers); + } + return make_uniq(result.catalog, result.schema, result.name, std::move(modifiers)); } QualifiedName PEGTransformerFactory::TransformTypeNameAsQualifiedName(PEGTransformer &transformer, @@ -283,10 +306,6 @@ QualifiedName PEGTransformerFactory::TransformCatalogReservedSchemaTypeName( return result; } -string PEGTransformerFactory::TransformCharacterType(PEGTransformer &transformer) { - return "VARCHAR"; -} - unique_ptr PEGTransformerFactory::TransformMapType(PEGTransformer &transformer, const vector &type) { if (type.size() != 2) { @@ -311,16 +330,18 @@ PEGTransformerFactory::TransformRowType(PEGTransformer &transformer, return make_uniq(Identifier("STRUCT"), std::move(struct_children)); } -unique_ptr PEGTransformerFactory::TransformGeometryType(PEGTransformer &transformer, - unique_ptr expression) { +unique_ptr +PEGTransformerFactory::TransformGeometryType(PEGTransformer &transformer, + optional> expression) { if (!expression) { return make_uniq(Identifier("GEOMETRY"), vector> {}); } + auto geo_modifier = std::move(*expression); vector> geo_children; - if (expression->GetExpressionClass() != ExpressionClass::CONSTANT) { + if (geo_modifier->GetExpressionClass() != ExpressionClass::CONSTANT) { throw ParserException("Expected a constant as type modifier"); } - geo_children.push_back(std::move(expression)); + geo_children.push_back(std::move(geo_modifier)); return make_uniq(Identifier("GEOMETRY"), std::move(geo_children)); } @@ -356,8 +377,8 @@ pair PEGTransformerFactory::TransformColIdType(PEGTrans } unique_ptr PEGTransformerFactory::TransformBitType( - PEGTransformer &transformer, - vector> expression) { // NOLINT(performance-unnecessary-value-param) + PEGTransformer &transformer, const bool &has_result, + optional>> expression) { // NOLINT(performance-unnecessary-value-param) return make_uniq(Identifier("BIT"), vector> {}); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_connect.cpp b/src/duckdb/src/parser/peg/transformer/transform_connect.cpp index a0ad3b38c..11d9bc317 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_connect.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_connect.cpp @@ -4,47 +4,33 @@ namespace duckdb { -//! Shape captured from `SessionTarget <- 'LOCAL' / StringLiteral / CatalogName`. The framework -//! wraps the named sub-rule in a List whose only child is the Choice over the alternatives. -struct SessionTargetCapture { - string name; - bool target_is_local = false; - bool name_is_string_literal = false; -}; +unique_ptr PEGTransformerFactory::TransformLocalSessionTarget(PEGTransformer &transformer) { + auto result = make_uniq(); + result->target_is_local = true; + return result; +} -static SessionTargetCapture TransformSessionTarget(PEGTransformer &transformer, ParseResult &target_result) { - auto &list = target_result.Cast(); - auto &inner = list.Child(0).GetResult(); - SessionTargetCapture result; - switch (inner.type) { - case ParseResultType::KEYWORD: - // 'LOCAL' alternative — name stays empty, just flip the flag. - result.target_is_local = true; - break; - case ParseResultType::STRING: - result.name = inner.Cast().GetRawString(); - result.name_is_string_literal = true; - break; - case ParseResultType::IDENTIFIER: - result.name = inner.Cast().identifier.GetIdentifierName(); - break; - default: - throw InternalException("Unexpected SessionTarget alternative type: %s", ParseResultToString(inner.type)); - } +unique_ptr PEGTransformerFactory::TransformStringSessionTarget(PEGTransformer &transformer, + const string &string_literal) { + auto result = make_uniq(); + result->name = Identifier(string_literal); + result->name_is_string_literal = true; + return result; +} + +unique_ptr PEGTransformerFactory::TransformCatalogSessionTarget(PEGTransformer &transformer, + const Identifier &catalog_name) { + auto result = make_uniq(); + result->name = catalog_name; return result; } -// ConnectStatement <- 'CONNECT' SessionTarget? -unique_ptr PEGTransformerFactory::TransformConnectStatement(PEGTransformer &transformer, - ParseResult &parse_result) { - auto &list_pr = parse_result.Cast(); +unique_ptr +PEGTransformerFactory::TransformConnectStatement(PEGTransformer &transformer, + optional> session_target) { auto info = make_uniq(); - auto &target_opt = list_pr.Child(1); - if (target_opt.HasResult()) { - auto captured = TransformSessionTarget(transformer, target_opt.GetResult()); - info->name = Identifier(std::move(captured.name)); - info->target_is_local = captured.target_is_local; - info->name_is_string_literal = captured.name_is_string_literal; + if (session_target) { + info = std::move(*session_target); } auto result = make_uniq(); result->info = std::move(info); diff --git a/src/duckdb/src/parser/peg/transformer/transform_copy.cpp b/src/duckdb/src/parser/peg/transformer/transform_copy.cpp index ffd44aed0..8f0bec34b 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_copy.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_copy.cpp @@ -70,7 +70,7 @@ void SetCopyOptions(unique_ptr &info, vector &optio unique_ptr PEGTransformerFactory::TransformCopySelect( PEGTransformer &transformer, unique_ptr select_statement_internal, - unique_ptr copy_file_name, const vector ©_options) { + unique_ptr copy_file_name, const optional> ©_options) { auto result = make_uniq(); auto info = make_uniq(); info->is_from = false; @@ -80,8 +80,10 @@ unique_ptr PEGTransformerFactory::TransformCopySelect( } else { info->file_path_expression = std::move(copy_file_name); } - auto options = copy_options; - SetCopyOptions(info, options); + if (copy_options) { + auto options = *copy_options; + SetCopyOptions(info, options); + } info->select_statement = std::move(select_statement_internal->node); result->info = std::move(info); return std::move(result); @@ -124,19 +126,20 @@ string PEGTransformerFactory::ExtractFormat(const string &file_path) { return format.substr(dot_pos + 1); } -unique_ptr PEGTransformerFactory::TransformCopyTable(PEGTransformer &transformer, - unique_ptr base_table_name, - const vector &insert_column_list, - const bool &from_or_to, - unique_ptr copy_file_name, - const vector ©_options) { +unique_ptr +PEGTransformerFactory::TransformCopyTable(PEGTransformer &transformer, unique_ptr base_table_name, + const optional> &insert_column_list, const bool &from_or_to, + unique_ptr copy_file_name, + const optional> ©_options) { auto result = make_uniq(); auto info = make_uniq(); info->table = base_table_name->table_name; info->schema = base_table_name->schema_name; info->catalog = base_table_name->catalog_name; - info->select_list = StringsToIdentifiers(insert_column_list); + if (insert_column_list) { + info->select_list = StringsToIdentifiers(*insert_column_list); + } info->is_from = from_or_to; if (copy_file_name->GetExpressionClass() == ExpressionClass::CONSTANT) { auto &const_expr = copy_file_name->Cast(); @@ -146,8 +149,10 @@ unique_ptr PEGTransformerFactory::TransformCopyTable(PEGTransforme } info->format = ExtractFormat(info->file_path); - auto generic_options = copy_options; - SetCopyOptions(info, generic_options); + if (copy_options) { + auto generic_options = *copy_options; + SetCopyOptions(info, generic_options); + } result->info = std::move(info); return std::move(result); @@ -188,15 +193,18 @@ Identifier PEGTransformerFactory::TransformIdentifierColId(PEGTransformer &trans } vector -PEGTransformerFactory::TransformCopyOptions(PEGTransformer &transformer, +PEGTransformerFactory::TransformCopyOptions(PEGTransformer &transformer, const bool &has_result, const vector ©_option_list) { return copy_option_list; } vector PEGTransformerFactory::TransformSpecializedOptionList(PEGTransformer &transformer, - const vector &specialized_option) { - return specialized_option; + const optional> &specialized_option) { + if (!specialized_option) { + return {}; + } + return *specialized_option; } GenericCopyOption PEGTransformerFactory::TransformEncodingOption(PEGTransformer &transformer, @@ -204,7 +212,8 @@ GenericCopyOption PEGTransformerFactory::TransformEncodingOption(PEGTransformer return GenericCopyOption("encoding", string_literal); } -GenericCopyOption PEGTransformerFactory::TransformForceQuoteOption(PEGTransformer &transformer, const bool &force_quote, +GenericCopyOption PEGTransformerFactory::TransformForceQuoteOption(PEGTransformer &transformer, + const optional &force_quote, const vector &star_symbol_column_list) { string func_name = force_quote ? "force_quote" : "quote"; auto result = GenericCopyOption(); @@ -219,13 +228,13 @@ GenericCopyOption PEGTransformerFactory::TransformForceQuoteOption(PEGTransforme return result; } -GenericCopyOption PEGTransformerFactory::TransformQuoteAsOption(PEGTransformer &transformer, +GenericCopyOption PEGTransformerFactory::TransformQuoteAsOption(PEGTransformer &transformer, const bool &has_result, const string &string_literal) { return GenericCopyOption("quote", string_literal); } GenericCopyOption PEGTransformerFactory::TransformForceNullOption(PEGTransformer &transformer, - const bool &force_not_null, + const optional &force_not_null, const vector &column_list) { auto result = GenericCopyOption(); result.name = force_not_null ? "force_not_null" : "force_null"; @@ -249,17 +258,17 @@ GenericCopyOption PEGTransformerFactory::TransformPartitionByOption(PEGTransform return result; } -GenericCopyOption PEGTransformerFactory::TransformNullAsOption(PEGTransformer &transformer, +GenericCopyOption PEGTransformerFactory::TransformNullAsOption(PEGTransformer &transformer, const bool &has_result, const string &string_literal) { return GenericCopyOption("null", string_literal); } -GenericCopyOption PEGTransformerFactory::TransformDelimiterAsOption(PEGTransformer &transformer, +GenericCopyOption PEGTransformerFactory::TransformDelimiterAsOption(PEGTransformer &transformer, const bool &has_result, const string &string_literal) { return GenericCopyOption("delimiter", string_literal); } -GenericCopyOption PEGTransformerFactory::TransformEscapeAsOption(PEGTransformer &transformer, +GenericCopyOption PEGTransformerFactory::TransformEscapeAsOption(PEGTransformer &transformer, const bool &has_result, const string &string_literal) { return GenericCopyOption("escape", string_literal); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_index.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_index.cpp index f4a6d23bf..3426e724d 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_index.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_index.cpp @@ -4,44 +4,52 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformCreateIndexStmt( - PEGTransformer &transformer, const bool &unique_index, const bool &if_not_exists, const Identifier &index_name, - unique_ptr base_table_name, const vector &insert_column_list, const Identifier &index_type, - vector> index_element, case_insensitive_map_t> with_list, - unique_ptr where_clause) { + PEGTransformer &transformer, const optional &unique_index, const optional &if_not_exists, + const optional &index_name, unique_ptr base_table_name, + const optional> &insert_column_list, const optional &index_type, + optional>> index_element, + optional>> with_list, + optional> where_clause) { auto result = make_uniq(); auto index_info = make_uniq(); index_info->constraint_type = unique_index ? IndexConstraintType::UNIQUE : IndexConstraintType::NONE; index_info->on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; - if (index_name.empty()) { + if (!index_name) { throw NotImplementedException("Please provide an index name, e.g., CREATE INDEX my_name ..."); } - index_info->index_name = index_name; + index_info->index_name = *index_name; index_info->table = base_table_name->table_name; index_info->catalog = base_table_name->catalog_name; index_info->schema = base_table_name->schema_name; - index_info->index_type = index_type.empty() ? "ART" : index_type.GetIdentifierName(); - for (auto &column : insert_column_list) { - index_info->expressions.push_back( - make_uniq(Identifier(column), base_table_name->table_name)); - index_info->parsed_expressions.push_back( - make_uniq(Identifier(column), base_table_name->table_name)); + index_info->index_type = index_type ? index_type->GetIdentifierName() : "ART"; + if (insert_column_list) { + for (auto &column : *insert_column_list) { + index_info->expressions.push_back( + make_uniq(Identifier(column), base_table_name->table_name)); + index_info->parsed_expressions.push_back( + make_uniq(Identifier(column), base_table_name->table_name)); + } } - for (auto &expr : index_element) { - if (expr->GetExpressionType() == ExpressionType::COLLATE) { - throw NotImplementedException("Index with collation not supported yet!"); + if (index_element) { + for (auto &expr : *index_element) { + if (expr->GetExpressionType() == ExpressionType::COLLATE) { + throw NotImplementedException("Index with collation not supported yet!"); + } + index_info->expressions.push_back(expr->Copy()); + index_info->parsed_expressions.push_back(std::move(expr)); } - index_info->expressions.push_back(expr->Copy()); - index_info->parsed_expressions.push_back(std::move(expr)); } if (where_clause) { throw NotImplementedException("Creating partial indexes is not supported currently"); } - for (auto &option_entry : with_list) { - if (option_entry.second->GetExpressionClass() != ExpressionClass::CONSTANT) { - throw InvalidInputException("Create index option must be a constant value"); + if (with_list) { + for (auto &option_entry : *with_list) { + if (option_entry.second->GetExpressionClass() != ExpressionClass::CONSTANT) { + throw InvalidInputException("Create index option must be a constant value"); + } + index_info->options[option_entry.first] = option_entry.second->Cast().GetValue(); } - index_info->options[option_entry.first] = option_entry.second->Cast().GetValue(); } result->info = std::move(index_info); return result; @@ -56,10 +64,10 @@ Identifier PEGTransformerFactory::TransformIndexType(PEGTransformer &transformer return identifier; } -unique_ptr PEGTransformerFactory::TransformIndexElement(PEGTransformer &transformer, - unique_ptr expression, - const OrderType &desc_or_asc, - const OrderByNullType &nulls_first_or_last) { +unique_ptr +PEGTransformerFactory::TransformIndexElement(PEGTransformer &transformer, unique_ptr expression, + const optional &desc_or_asc, + const optional &nulls_first_or_last) { // TODO(Dtenwolde): We currently ignore desc_or_asc and nulls_first_or_last return expression; } @@ -106,11 +114,11 @@ Identifier PEGTransformerFactory::TransformRelOptionName(PEGTransformer &transfo pair> PEGTransformerFactory::TransformRelOption(PEGTransformer &transformer, const Identifier &rel_option_name, - unique_ptr rel_option_argument_opt) { + optional> rel_option_argument_opt) { if (!rel_option_argument_opt) { return {rel_option_name, make_uniq(Value())}; } - return {rel_option_name, std::move(rel_option_argument_opt)}; + return {rel_option_name, std::move(*rel_option_argument_opt)}; } // RelOptionArgumentOpt <- '=' DefArg diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_macro.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_macro.cpp index 1f2e991ca..2dc7c8bfd 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_macro.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_macro.cpp @@ -5,10 +5,9 @@ #include "duckdb/function/scalar_macro_function.hpp" namespace duckdb { -unique_ptr -PEGTransformerFactory::TransformCreateMacroStmt(PEGTransformer &transformer, const bool ¯o_or_function, - const bool &if_not_exists, const QualifiedName &qualified_name, - vector> macro_definition) { +unique_ptr PEGTransformerFactory::TransformCreateMacroStmt( + PEGTransformer &transformer, const bool ¯o_or_function, const optional &if_not_exists, + const QualifiedName &qualified_name, vector> macro_definition) { auto result = make_uniq(); auto info = make_uniq(CatalogType::MACRO_ENTRY); @@ -48,11 +47,15 @@ bool PEGTransformerFactory::TransformFunctionKeyword(PEGTransformer &transformer } unique_ptr -PEGTransformerFactory::TransformMacroDefinition(PEGTransformer &transformer, vector macro_parameters, +PEGTransformerFactory::TransformMacroDefinition(PEGTransformer &transformer, + optional> macro_parameters, unique_ptr macro_definition_body) { + if (!macro_parameters) { + return macro_definition_body; + } bool default_value_found = false; identifier_set_t parameter_names; - for (auto ¶meter : macro_parameters) { + for (auto ¶meter : *macro_parameters) { D_ASSERT(!parameter.name.empty()); if (parameter_names.find(parameter.name) != parameter_names.end()) { throw ParserException("Duplicate parameter '%s' in macro definition", parameter.name.GetIdentifierName()); @@ -99,12 +102,12 @@ vector PEGTransformerFactory::TransformMacroParameters(PEGTransf MacroParameter PEGTransformerFactory::TransformSimpleParameter(PEGTransformer &transformer, const Identifier &type_func_name, - const LogicalType &type) { + const optional &type) { MacroParameter result; result.name = Identifier(type_func_name); result.expression = make_uniq(Identifier(type_func_name)); - if (type.id() != LogicalTypeId::INVALID) { - result.type = type; + if (type) { + result.type = *type; } result.is_default = false; return result; diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_schema.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_schema.cpp index 7063518e3..5acff3941 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_schema.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_schema.cpp @@ -3,7 +3,7 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformCreateSchemaStmt(PEGTransformer &transformer, - const bool &if_not_exists, + const optional &if_not_exists, const QualifiedName &qualified_name) { if (!qualified_name.catalog.empty()) { throw ParserException("CREATE SCHEMA too many dots: expected \"catalog.schema\" or \"schema\""); diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_secret.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_secret.cpp index 2e4da4ddd..79e3a0843 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_secret.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_secret.cpp @@ -14,16 +14,16 @@ Value PEGTransformerFactory::GetConstantExpressionValue(unique_ptr PEGTransformerFactory::TransformCreateSecretStmt( - PEGTransformer &transformer, const bool &if_not_exists, const Identifier &secret_name, - const Identifier &secret_storage_specifier, const vector &generic_copy_option_list) { + PEGTransformer &transformer, const optional &if_not_exists, const optional &secret_name, + const optional &secret_storage_specifier, const vector &generic_copy_option_list) { auto result = make_uniq(); auto on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; auto info = make_uniq(on_conflict, SecretPersistType::DEFAULT); - if (!secret_name.empty()) { - info->name = secret_name; + if (secret_name) { + info->name = *secret_name; } - if (!secret_storage_specifier.empty()) { - info->storage_type = Identifier(StringUtil::Lower(secret_storage_specifier.GetIdentifierName())); + if (secret_storage_specifier) { + info->storage_type = Identifier(StringUtil::Lower(secret_storage_specifier->GetIdentifierName())); } for (const auto &option : generic_copy_option_list) { auto lower_name = StringUtil::Lower(option.name.GetIdentifierName()); diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_sequence.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_sequence.cpp index a779e4b9b..3e0dab0c5 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_sequence.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_sequence.cpp @@ -4,10 +4,9 @@ namespace duckdb { -unique_ptr -PEGTransformerFactory::TransformCreateSequenceStmt(PEGTransformer &transformer, const bool &if_not_exists, - const QualifiedName &qualified_name, - vector>> sequence_option) { +unique_ptr PEGTransformerFactory::TransformCreateSequenceStmt( + PEGTransformer &transformer, const optional &if_not_exists, const QualifiedName &qualified_name, + optional>>> sequence_option) { auto result = make_uniq(); auto info = make_uniq(); info->catalog = qualified_name.catalog; @@ -15,13 +14,15 @@ PEGTransformerFactory::TransformCreateSequenceStmt(PEGTransformer &transformer, info->name = qualified_name.name; info->on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; case_insensitive_map_t> sequence_options; - for (auto &seq_option : sequence_option) { - if (sequence_options.find(seq_option.first) != sequence_options.end()) { - auto seq_option_capital = StringUtil::Lower(seq_option.first); - seq_option_capital[0] = StringUtil::CharacterToUpper(seq_option_capital[0]); - throw ParserException("%s should be passed at most once", seq_option_capital); + if (sequence_option) { + for (auto &seq_option : *sequence_option) { + if (sequence_options.find(seq_option.first) != sequence_options.end()) { + auto seq_option_capital = StringUtil::Lower(seq_option.first); + seq_option_capital[0] = StringUtil::CharacterToUpper(seq_option_capital[0]); + throw ParserException("%s should be passed at most once", seq_option_capital); + } + sequence_options.insert(std::move(seq_option)); } - sequence_options.insert(std::move(seq_option)); } bool no_min = false; bool no_max = false; @@ -120,7 +121,8 @@ pair> PEGTransformerFactory::TransformSeqNoCy } pair> -PEGTransformerFactory::TransformSeqSetIncrement(PEGTransformer &transformer, unique_ptr expression) { +PEGTransformerFactory::TransformSeqSetIncrement(PEGTransformer &transformer, const bool &has_result, + unique_ptr expression) { if (expression->GetExpressionClass() == ExpressionClass::FUNCTION) { auto func_expr = unique_ptr_cast(std::move(expression)); if (func_expr->FunctionName() != "-") { @@ -173,7 +175,8 @@ pair> PEGTransformerFactory::TransformSeqNoMi } pair> -PEGTransformerFactory::TransformSeqStartWith(PEGTransformer &transformer, unique_ptr expression) { +PEGTransformerFactory::TransformSeqStartWith(PEGTransformer &transformer, const bool &has_result, + unique_ptr expression) { if (expression->GetExpressionClass() != ExpressionClass::CONSTANT) { throw ParserException("Expected constant expression."); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_table.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_table.cpp index 0b3ac06fb..54ca47161 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_table.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_table.cpp @@ -22,8 +22,8 @@ namespace duckdb { unique_ptr -PEGTransformerFactory::TransformCreateStatement(PEGTransformer &transformer, const bool &or_replace, - const SecretPersistType &temporary, +PEGTransformerFactory::TransformCreateStatement(PEGTransformer &transformer, const optional &or_replace, + const optional &temporary, unique_ptr create_statement_variation) { auto result = std::move(create_statement_variation); auto &conflict_policy = result->info->on_conflict; @@ -35,9 +35,9 @@ PEGTransformerFactory::TransformCreateStatement(PEGTransformer &transformer, con } if (result->info->type == CatalogType::SECRET_ENTRY) { auto &secret_info = result->info->Cast(); - secret_info.persist_type = temporary; + secret_info.persist_type = temporary ? *temporary : SecretPersistType::DEFAULT; } - result->info->temporary = temporary == SecretPersistType::TEMPORARY; + result->info->temporary = temporary && *temporary == SecretPersistType::TEMPORARY; return std::move(result); } @@ -54,8 +54,8 @@ SecretPersistType PEGTransformerFactory::TransformTemporaryPersistent(PEGTransfo } unique_ptr PEGTransformerFactory::TransformCreateTableStmt( - PEGTransformer &transformer, const bool &if_not_exists, const QualifiedName &qualified_name, - CreateTableDefinition create_table_definition, const bool &commit_action) { + PEGTransformer &transformer, const optional &if_not_exists, const QualifiedName &qualified_name, + CreateTableDefinition create_table_definition, const optional &commit_action) { auto result = make_uniq(); if (qualified_name.name.empty()) { throw ParserException("Empty table name not supported"); @@ -76,20 +76,26 @@ unique_ptr PEGTransformerFactory::TransformCreateTableStmt( } CreateTableDefinition -PEGTransformerFactory::TransformCreateTableAs(PEGTransformer &transformer, ColumnList identifier_list, - PartitionSortedOptions partition_sorted_options, - case_insensitive_map_t> with_list, - unique_ptr statement, const bool &with_data) { +PEGTransformerFactory::TransformCreateTableAs(PEGTransformer &transformer, optional identifier_list, + optional partition_sorted_options, + optional>> with_list, + unique_ptr statement, const optional &with_data) { CreateTableDefinition result; - result.columns = std::move(identifier_list); - result.partition_keys = std::move(partition_sorted_options.partition_keys); - result.sort_keys = std::move(partition_sorted_options.sort_keys); - result.options = std::move(with_list); + if (identifier_list) { + result.columns = std::move(*identifier_list); + } + if (partition_sorted_options) { + result.partition_keys = std::move(partition_sorted_options->partition_keys); + result.sort_keys = std::move(partition_sorted_options->sort_keys); + } + if (with_list) { + result.options = std::move(*with_list); + } if (statement->type != StatementType::SELECT_STATEMENT) { throw ParserException("CREATE TABLE AS requires a SELECT clause"); } result.select_statement = unique_ptr_cast(std::move(statement)); - if (with_data) { + if (with_data && *with_data) { auto limit_modifier = make_uniq(); limit_modifier->limit = make_uniq(0); result.select_statement->node->modifiers.push_back(std::move(limit_modifier)); @@ -106,19 +112,23 @@ ColumnList PEGTransformerFactory::TransformIdentifierList(PEGTransformer &transf return result; } -CreateTableDefinition -PEGTransformerFactory::TransformCreateColumnList(PEGTransformer &transformer, ColumnElements create_table_column_list, - PartitionSortedOptions partition_sorted_options, - case_insensitive_map_t> with_list) { - if (create_table_column_list.columns.empty()) { +CreateTableDefinition PEGTransformerFactory::TransformCreateColumnList( + PEGTransformer &transformer, optional create_table_column_list, + optional partition_sorted_options, + optional>> with_list) { + if (!create_table_column_list || create_table_column_list->columns.empty()) { throw ParserException("Table must have at least one column!"); } CreateTableDefinition result; - result.columns = std::move(create_table_column_list.columns); - result.constraints = std::move(create_table_column_list.constraints); - result.partition_keys = std::move(partition_sorted_options.partition_keys); - result.sort_keys = std::move(partition_sorted_options.sort_keys); - result.options = std::move(with_list); + result.columns = std::move(create_table_column_list->columns); + result.constraints = std::move(create_table_column_list->constraints); + if (partition_sorted_options) { + result.partition_keys = std::move(partition_sorted_options->partition_keys); + result.sort_keys = std::move(partition_sorted_options->sort_keys); + } + if (with_list) { + result.options = std::move(*with_list); + } return result; } @@ -205,74 +215,80 @@ string PEGTransformerFactory::TransformIdentifier(PEGTransformer &transformer, P vector PEGTransformerFactory::TransformDottedIdentifier(PEGTransformer &transformer, const Identifier &identifier, - const vector &dot_col_label) { + const optional> &dot_col_label) { vector parts {identifier.GetIdentifierName()}; - parts.insert(parts.end(), dot_col_label.begin(), dot_col_label.end()); + if (dot_col_label) { + parts.insert(parts.end(), dot_col_label->begin(), dot_col_label->end()); + } return parts; } -ConstraintColumnDefinition -PEGTransformerFactory::TransformColumnDefinition(PEGTransformer &transformer, const vector &dotted_identifier, - const LogicalType &type, GeneratedColumnDefinition generated_column, - vector column_constraint) { +ConstraintColumnDefinition PEGTransformerFactory::TransformColumnDefinition( + PEGTransformer &transformer, const vector &dotted_identifier, const optional &type, + optional generated_column, const bool &has_result, + optional> column_constraint) { auto qualified_name = StringToQualifiedName(dotted_identifier); - bool has_type = type != LogicalType::INVALID; - bool has_generated = generated_column.expr != nullptr; + bool has_type = type.has_value(); + bool has_generated = generated_column && generated_column->expr != nullptr; if (!has_type && !has_generated) { throw ParserException("Column %s must have a type or be defined as a GENERATED column.", qualified_name.ToString()); } - auto column_type = has_type ? type : LogicalType::ANY; + auto column_type = has_type ? *type : LogicalType::ANY; CompressionType compression_type = CompressionType::COMPRESSION_AUTO; ColumnConstraint accumulated_constraints; - for (auto &cc_entry : column_constraint) { - if (cc_entry.constraint_name == "DefaultValue") { - if (accumulated_constraints.default_value) { - throw ParserException("Cannot define a default value twice"); - } - accumulated_constraints.default_value = std::move(cc_entry.expression); - } else if (cc_entry.constraint_name == "NotNullConstraint" || cc_entry.constraint_name == "UniqueConstraint" || - cc_entry.constraint_name == "PrimaryKeyConstraint") { - accumulated_constraints.constraint_types.push_back(cc_entry.constraint_type_info); - } else if (cc_entry.constraint_name == "ColumnCompression") { - compression_type = cc_entry.compression_type; - if (compression_type == CompressionType::COMPRESSION_AUTO) { - throw ParserException("Unrecognized option for column compression, expected none, uncompressed, rle, " - "dictionary, pfor, bitpacking, fsst, chimp, patas, zstd, alp, alprd or roaring"); - } - } else if (cc_entry.constraint_name == "ForeignKeyConstraint") { - auto &fk_constraint = cc_entry.constraint->Cast(); - fk_constraint.fk_columns.push_back(qualified_name.name); - accumulated_constraints.constraints.push_back(std::move(cc_entry.constraint)); - } else if (cc_entry.constraint_name == "ColumnCollation") { - if (has_generated) { - throw ParserException("Collations are not supported on generated columns"); - } - if (column_type.id() == LogicalTypeId::ANY) { - throw ParserException("Specify the VARCHAR type for column \"%s\" with collation.", - qualified_name.ToString()); - } else if (column_type.IsUnbound()) { - auto &expr = UnboundType::GetTypeExpression(column_type); - if (expr->GetExpressionClass() != ExpressionClass::TYPE) { - throw InternalException("Expected a type expression"); + if (column_constraint) { + for (auto &cc_entry : *column_constraint) { + if (cc_entry.constraint_name == "DefaultValue") { + if (accumulated_constraints.default_value) { + throw ParserException("Cannot define a default value twice"); } - auto &type_expr = expr->Cast(); - if (DefaultTypeGenerator::GetDefaultType(type_expr.GetTypeName()) != LogicalTypeId::VARCHAR) { - throw ParserException("Only VARCHAR columns can have collations!"); + accumulated_constraints.default_value = std::move(cc_entry.expression); + } else if (cc_entry.constraint_name == "NotNullConstraint" || + cc_entry.constraint_name == "UniqueConstraint" || + cc_entry.constraint_name == "PrimaryKeyConstraint") { + accumulated_constraints.constraint_types.push_back(cc_entry.constraint_type_info); + } else if (cc_entry.constraint_name == "ColumnCompression") { + compression_type = cc_entry.compression_type; + if (compression_type == CompressionType::COMPRESSION_AUTO) { + throw ParserException("Unrecognized option for column compression, expected none, uncompressed, " + "rle, dictionary, pfor, bitpacking, fsst, chimp, patas, zstd, alp, alprd or " + "roaring"); } + } else if (cc_entry.constraint_name == "ForeignKeyConstraint") { + auto &fk_constraint = cc_entry.constraint->Cast(); + fk_constraint.fk_columns.push_back(qualified_name.name); + accumulated_constraints.constraints.push_back(std::move(cc_entry.constraint)); + } else if (cc_entry.constraint_name == "ColumnCollation") { + if (has_generated) { + throw ParserException("Collations are not supported on generated columns"); + } + if (column_type.id() == LogicalTypeId::ANY) { + throw ParserException("Specify the VARCHAR type for column \"%s\" with collation.", + qualified_name.ToString()); + } else if (column_type.IsUnbound()) { + auto &expr = UnboundType::GetTypeExpression(column_type); + if (expr->GetExpressionClass() != ExpressionClass::TYPE) { + throw InternalException("Expected a type expression"); + } + auto &type_expr = expr->Cast(); + if (DefaultTypeGenerator::GetDefaultType(type_expr.GetTypeName()) != LogicalTypeId::VARCHAR) { + throw ParserException("Only VARCHAR columns can have collations!"); + } + } else { + throw InternalException("Expected only unbound types here"); + } + vector> type_children; + type_children.push_back(std::move(cc_entry.expression)); + column_type = + LogicalType::UNBOUND(make_uniq(Identifier("VARCHAR"), std::move(type_children))); } else { - throw InternalException("Expected only unbound types here"); + accumulated_constraints.constraints.push_back(std::move(cc_entry.constraint)); } - vector> type_children; - type_children.push_back(std::move(cc_entry.expression)); - column_type = - LogicalType::UNBOUND(make_uniq(Identifier("VARCHAR"), std::move(type_children))); - } else { - accumulated_constraints.constraints.push_back(std::move(cc_entry.constraint)); } } if (has_generated) { - auto generated = std::move(generated_column); + auto generated = std::move(*generated_column); if (generated.expr->HasSubquery()) { throw ParserException("Expression of generated column \"%s\" contains a subquery, which isn't allowed", qualified_name.name); @@ -307,8 +323,9 @@ PEGTransformerFactory::TransformColumnDefinition(PEGTransformer &transformer, co } GeneratedColumnDefinition PEGTransformerFactory::TransformGeneratedColumn(PEGTransformer &transformer, + const bool &has_result, unique_ptr expression, - const bool &generated_column_type) { + const optional &generated_column_type) { GeneratedColumnDefinition generated; generated.expr = std::move(expression); VerifyColumnRefs(*generated.expr); @@ -324,7 +341,7 @@ ColumnConstraintEntry PEGTransformerFactory::TransformDefaultValue(PEGTransforme } unique_ptr -PEGTransformerFactory::TransformTopLevelConstraint(PEGTransformer &transformer, +PEGTransformerFactory::TransformTopLevelConstraint(PEGTransformer &transformer, const bool &has_result, unique_ptr top_level_constraint_list) { return top_level_constraint_list; } @@ -386,7 +403,7 @@ ColumnConstraintEntry PEGTransformerFactory::TransformColumnCompression(PEGTrans ColumnConstraintEntry PEGTransformerFactory::TransformForeignKeyConstraint(PEGTransformer &transformer, unique_ptr base_table_name, - const vector &column_list, + const optional> &column_list, const KeyActions &key_actions) { ForeignKeyInfo fk_info; fk_info.schema = base_table_name->schema_name; @@ -395,16 +412,24 @@ ColumnConstraintEntry PEGTransformerFactory::TransformForeignKeyConstraint(PEGTr ColumnConstraintEntry entry; entry.constraint_name = "ForeignKeyConstraint"; - entry.constraint = - make_uniq(StringsToIdentifiers(column_list), vector(), fk_info); + vector columns; + if (column_list) { + columns = StringsToIdentifiers(*column_list); + } + entry.constraint = make_uniq(columns, vector(), fk_info); return entry; } -KeyActions PEGTransformerFactory::TransformKeyActions(PEGTransformer &transformer, const string &update_action, - const string &delete_action) { +KeyActions PEGTransformerFactory::TransformKeyActions(PEGTransformer &transformer, + const optional &update_action, + const optional &delete_action) { KeyActions results; - results.update_action = update_action; - results.delete_action = delete_action; + if (update_action) { + results.update_action = *update_action; + } + if (delete_action) { + results.delete_action = *delete_action; + } return results; } @@ -526,23 +551,25 @@ PEGTransformerFactory::TransformSortedOptions(PEGTransformer &transformer, return expression; } -PartitionSortedOptions -PEGTransformerFactory::TransformPartitionOptSortedOptions(PEGTransformer &transformer, - vector> partition_options, - vector> sorted_options) { +PartitionSortedOptions PEGTransformerFactory::TransformPartitionOptSortedOptions( + PEGTransformer &transformer, vector> partition_options, + optional>> sorted_options) { PartitionSortedOptions result; result.partition_keys = std::move(partition_options); - result.sort_keys = std::move(sorted_options); + if (sorted_options) { + result.sort_keys = std::move(*sorted_options); + } return result; } -PartitionSortedOptions -PEGTransformerFactory::TransformSortedOptPartitionOptions(PEGTransformer &transformer, - vector> sorted_options, - vector> partition_options) { +PartitionSortedOptions PEGTransformerFactory::TransformSortedOptPartitionOptions( + PEGTransformer &transformer, vector> sorted_options, + optional>> partition_options) { PartitionSortedOptions result; result.sort_keys = std::move(sorted_options); - result.partition_keys = std::move(partition_options); + if (partition_options) { + result.partition_keys = std::move(*partition_options); + } return result; } diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_trigger.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_trigger.cpp index dcc93b045..b9b0dd3e5 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_trigger.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_trigger.cpp @@ -27,10 +27,10 @@ static unique_ptr ExtractQueryNode(unique_ptr stmt) { } unique_ptr PEGTransformerFactory::TransformCreateTriggerStmt( - PEGTransformer &transformer, const bool &if_not_exists, const Identifier &trigger_name, + PEGTransformer &transformer, const optional &if_not_exists, const Identifier &trigger_name, const TriggerTiming &trigger_timing, const TriggerEventInfo &trigger_event, - unique_ptr base_table_name, const TriggerTableReferencingInfo &referencing_clause, - const TriggerForEach &for_each_clause, unique_ptr trigger_body) { + unique_ptr base_table_name, const optional &referencing_clause, + const optional &for_each_clause, unique_ptr trigger_body) { auto result = make_uniq(); auto info = make_uniq(); info->on_conflict = if_not_exists ? OnCreateConflict::IGNORE_ON_CONFLICT : OnCreateConflict::ERROR_ON_CONFLICT; @@ -39,9 +39,13 @@ unique_ptr PEGTransformerFactory::TransformCreateTriggerStmt( info->event_type = trigger_event.event_type; info->columns = trigger_event.columns; info->base_table = std::move(base_table_name); - info->referencing_new_table = referencing_clause.new_table; - info->referencing_old_table = referencing_clause.old_table; - info->for_each = for_each_clause; + if (referencing_clause) { + info->referencing_new_table = referencing_clause->new_table; + info->referencing_old_table = referencing_clause->old_table; + } + if (for_each_clause) { + info->for_each = *for_each_clause; + } info->trigger_action = ExtractQueryNode(std::move(trigger_body)); result->info = std::move(info); return result; @@ -111,19 +115,22 @@ TriggerTableReferencingInfo PEGTransformerFactory::TransformReferencingOldTableA TriggerTableReferencingInfo PEGTransformerFactory::TransformReferencingClause(PEGTransformer &transformer, const TriggerTableReferencingInfo &referencing_item, - const TriggerTableReferencingInfo &referencing_item_1) { + const optional &referencing_item_1) { auto result = referencing_item; - if (!referencing_item_1.new_table.empty()) { + if (!referencing_item_1) { + return result; + } + if (!referencing_item_1->new_table.empty()) { if (!result.new_table.empty()) { throw ParserException("NEW TABLE cannot be specified multiple times in REFERENCING clause"); } - result.new_table = referencing_item_1.new_table; + result.new_table = referencing_item_1->new_table; } - if (!referencing_item_1.old_table.empty()) { + if (!referencing_item_1->old_table.empty()) { if (!result.old_table.empty()) { throw ParserException("OLD TABLE cannot be specified multiple times in REFERENCING clause"); } - result.old_table = referencing_item_1.old_table; + result.old_table = referencing_item_1->old_table; } if (!result.new_table.empty() && !result.old_table.empty() && result.new_table == result.old_table) { throw ParserException("REFERENCING aliases must be distinct"); diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_type.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_type.cpp index 179c4b8f0..bdfffa0d0 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_type.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_type.cpp @@ -6,7 +6,7 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformCreateTypeStmt(PEGTransformer &transformer, - const bool &if_not_exists, + const optional &if_not_exists, const QualifiedName &qualified_name, unique_ptr create_type) { auto result = make_uniq(); @@ -35,15 +35,19 @@ PEGTransformerFactory::TransformEnumSelectType(PEGTransformer &transformer, return result; } -unique_ptr PEGTransformerFactory::TransformEnumStringLiteralList(PEGTransformer &transformer, - const vector &string_literal) { +unique_ptr +PEGTransformerFactory::TransformEnumStringLiteralList(PEGTransformer &transformer, + const optional> &string_literal) { auto result = make_uniq(); - Vector enum_vector(LogicalType::VARCHAR, string_literal.size()); - auto string_data = FlatVector::Writer(enum_vector, string_literal.size()); - for (auto &literal : string_literal) { - string_data.WriteValue(string_t(literal)); + idx_t enum_count = string_literal ? string_literal->size() : 0; + Vector enum_vector(LogicalType::VARCHAR, enum_count); + auto string_data = FlatVector::Writer(enum_vector, enum_count); + if (string_literal) { + for (auto &literal : *string_literal) { + string_data.WriteValue(string_t(literal)); + } } - result->type = LogicalType::ENUM(enum_vector, string_literal.size()); + result->type = LogicalType::ENUM(enum_vector, enum_count); return result; } diff --git a/src/duckdb/src/parser/peg/transformer/transform_create_view.cpp b/src/duckdb/src/parser/peg/transformer/transform_create_view.cpp index a281e7604..ea1c4373d 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_create_view.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_create_view.cpp @@ -87,10 +87,10 @@ void PEGTransformerFactory::ConvertToRecursiveView(unique_ptr &i } unique_ptr -PEGTransformerFactory::TransformCreateViewStmt(PEGTransformer &transformer, const bool &create_recursive, - const bool &if_not_exists, const QualifiedName &qualified_name, - const vector &insert_column_list, - case_insensitive_map_t> with_list, +PEGTransformerFactory::TransformCreateViewStmt(PEGTransformer &transformer, const optional &create_recursive, + const optional &if_not_exists, const QualifiedName &qualified_name, + const optional> &insert_column_list, + optional>> with_list, unique_ptr select_statement_internal) { auto result = make_uniq(); auto info = make_uniq(); @@ -98,9 +98,11 @@ PEGTransformerFactory::TransformCreateViewStmt(PEGTransformer &transformer, cons info->catalog = qualified_name.catalog; info->schema = qualified_name.schema; info->view_name = qualified_name.name; - info->aliases = StringsToIdentifiers(insert_column_list); - if (!with_list.empty()) { - for (auto &option_entry : with_list) { + if (insert_column_list) { + info->aliases = StringsToIdentifiers(*insert_column_list); + } + if (with_list) { + for (auto &option_entry : *with_list) { if (!StringUtil::CIEquals(option_entry.first, "defer_binding")) { throw ParserException("Only DEFER_BINDING is currently supported as option for CREATE VIEW"); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_deallocate.cpp b/src/duckdb/src/parser/peg/transformer/transform_deallocate.cpp index c1403a0a8..e93574225 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_deallocate.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_deallocate.cpp @@ -4,7 +4,7 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformDeallocateStatement(PEGTransformer &transformer, - const bool &deallocate_prepare, + const optional &deallocate_prepare, const Identifier &identifier) { auto result = make_uniq(); result->info->type = CatalogType::PREPARED_STATEMENT; diff --git a/src/duckdb/src/parser/peg/transformer/transform_delete.cpp b/src/duckdb/src/parser/peg/transformer/transform_delete.cpp index 5e0ff952f..577dfca05 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_delete.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_delete.cpp @@ -5,26 +5,34 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformDeleteStatement( - PEGTransformer &transformer, CommonTableExpressionMap with_clause, unique_ptr target_opt_alias, - vector> delete_using_clause, unique_ptr where_clause, - vector> returning_clause) { + PEGTransformer &transformer, optional with_clause, + unique_ptr target_opt_alias, optional>> delete_using_clause, + optional> where_clause, + optional>> returning_clause) { auto result = make_uniq(); auto &node = *result->node; - if (!with_clause.map.empty()) { - node.cte_map = std::move(with_clause); + if (with_clause && !with_clause->map.empty()) { + node.cte_map = std::move(*with_clause); } node.table = std::move(target_opt_alias); - node.using_clauses = std::move(delete_using_clause); - node.condition = std::move(where_clause); - node.returning_list = std::move(returning_clause); + if (delete_using_clause) { + node.using_clauses = std::move(*delete_using_clause); + } + if (where_clause) { + node.condition = std::move(*where_clause); + } + if (returning_clause) { + node.returning_list = std::move(*returning_clause); + } return std::move(result); } unique_ptr PEGTransformerFactory::TransformTargetOptAlias(PEGTransformer &transformer, unique_ptr base_table_name, - const Identifier &col_id) { - if (!col_id.empty()) { - base_table_name->alias = Identifier(col_id); + const bool &has_result, + const optional &col_id) { + if (col_id && !col_id->empty()) { + base_table_name->alias = Identifier(*col_id); } return base_table_name; } @@ -35,6 +43,7 @@ vector> PEGTransformerFactory::TransformDeleteUsingClause(P } unique_ptr PEGTransformerFactory::TransformTruncateStatement(PEGTransformer &transformer, + const bool &has_result, unique_ptr base_table_name) { auto result = make_uniq(); result->node->table = std::move(base_table_name); diff --git a/src/duckdb/src/parser/peg/transformer/transform_describe.cpp b/src/duckdb/src/parser/peg/transformer/transform_describe.cpp index 6302ea496..570cecad6 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_describe.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_describe.cpp @@ -57,17 +57,21 @@ unique_ptr PEGTransformerFactory::TransformShowAllTables(PEGTransform unique_ptr PEGTransformerFactory::TransformShowQualifiedName(PEGTransformer &transformer, const ShowType &show_or_describe_or_summarize, - DescribeTarget describe_target) { + optional describe_target) { auto showref = make_uniq(); showref->show_type = show_or_describe_or_summarize; + DescribeTarget target; + if (describe_target) { + target = std::move(*describe_target); + } - if (describe_target.is_table_name || describe_target.table_ref) { - if (describe_target.is_table_name) { + if (target.is_table_name || target.table_ref) { + if (target.is_table_name) { // Case: SHOW 'something' or DESCRIBE 'something' - showref->table_name = describe_target.table_name; + showref->table_name = target.table_name; } else { // Case: A relation/table reference - auto &base_table = *describe_target.table_ref; + auto &base_table = *target.table_ref; if (showref->show_type == ShowType::SHOW_FROM) { // Logic for SHOW TABLES FROM [database].[schema] @@ -90,14 +94,14 @@ unique_ptr PEGTransformerFactory::TransformShowQualifiedName(PEGTrans if (showref->table_name.empty() && showref->show_type != ShowType::SHOW_FROM) { auto show_select_node = make_uniq(); show_select_node->select_list.push_back(make_uniq()); - if (describe_target.is_table_name) { + if (target.is_table_name) { // Case: SHOW 'something' or DESCRIBE 'something' auto table_ref = make_uniq(); - table_ref->table_name = describe_target.table_name; + table_ref->table_name = target.table_name; show_select_node->from_table = std::move(table_ref); } else { // Case: A relation/table reference - show_select_node->from_table = std::move(describe_target.table_ref); + show_select_node->from_table = std::move(target.table_ref); } showref->query = std::move(show_select_node); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_detach.cpp b/src/duckdb/src/parser/peg/transformer/transform_detach.cpp index 6ada2bd52..24db3ac84 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_detach.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_detach.cpp @@ -4,7 +4,8 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformDetachStatement(PEGTransformer &transformer, - const bool &if_exists, + const bool &has_result, + const optional &if_exists, const Identifier &catalog_name) { auto result = make_uniq(); auto info = make_uniq(); diff --git a/src/duckdb/src/parser/peg/transformer/transform_drop.cpp b/src/duckdb/src/parser/peg/transformer/transform_drop.cpp index d835c7cc5..050d3a276 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_drop.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_drop.cpp @@ -6,14 +6,16 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformDropStatement(PEGTransformer &transformer, unique_ptr drop_entries, - const bool &drop_behavior) { - drop_entries->info->cascade = drop_behavior; + const optional &drop_behavior) { + if (drop_behavior) { + drop_entries->info->cascade = *drop_behavior; + } return std::move(drop_entries); } unique_ptr PEGTransformerFactory::TransformDropTable(PEGTransformer &transformer, const CatalogType &table_or_view, - const bool &if_exists, + const optional &if_exists, vector> base_table_name) { auto result = make_uniq(); auto info = make_uniq(); @@ -44,7 +46,7 @@ bool PEGTransformerFactory::TransformFunctionTypeFunction(PEGTransformer &transf unique_ptr PEGTransformerFactory::TransformDropTableFunction(PEGTransformer &transformer, const CatalogType &comment_macro_table, - const bool &if_exists, + const optional &if_exists, const vector &table_function_name) { auto result = make_uniq(); auto info = make_uniq(); @@ -62,7 +64,8 @@ PEGTransformerFactory::TransformDropTableFunction(PEGTransformer &transformer, c unique_ptr PEGTransformerFactory::TransformDropFunction(PEGTransformer &transformer, const bool &function_type_macro, - const bool &if_exists, const vector &function_identifier) { + const optional &if_exists, + const vector &function_identifier) { auto result = make_uniq(); auto info = make_uniq(); auto catalog_type = CatalogType::MACRO_ENTRY; @@ -80,7 +83,7 @@ PEGTransformerFactory::TransformDropFunction(PEGTransformer &transformer, const } unique_ptr -PEGTransformerFactory::TransformDropSchema(PEGTransformer &transformer, const bool &if_exists, +PEGTransformerFactory::TransformDropSchema(PEGTransformer &transformer, const optional &if_exists, const vector &qualified_schema_name) { auto result = make_uniq(); auto info = make_uniq(); @@ -113,7 +116,8 @@ QualifiedName PEGTransformerFactory::TransformCatalogReservedSchema(PEGTransform return result; } -unique_ptr PEGTransformerFactory::TransformDropIndex(PEGTransformer &transformer, const bool &if_exists, +unique_ptr PEGTransformerFactory::TransformDropIndex(PEGTransformer &transformer, + const optional &if_exists, const vector &qualified_index_name) { auto result = make_uniq(); auto info = make_uniq(); @@ -160,7 +164,7 @@ QualifiedName PEGTransformerFactory::TransformCatalogReservedSchemaIndex( } unique_ptr -PEGTransformerFactory::TransformDropSequence(PEGTransformer &transformer, const bool &if_exists, +PEGTransformerFactory::TransformDropSequence(PEGTransformer &transformer, const optional &if_exists, const vector &qualified_sequence_name) { auto result = make_uniq(); auto info = make_uniq(); @@ -186,7 +190,7 @@ Identifier PEGTransformerFactory::TransformCollationName(PEGTransformer &transfo } unique_ptr PEGTransformerFactory::TransformDropCollation(PEGTransformer &transformer, - const bool &if_exists, + const optional &if_exists, const vector &collation_name) { throw NotImplementedException("Cannot drop collation yet"); /* @@ -206,7 +210,8 @@ unique_ptr PEGTransformerFactory::TransformDropCollation(PEGTrans */ } -unique_ptr PEGTransformerFactory::TransformDropType(PEGTransformer &transformer, const bool &if_exists, +unique_ptr PEGTransformerFactory::TransformDropType(PEGTransformer &transformer, + const optional &if_exists, const vector &qualified_type_name) { auto result = make_uniq(); auto info = make_uniq(); @@ -240,19 +245,23 @@ bool PEGTransformerFactory::TransformIfExists(PEGTransformer &transformer) { } unique_ptr PEGTransformerFactory::TransformDropSecret(PEGTransformer &transformer, - const SecretPersistType &temporary, - const bool &if_exists, + const optional &temporary, + const optional &if_exists, const Identifier &secret_name, - const Identifier &drop_secret_storage) { + const optional &drop_secret_storage) { auto result = make_uniq(); auto info = make_uniq(); info->type = CatalogType::SECRET_ENTRY; auto extra_drop_info = make_uniq(); - extra_drop_info->persist_mode = temporary; + if (temporary) { + extra_drop_info->persist_mode = *temporary; + } info->if_not_found = if_exists ? OnEntryNotFound::RETURN_NULL : OnEntryNotFound::THROW_EXCEPTION; info->name = secret_name; - extra_drop_info->secret_storage = drop_secret_storage.GetIdentifierName(); + if (drop_secret_storage) { + extra_drop_info->secret_storage = drop_secret_storage->GetIdentifierName(); + } info->extra_drop_info = std::move(extra_drop_info); result->info = std::move(info); return result; @@ -264,7 +273,7 @@ Identifier PEGTransformerFactory::TransformDropSecretStorage(PEGTransformer &tra } unique_ptr PEGTransformerFactory::TransformDropTrigger(PEGTransformer &transformer, - const bool &if_exists, + const optional &if_exists, const Identifier &trigger_name, unique_ptr base_table_name) { auto result = make_uniq(); diff --git a/src/duckdb/src/parser/peg/transformer/transform_execute.cpp b/src/duckdb/src/parser/peg/transformer/transform_execute.cpp index 1275e60fc..bd012cdf8 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_execute.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_execute.cpp @@ -5,23 +5,24 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformExecuteStatement(PEGTransformer &transformer, const Identifier &identifier, - vector table_function_arguments) { + optional> table_function_arguments) { auto result = make_uniq(); result->name = identifier; - if (table_function_arguments.empty()) { + if (!table_function_arguments) { return std::move(result); } idx_t param_idx = 0; - for (idx_t i = 0; i < table_function_arguments.size(); i++) { - auto &arg = table_function_arguments[i]; - if (!table_function_arguments[i].GetExpression().IsScalar()) { + auto &arguments = *table_function_arguments; + for (idx_t i = 0; i < arguments.size(); i++) { + auto &arg = arguments[i]; + if (!arguments[i].GetExpression().IsScalar()) { throw InvalidInputException("Only scalar parameters, named parameters or NULL supported for EXECUTE"); } - if (!table_function_arguments[i].GetName().empty() && param_idx != 0) { + if (!arguments[i].GetName().empty() && param_idx != 0) { throw NotImplementedException("Mixing named parameters and positional parameters is not supported yet"); } auto param_name = arg.GetName(); - if (table_function_arguments[i].GetName().empty()) { + if (arguments[i].GetName().empty()) { param_name = Identifier(std::to_string(param_idx + 1)); if (param_idx != i) { throw NotImplementedException("Mixing named parameters and positional parameters is not supported yet"); diff --git a/src/duckdb/src/parser/peg/transformer/transform_explain.cpp b/src/duckdb/src/parser/peg/transformer/transform_explain.cpp index cd38816bf..4c28a5bfb 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_explain.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_explain.cpp @@ -14,14 +14,14 @@ ProfilerPrintFormat ParseProfilerPrintFormat(const Value &val) { } unique_ptr -PEGTransformerFactory::TransformExplainStatement(PEGTransformer &transformer, const bool &explain_analyze, - const vector &explain_option_list, +PEGTransformerFactory::TransformExplainStatement(PEGTransformer &transformer, const optional &explain_analyze, + const optional> &explain_option_list, unique_ptr explainable_statements) { auto explain_type = explain_analyze ? ExplainType::EXPLAIN_ANALYZE : ExplainType::EXPLAIN_STANDARD; bool format_is_set = false; auto format = ProfilerPrintFormat::Default(); - if (!explain_option_list.empty()) { - for (auto option : explain_option_list) { + if (explain_option_list) { + for (auto option : *explain_option_list) { auto option_name = StringUtil::Lower(option.name.GetIdentifierName()); if (option_name == "format") { if (format_is_set) { @@ -58,18 +58,19 @@ PEGTransformerFactory::TransformExplainOptionList(PEGTransformer &transformer, GenericCopyOption PEGTransformerFactory::TransformExplainOption(PEGTransformer &transformer, const Identifier &explain_option_name, - unique_ptr expression) { + optional> expression) { GenericCopyOption copy_option; copy_option.name = Identifier(StringUtil::Lower(explain_option_name.GetIdentifierName())); if (!expression) { return copy_option; } - if (expression->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - copy_option.children.push_back(Value(expression->Cast().GetValue())); - } else if (expression->GetExpressionType() == ExpressionType::COLUMN_REF) { - copy_option.children.push_back(Value(expression->Cast().GetColumnName())); + auto &expr = *expression; + if (expr->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { + copy_option.children.push_back(Value(expr->Cast().GetValue())); + } else if (expr->GetExpressionType() == ExpressionType::COLUMN_REF) { + copy_option.children.push_back(Value(expr->Cast().GetColumnName())); } else { - copy_option.expression = std::move(expression); + copy_option.expression = std::move(expr); } return copy_option; } diff --git a/src/duckdb/src/parser/peg/transformer/transform_export.cpp b/src/duckdb/src/parser/peg/transformer/transform_export.cpp index 0f7e53ea2..1d718255a 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_export.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_export.cpp @@ -6,27 +6,31 @@ namespace duckdb { unique_ptr -PEGTransformerFactory::TransformExportStatement(PEGTransformer &transformer, const string &export_source, +PEGTransformerFactory::TransformExportStatement(PEGTransformer &transformer, const optional &export_source, const string &string_literal, - const vector &generic_copy_option_list) { + const optional> &generic_copy_option_list) { auto info = make_uniq(); info->file_path = string_literal; info->format = "csv"; info->is_from = false; - for (const auto &option : generic_copy_option_list) { - if (option.name == "format") { - info->format = option.children[0].GetValue(); - info->is_format_auto_detected = false; - } else if (option.expression) { - info->parsed_options[StringUtil::Upper(option.name.GetIdentifierName())] = option.expression->Copy(); - } else { - info->options[StringUtil::Upper(option.name.GetIdentifierName())] = option.children; + if (generic_copy_option_list) { + for (const auto &option : *generic_copy_option_list) { + if (option.name == "format") { + info->format = option.children[0].GetValue(); + info->is_format_auto_detected = false; + } else if (option.expression) { + info->parsed_options[StringUtil::Upper(option.name.GetIdentifierName())] = option.expression->Copy(); + } else { + info->options[StringUtil::Upper(option.name.GetIdentifierName())] = option.children; + } } } auto result = make_uniq(std::move(info)); - result->database = export_source; + if (export_source) { + result->database = *export_source; + } return std::move(result); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_generated.cpp b/src/duckdb/src/parser/peg/transformer/transform_generated.cpp index 4c030764d..ddd7bc131 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_generated.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_generated.cpp @@ -22,7 +22,7 @@ unique_ptr PEGTransformerFactory::TransformAlterOptionsInt unique_ptr PEGTransformerFactory::TransformAlterTableStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -44,7 +44,7 @@ unique_ptr PEGTransformerFactory::TransformAlterTableStmtI unique_ptr PEGTransformerFactory::TransformAlterSchemaStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -75,14 +75,17 @@ unique_ptr PEGTransformerFactory::TransformAddConstraintIn unique_ptr PEGTransformerFactory::TransformAddColumnInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_not_exists {}; + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(2).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); if_not_exists = if_not_exists_value; } auto add_column_entry = transformer.Transform(list_pr.GetChild(3)); - auto result = TransformAddColumn(transformer, if_not_exists, std::move(add_column_entry)); + auto result = TransformAddColumn(transformer, has_result, if_not_exists, std::move(add_column_entry)); return make_uniq>>(std::move(result)); } @@ -90,20 +93,20 @@ unique_ptr PEGTransformerFactory::TransformAddColumnEntryI ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto dotted_identifier = transformer.Transform>(list_pr.GetChild(0)); - LogicalType type {}; + optional type {}; auto &type_opt = list_pr.GetChild(1).Cast(); if (type_opt.HasResult()) { auto type_value = transformer.Transform(type_opt.GetResult()); type = type_value; } - GeneratedColumnDefinition generated_column {}; + optional generated_column {}; auto &generated_column_opt = list_pr.GetChild(2).Cast(); if (generated_column_opt.HasResult()) { auto generated_column_value = transformer.Transform(generated_column_opt.GetResult()); generated_column = std::move(generated_column_value); } - vector column_constraint {}; + optional> column_constraint {}; auto &column_constraint_opt = list_pr.GetChild(3).Cast(); if (column_constraint_opt.HasResult()) { vector column_constraint_value; @@ -123,45 +126,55 @@ unique_ptr PEGTransformerFactory::TransformAddColumnEntryI unique_ptr PEGTransformerFactory::TransformDropColumnInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(2).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); if_exists = if_exists_value; } auto nested_column_name = transformer.Transform>(list_pr.GetChild(3)); - bool drop_behavior {}; + optional drop_behavior {}; auto &drop_behavior_opt = list_pr.GetChild(4).Cast(); if (drop_behavior_opt.HasResult()) { auto drop_behavior_value = transformer.Transform(drop_behavior_opt.GetResult()); drop_behavior = drop_behavior_value; } - auto result = TransformDropColumn(transformer, if_exists, std::move(nested_column_name), drop_behavior); + auto result = TransformDropColumn(transformer, has_result, if_exists, std::move(nested_column_name), drop_behavior); return make_uniq>>(std::move(result)); } unique_ptr PEGTransformerFactory::TransformAlterColumnInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); auto nested_column_name = transformer.Transform>(list_pr.GetChild(2)); auto alter_column_entry = transformer.Transform>(list_pr.GetChild(3)); - auto result = TransformAlterColumn(transformer, std::move(nested_column_name), std::move(alter_column_entry)); + auto result = + TransformAlterColumn(transformer, has_result, std::move(nested_column_name), std::move(alter_column_entry)); return make_uniq>>(std::move(result)); } unique_ptr PEGTransformerFactory::TransformRenameColumnInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); auto nested_column_name = transformer.Transform>(list_pr.GetChild(2)); auto identifier = list_pr.GetChild(4).Cast().identifier; - auto result = TransformRenameColumn(transformer, std::move(nested_column_name), identifier); + auto result = TransformRenameColumn(transformer, has_result, std::move(nested_column_name), identifier); return make_uniq>>(std::move(result)); } unique_ptr PEGTransformerFactory::TransformNestedColumnNameInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector identifier_dot {}; + optional> identifier_dot {}; auto &identifier_dot_opt = list_pr.GetChild(0).Cast(); if (identifier_dot_opt.HasResult()) { vector identifier_dot_value; @@ -306,20 +319,23 @@ unique_ptr PEGTransformerFactory::TransformSetNullabilityI unique_ptr PEGTransformerFactory::TransformAlterTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - LogicalType type {}; + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(0).Cast(); + has_result = has_result_opt.HasResult(); + optional type {}; auto &type_opt = list_pr.GetChild(2).Cast(); if (type_opt.HasResult()) { auto type_value = transformer.Transform(type_opt.GetResult()); type = type_value; } - unique_ptr using_expression {}; + optional> using_expression {}; auto &using_expression_opt = list_pr.GetChild(3).Cast(); if (using_expression_opt.HasResult()) { auto using_expression_value = transformer.Transform>(using_expression_opt.GetResult()); using_expression = std::move(using_expression_value); } - auto result = TransformAlterType(transformer, type, std::move(using_expression)); + auto result = TransformAlterType(transformer, has_result, type, std::move(using_expression)); return make_uniq>>(std::move(result)); } @@ -334,7 +350,7 @@ unique_ptr PEGTransformerFactory::TransformUsingExpression unique_ptr PEGTransformerFactory::TransformAlterViewStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -349,7 +365,7 @@ unique_ptr PEGTransformerFactory::TransformAlterViewStmtIn unique_ptr PEGTransformerFactory::TransformAlterSequenceStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -365,13 +381,13 @@ unique_ptr PEGTransformerFactory::TransformAlterSequenceSt unique_ptr PEGTransformerFactory::TransformQualifiedSequenceNameInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - Identifier catalog_qualification {}; + optional catalog_qualification {}; auto &catalog_qualification_opt = list_pr.GetChild(0).Cast(); if (catalog_qualification_opt.HasResult()) { auto catalog_qualification_value = transformer.Transform(catalog_qualification_opt.GetResult()); catalog_qualification = catalog_qualification_value; } - Identifier schema_qualification {}; + optional schema_qualification {}; auto &schema_qualification_opt = list_pr.GetChild(1).Cast(); if (schema_qualification_opt.HasResult()) { auto schema_qualification_value = transformer.Transform(schema_qualification_opt.GetResult()); @@ -408,7 +424,7 @@ unique_ptr PEGTransformerFactory::TransformSetSequenceOpti unique_ptr PEGTransformerFactory::TransformAlterDatabaseStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -423,13 +439,13 @@ unique_ptr PEGTransformerFactory::TransformAlterDatabaseSt unique_ptr PEGTransformerFactory::TransformAnalyzeStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool analyze_verbose {}; + optional analyze_verbose {}; auto &analyze_verbose_opt = list_pr.GetChild(1).Cast(); if (analyze_verbose_opt.HasResult()) { auto analyze_verbose_value = transformer.Transform(analyze_verbose_opt.GetResult()); analyze_verbose = analyze_verbose_value; } - AnalyzeTarget analyze_target {}; + optional analyze_target {}; auto &analyze_target_opt = list_pr.GetChild(2).Cast(); if (analyze_target_opt.HasResult()) { auto analyze_target_value = transformer.Transform(analyze_target_opt.GetResult()); @@ -443,7 +459,7 @@ unique_ptr PEGTransformerFactory::TransformAnalyzeTargetIn ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); - vector name_list {}; + optional> name_list {}; auto &name_list_opt = list_pr.GetChild(1).Cast(); if (name_list_opt.HasResult()) { auto name_list_value = transformer.Transform>(name_list_opt.GetResult()); @@ -462,32 +478,35 @@ unique_ptr PEGTransformerFactory::TransformAnalyzeVerboseI unique_ptr PEGTransformerFactory::TransformAttachStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool or_replace {}; + optional or_replace {}; auto &or_replace_opt = list_pr.GetChild(1).Cast(); if (or_replace_opt.HasResult()) { auto or_replace_value = transformer.Transform(or_replace_opt.GetResult()); or_replace = or_replace_value; } - bool if_not_exists {}; + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(2).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); if_not_exists = if_not_exists_value; } + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(3).Cast(); + has_result = has_result_opt.HasResult(); auto database_path = transformer.Transform>(list_pr.GetChild(4)); - Identifier attach_alias {}; + optional attach_alias {}; auto &attach_alias_opt = list_pr.GetChild(5).Cast(); if (attach_alias_opt.HasResult()) { auto attach_alias_value = transformer.Transform(attach_alias_opt.GetResult()); attach_alias = attach_alias_value; } - vector attach_options {}; + optional> attach_options {}; auto &attach_options_opt = list_pr.GetChild(6).Cast(); if (attach_options_opt.HasResult()) { auto attach_options_value = transformer.Transform>(attach_options_opt.GetResult()); attach_options = attach_options_value; } - auto result = TransformAttachStatement(transformer, or_replace, if_not_exists, std::move(database_path), + auto result = TransformAttachStatement(transformer, or_replace, if_not_exists, has_result, std::move(database_path), attach_alias, attach_options); return make_uniq>>(std::move(result)); } @@ -528,13 +547,13 @@ unique_ptr PEGTransformerFactory::TransformCallStatementIn unique_ptr PEGTransformerFactory::TransformCheckpointStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool checkpoint_force {}; + optional checkpoint_force {}; auto &checkpoint_force_opt = list_pr.GetChild(0).Cast(); if (checkpoint_force_opt.HasResult()) { auto checkpoint_force_value = transformer.Transform(checkpoint_force_opt.GetResult()); checkpoint_force = checkpoint_force_value; } - Identifier catalog_name {}; + optional catalog_name {}; auto &catalog_name_opt = list_pr.GetChild(2).Cast(); if (catalog_name_opt.HasResult()) { auto catalog_name_value = catalog_name_opt.GetResult().Cast().identifier; @@ -675,7 +694,7 @@ unique_ptr PEGTransformerFactory::TransformTypeInternal(PE ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto type_variations = transformer.Transform>(list_pr.GetChild(0)); - vector array_bounds {}; + optional> array_bounds {}; auto &array_bounds_opt = list_pr.GetChild(1).Cast(); if (array_bounds_opt.HasResult()) { vector array_bounds_value; @@ -709,15 +728,14 @@ unique_ptr PEGTransformerFactory::TransformSimpleTypeInter unique_ptr PEGTransformerFactory::TransformCharacterSimpleTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - auto character_type = transformer.Transform(list_pr.GetChild(0)); - vector> type_modifiers {}; + optional>> type_modifiers {}; auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); if (type_modifiers_opt.HasResult()) { auto type_modifiers_value = transformer.Transform>>(type_modifiers_opt.GetResult()); type_modifiers = std::move(type_modifiers_value); } - auto result = TransformCharacterSimpleType(transformer, character_type, std::move(type_modifiers)); + auto result = TransformCharacterSimpleType(transformer, std::move(type_modifiers)); return make_uniq>>(std::move(result)); } @@ -725,7 +743,7 @@ unique_ptr PEGTransformerFactory::TransformQualifiedSimpleTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto qualified_type_name = transformer.Transform(list_pr.GetChild(0)); - vector> type_modifiers {}; + optional>> type_modifiers {}; auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); if (type_modifiers_opt.HasResult()) { auto type_modifiers_value = @@ -736,12 +754,6 @@ PEGTransformerFactory::TransformQualifiedSimpleTypeInternal(PEGTransformer &tran return make_uniq>>(std::move(result)); } -unique_ptr PEGTransformerFactory::TransformCharacterTypeInternal(PEGTransformer &transformer, - ParseResult &parse_result) { - auto result = TransformCharacterType(transformer); - return make_uniq>(result); -} - unique_ptr PEGTransformerFactory::TransformIntervalTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); @@ -951,7 +963,10 @@ unique_ptr PEGTransformerFactory::TransformMinuteToSecondI unique_ptr PEGTransformerFactory::TransformBitTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector> expression {}; + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); + optional>> expression {}; auto &expression_opt = list_pr.GetChild(2).Cast(); if (expression_opt.HasResult()) { vector> expression_value; @@ -964,14 +979,14 @@ unique_ptr PEGTransformerFactory::TransformBitTypeInternal } expression = std::move(expression_value); } - auto result = TransformBitType(transformer, std::move(expression)); + auto result = TransformBitType(transformer, has_result, std::move(expression)); return make_uniq>>(std::move(result)); } unique_ptr PEGTransformerFactory::TransformGeometryTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - unique_ptr expression {}; + optional> expression {}; auto &expression_opt = list_pr.GetChild(1).Cast(); if (expression_opt.HasResult()) { auto expression_value = @@ -1068,7 +1083,7 @@ unique_ptr PEGTransformerFactory::TransformDoubleTypeInter unique_ptr PEGTransformerFactory::TransformFloatTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - unique_ptr number_literal {}; + optional> number_literal {}; auto &number_literal_opt = list_pr.GetChild(1).Cast(); if (number_literal_opt.HasResult()) { auto number_literal_value = transformer.Transform>( @@ -1082,7 +1097,7 @@ unique_ptr PEGTransformerFactory::TransformFloatTypeIntern unique_ptr PEGTransformerFactory::TransformDecimalTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector> type_modifiers {}; + optional>> type_modifiers {}; auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); if (type_modifiers_opt.HasResult()) { auto type_modifiers_value = @@ -1096,7 +1111,7 @@ unique_ptr PEGTransformerFactory::TransformDecimalTypeInte unique_ptr PEGTransformerFactory::TransformDecTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector> type_modifiers {}; + optional>> type_modifiers {}; auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); if (type_modifiers_opt.HasResult()) { auto type_modifiers_value = @@ -1110,7 +1125,7 @@ unique_ptr PEGTransformerFactory::TransformDecTypeInternal unique_ptr PEGTransformerFactory::TransformNumericModTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector> type_modifiers {}; + optional>> type_modifiers {}; auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); if (type_modifiers_opt.HasResult()) { auto type_modifiers_value = @@ -1162,7 +1177,7 @@ PEGTransformerFactory::TransformSchemaReservedTypeNameInternal(PEGTransformer &t unique_ptr PEGTransformerFactory::TransformTypeModifiersInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector> expression {}; + optional>> expression {}; auto &expression_opt = ExtractResultFromParens(list_pr.GetChild(0)).Cast(); if (expression_opt.HasResult()) { vector> expression_value; @@ -1181,7 +1196,13 @@ unique_ptr PEGTransformerFactory::TransformTypeModifiersIn unique_ptr PEGTransformerFactory::TransformRowTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - auto col_id_type_list = transformer.Transform>(list_pr.GetChild(1)); + child_list_t col_id_type_list {}; + auto &col_id_type_list_opt = list_pr.GetChild(1).Cast(); + if (col_id_type_list_opt.HasResult()) { + auto col_id_type_list_value = + transformer.Transform>(col_id_type_list_opt.GetResult()); + col_id_type_list = col_id_type_list_value; + } auto result = TransformRowType(transformer, col_id_type_list); return make_uniq>>(std::move(result)); } @@ -1254,7 +1275,7 @@ unique_ptr PEGTransformerFactory::TransformArrayKeywordInt unique_ptr PEGTransformerFactory::TransformSquareBracketsArrayInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - unique_ptr expression {}; + optional> expression {}; auto &expression_opt = list_pr.GetChild(1).Cast(); if (expression_opt.HasResult()) { auto expression_value = transformer.Transform>(expression_opt.GetResult()); @@ -1268,14 +1289,14 @@ unique_ptr PEGTransformerFactory::TransformTimeTypeInterna ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto time_or_timestamp = transformer.Transform(list_pr.GetChild(0)); - vector> type_modifiers {}; + optional>> type_modifiers {}; auto &type_modifiers_opt = list_pr.GetChild(1).Cast(); if (type_modifiers_opt.HasResult()) { auto type_modifiers_value = transformer.Transform>>(type_modifiers_opt.GetResult()); type_modifiers = std::move(type_modifiers_value); } - bool time_zone {}; + optional time_zone {}; auto &time_zone_opt = list_pr.GetChild(2).Cast(); if (time_zone_opt.HasResult()) { auto time_zone_value = transformer.Transform(time_zone_opt.GetResult()); @@ -1333,12 +1354,55 @@ unique_ptr PEGTransformerFactory::TransformWithoutRuleInte return make_uniq>(result); } +unique_ptr PEGTransformerFactory::TransformConnectStatementInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + optional> session_target {}; + auto &session_target_opt = list_pr.GetChild(1).Cast(); + if (session_target_opt.HasResult()) { + auto session_target_value = transformer.Transform>(session_target_opt.GetResult()); + session_target = std::move(session_target_value); + } + auto result = TransformConnectStatement(transformer, std::move(session_target)); + return make_uniq>>(std::move(result)); +} + unique_ptr PEGTransformerFactory::TransformDisconnectStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto result = TransformDisconnectStatement(transformer); return make_uniq>>(std::move(result)); } +unique_ptr PEGTransformerFactory::TransformSessionTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto &choice_pr = list_pr.Child(0); + auto result = transformer.Transform>(choice_pr.GetResult()); + return make_uniq>>(std::move(result)); +} + +unique_ptr PEGTransformerFactory::TransformLocalSessionTargetInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto result = TransformLocalSessionTarget(transformer); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformStringSessionTargetInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto string_literal = transformer.Transform(list_pr.GetChild(0)); + auto result = TransformStringSessionTarget(transformer, string_literal); + return make_uniq>>(std::move(result)); +} + +unique_ptr +PEGTransformerFactory::TransformCatalogSessionTargetInternal(PEGTransformer &transformer, ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto catalog_name = list_pr.GetChild(0).Cast().identifier; + auto result = TransformCatalogSessionTarget(transformer, catalog_name); + return make_uniq>>(std::move(result)); +} + unique_ptr PEGTransformerFactory::TransformCopyStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); @@ -1359,7 +1423,7 @@ unique_ptr PEGTransformerFactory::TransformCopyTableIntern ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); - vector insert_column_list {}; + optional> insert_column_list {}; auto &insert_column_list_opt = list_pr.GetChild(1).Cast(); if (insert_column_list_opt.HasResult()) { auto insert_column_list_value = transformer.Transform>(insert_column_list_opt.GetResult()); @@ -1367,7 +1431,7 @@ unique_ptr PEGTransformerFactory::TransformCopyTableIntern } auto from_or_to = transformer.Transform(list_pr.GetChild(2)); auto copy_file_name = transformer.Transform>(list_pr.GetChild(3)); - vector copy_options {}; + optional> copy_options {}; auto ©_options_opt = list_pr.GetChild(4).Cast(); if (copy_options_opt.HasResult()) { auto copy_options_value = transformer.Transform>(copy_options_opt.GetResult()); @@ -1404,7 +1468,7 @@ unique_ptr PEGTransformerFactory::TransformCopySelectInter auto select_statement_internal = transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(0))); auto copy_file_name = transformer.Transform>(list_pr.GetChild(2)); - vector copy_options {}; + optional> copy_options {}; auto ©_options_opt = list_pr.GetChild(3).Cast(); if (copy_options_opt.HasResult()) { auto copy_options_value = transformer.Transform>(copy_options_opt.GetResult()); @@ -1469,8 +1533,11 @@ unique_ptr PEGTransformerFactory::TransformIdentifierColId unique_ptr PEGTransformerFactory::TransformCopyOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(0).Cast(); + has_result = has_result_opt.HasResult(); auto copy_option_list = transformer.Transform>(list_pr.GetChild(1)); - auto result = TransformCopyOptions(transformer, copy_option_list); + auto result = TransformCopyOptions(transformer, has_result, copy_option_list); return make_uniq>>(result); } @@ -1485,7 +1552,7 @@ unique_ptr PEGTransformerFactory::TransformCopyOptionListI unique_ptr PEGTransformerFactory::TransformSpecializedOptionListInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector specialized_option {}; + optional> specialized_option {}; auto &specialized_option_opt = list_pr.GetChild(0).Cast(); if (specialized_option_opt.HasResult()) { vector specialized_option_value; @@ -1550,32 +1617,44 @@ unique_ptr PEGTransformerFactory::TransformHeaderOptionInt unique_ptr PEGTransformerFactory::TransformNullAsOptionInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); auto string_literal = transformer.Transform(list_pr.GetChild(2)); - auto result = TransformNullAsOption(transformer, string_literal); + auto result = TransformNullAsOption(transformer, has_result, string_literal); return make_uniq>(result); } unique_ptr PEGTransformerFactory::TransformDelimiterAsOptionInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); auto string_literal = transformer.Transform(list_pr.GetChild(2)); - auto result = TransformDelimiterAsOption(transformer, string_literal); + auto result = TransformDelimiterAsOption(transformer, has_result, string_literal); return make_uniq>(result); } unique_ptr PEGTransformerFactory::TransformQuoteAsOptionInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); auto string_literal = transformer.Transform(list_pr.GetChild(2)); - auto result = TransformQuoteAsOption(transformer, string_literal); + auto result = TransformQuoteAsOption(transformer, has_result, string_literal); return make_uniq>(result); } unique_ptr PEGTransformerFactory::TransformEscapeAsOptionInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); auto string_literal = transformer.Transform(list_pr.GetChild(2)); - auto result = TransformEscapeAsOption(transformer, string_literal); + auto result = TransformEscapeAsOption(transformer, has_result, string_literal); return make_uniq>(result); } @@ -1590,7 +1669,7 @@ unique_ptr PEGTransformerFactory::TransformEncodingOptionI unique_ptr PEGTransformerFactory::TransformForceQuoteOptionInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool force_quote {}; + optional force_quote {}; auto &force_quote_opt = list_pr.GetChild(0).Cast(); if (force_quote_opt.HasResult()) { auto force_quote_value = transformer.Transform(force_quote_opt.GetResult()); @@ -1630,7 +1709,7 @@ unique_ptr PEGTransformerFactory::TransformPartitionByOpti unique_ptr PEGTransformerFactory::TransformForceNullOptionInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool force_not_null {}; + optional force_not_null {}; auto &force_not_null_opt = list_pr.GetChild(1).Cast(); if (force_not_null_opt.HasResult()) { auto force_not_null_value = transformer.Transform(force_not_null_opt.GetResult()); @@ -1664,7 +1743,7 @@ unique_ptr PEGTransformerFactory::TransformGenericCopyOpti ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto copy_option_name = list_pr.GetChild(0).Cast().identifier; - GenericCopyOptionValue generic_copy_option_value {}; + optional generic_copy_option_value {}; auto &generic_copy_option_value_opt = list_pr.GetChild(1).Cast(); if (generic_copy_option_value_opt.HasResult()) { auto generic_copy_option_value_value = @@ -1774,38 +1853,38 @@ unique_ptr PEGTransformerFactory::TransformCopyDataInterna unique_ptr PEGTransformerFactory::TransformCreateIndexStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool unique_index {}; + optional unique_index {}; auto &unique_index_opt = list_pr.GetChild(0).Cast(); if (unique_index_opt.HasResult()) { auto unique_index_value = transformer.Transform(unique_index_opt.GetResult()); unique_index = unique_index_value; } - bool if_not_exists {}; + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(2).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); if_not_exists = if_not_exists_value; } - Identifier index_name {}; + optional index_name {}; auto &index_name_opt = list_pr.GetChild(3).Cast(); if (index_name_opt.HasResult()) { auto index_name_value = index_name_opt.GetResult().Cast().identifier; index_name = index_name_value; } auto base_table_name = transformer.Transform>(list_pr.GetChild(5)); - vector insert_column_list {}; + optional> insert_column_list {}; auto &insert_column_list_opt = list_pr.GetChild(6).Cast(); if (insert_column_list_opt.HasResult()) { auto insert_column_list_value = transformer.Transform>(insert_column_list_opt.GetResult()); insert_column_list = insert_column_list_value; } - Identifier index_type {}; + optional index_type {}; auto &index_type_opt = list_pr.GetChild(7).Cast(); if (index_type_opt.HasResult()) { auto index_type_value = transformer.Transform(index_type_opt.GetResult()); index_type = index_type_value; } - vector> index_element {}; + optional>> index_element {}; auto &index_element_opt = list_pr.GetChild(8).Cast(); if (index_element_opt.HasResult()) { vector> index_element_value; @@ -1818,14 +1897,14 @@ unique_ptr PEGTransformerFactory::TransformCreateIndexStmt } index_element = std::move(index_element_value); } - case_insensitive_map_t> with_list {}; + optional>> with_list {}; auto &with_list_opt = list_pr.GetChild(9).Cast(); if (with_list_opt.HasResult()) { auto with_list_value = transformer.Transform>>(with_list_opt.GetResult()); with_list = std::move(with_list_value); } - unique_ptr where_clause {}; + optional> where_clause {}; auto &where_clause_opt = list_pr.GetChild(10).Cast(); if (where_clause_opt.HasResult()) { auto where_clause_value = transformer.Transform>(where_clause_opt.GetResult()); @@ -1900,13 +1979,13 @@ unique_ptr PEGTransformerFactory::TransformIndexElementInt ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto expression = transformer.Transform>(list_pr.GetChild(0)); - OrderType desc_or_asc {}; + optional desc_or_asc {}; auto &desc_or_asc_opt = list_pr.GetChild(1).Cast(); if (desc_or_asc_opt.HasResult()) { auto desc_or_asc_value = transformer.Transform(desc_or_asc_opt.GetResult()); desc_or_asc = desc_or_asc_value; } - OrderByNullType nulls_first_or_last {}; + optional nulls_first_or_last {}; auto &nulls_first_or_last_opt = list_pr.GetChild(2).Cast(); if (nulls_first_or_last_opt.HasResult()) { auto nulls_first_or_last_value = transformer.Transform(nulls_first_or_last_opt.GetResult()); @@ -1934,7 +2013,7 @@ unique_ptr PEGTransformerFactory::TransformRelOptionIntern ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto rel_option_name = transformer.Transform(list_pr.GetChild(0)); - unique_ptr rel_option_argument_opt {}; + optional> rel_option_argument_opt {}; auto &rel_option_argument_opt_opt = list_pr.GetChild(1).Cast(); if (rel_option_argument_opt_opt.HasResult()) { auto rel_option_argument_opt_value = @@ -2022,7 +2101,7 @@ unique_ptr PEGTransformerFactory::TransformCreateMacroStmt ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto macro_or_function = transformer.Transform(list_pr.GetChild(0)); - bool if_not_exists {}; + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); @@ -2063,7 +2142,7 @@ unique_ptr PEGTransformerFactory::TransformFunctionKeyword unique_ptr PEGTransformerFactory::TransformMacroDefinitionInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector macro_parameters {}; + optional> macro_parameters {}; auto ¯o_parameters_opt = ExtractResultFromParens(list_pr.GetChild(0)).Cast(); if (macro_parameters_opt.HasResult()) { auto macro_parameters_value = transformer.Transform>(macro_parameters_opt.GetResult()); @@ -2107,7 +2186,7 @@ unique_ptr PEGTransformerFactory::TransformSimpleParameter ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto type_func_name = transformer.Transform(list_pr.GetChild(0)); - LogicalType type {}; + optional type {}; auto &type_opt = list_pr.GetChild(1).Cast(); if (type_opt.HasResult()) { auto type_value = transformer.Transform(type_opt.GetResult()); @@ -2136,7 +2215,7 @@ PEGTransformerFactory::TransformTableMacroDefinitionInternal(PEGTransformer &tra unique_ptr PEGTransformerFactory::TransformCreateSchemaStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_not_exists {}; + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); @@ -2150,19 +2229,19 @@ unique_ptr PEGTransformerFactory::TransformCreateSchemaStm unique_ptr PEGTransformerFactory::TransformCreateSecretStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_not_exists {}; + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); if_not_exists = if_not_exists_value; } - Identifier secret_name {}; + optional secret_name {}; auto &secret_name_opt = list_pr.GetChild(2).Cast(); if (secret_name_opt.HasResult()) { auto secret_name_value = transformer.Transform(secret_name_opt.GetResult()); secret_name = secret_name_value; } - Identifier secret_storage_specifier {}; + optional secret_storage_specifier {}; auto &secret_storage_specifier_opt = list_pr.GetChild(3).Cast(); if (secret_storage_specifier_opt.HasResult()) { auto secret_storage_specifier_value = @@ -2194,14 +2273,14 @@ unique_ptr PEGTransformerFactory::TransformSecretNameInter unique_ptr PEGTransformerFactory::TransformCreateSequenceStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_not_exists {}; + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); if_not_exists = if_not_exists_value; } auto qualified_name = transformer.Transform(list_pr.GetChild(2)); - vector>> sequence_option {}; + optional>>> sequence_option {}; auto &sequence_option_opt = list_pr.GetChild(3).Cast(); if (sequence_option_opt.HasResult()) { vector>> sequence_option_value; @@ -2248,8 +2327,11 @@ unique_ptr PEGTransformerFactory::TransformSeqNoCycleInter unique_ptr PEGTransformerFactory::TransformSeqSetIncrementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); auto expression = transformer.Transform>(list_pr.GetChild(2)); - auto result = TransformSeqSetIncrement(transformer, std::move(expression)); + auto result = TransformSeqSetIncrement(transformer, has_result, std::move(expression)); return make_uniq>>>(std::move(result)); } @@ -2273,8 +2355,11 @@ unique_ptr PEGTransformerFactory::TransformSeqNoMinMaxInte unique_ptr PEGTransformerFactory::TransformSeqStartWithInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); auto expression = transformer.Transform>(list_pr.GetChild(2)); - auto result = TransformSeqStartWith(transformer, std::move(expression)); + auto result = TransformSeqStartWith(transformer, has_result, std::move(expression)); return make_uniq>>>(std::move(result)); } @@ -2309,13 +2394,13 @@ unique_ptr PEGTransformerFactory::TransformMaxValueInterna unique_ptr PEGTransformerFactory::TransformCreateStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool or_replace {}; + optional or_replace {}; auto &or_replace_opt = list_pr.GetChild(1).Cast(); if (or_replace_opt.HasResult()) { auto or_replace_value = transformer.Transform(or_replace_opt.GetResult()); or_replace = or_replace_value; } - SecretPersistType temporary {}; + optional temporary {}; auto &temporary_opt = list_pr.GetChild(2).Cast(); if (temporary_opt.HasResult()) { auto temporary_value = transformer.Transform(temporary_opt.GetResult()); @@ -2370,7 +2455,7 @@ PEGTransformerFactory::TransformTemporaryPersistentInternal(PEGTransformer &tran unique_ptr PEGTransformerFactory::TransformCreateTableStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_not_exists {}; + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); @@ -2378,7 +2463,7 @@ unique_ptr PEGTransformerFactory::TransformCreateTableStmt } auto qualified_name = transformer.Transform(list_pr.GetChild(2)); auto create_table_definition = transformer.Transform(list_pr.GetChild(3)); - bool commit_action {}; + optional commit_action {}; auto &commit_action_opt = list_pr.GetChild(4).Cast(); if (commit_action_opt.HasResult()) { auto commit_action_value = transformer.Transform(commit_action_opt.GetResult()); @@ -2400,20 +2485,20 @@ PEGTransformerFactory::TransformCreateTableDefinitionInternal(PEGTransformer &tr unique_ptr PEGTransformerFactory::TransformCreateTableAsInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - ColumnList identifier_list {}; + optional identifier_list {}; auto &identifier_list_opt = list_pr.GetChild(0).Cast(); if (identifier_list_opt.HasResult()) { auto identifier_list_value = transformer.Transform(identifier_list_opt.GetResult()); identifier_list = std::move(identifier_list_value); } - PartitionSortedOptions partition_sorted_options {}; + optional partition_sorted_options {}; auto &partition_sorted_options_opt = list_pr.GetChild(1).Cast(); if (partition_sorted_options_opt.HasResult()) { auto partition_sorted_options_value = transformer.Transform(partition_sorted_options_opt.GetResult()); partition_sorted_options = std::move(partition_sorted_options_value); } - case_insensitive_map_t> with_list {}; + optional>> with_list {}; auto &with_list_opt = list_pr.GetChild(2).Cast(); if (with_list_opt.HasResult()) { auto with_list_value = @@ -2421,7 +2506,7 @@ unique_ptr PEGTransformerFactory::TransformCreateTableAsIn with_list = std::move(with_list_value); } auto statement = transformer.Transform>(list_pr.GetChild(4)); - bool with_data {}; + optional with_data {}; auto &with_data_opt = list_pr.GetChild(5).Cast(); if (with_data_opt.HasResult()) { auto with_data_value = transformer.Transform(with_data_opt.GetResult()); @@ -2445,7 +2530,7 @@ PEGTransformerFactory::TransformPartitionOptSortedOptionsInternal(PEGTransformer ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto partition_options = transformer.Transform>>(list_pr.GetChild(0)); - vector> sorted_options {}; + optional>> sorted_options {}; auto &sorted_options_opt = list_pr.GetChild(1).Cast(); if (sorted_options_opt.HasResult()) { auto sorted_options_value = @@ -2462,7 +2547,7 @@ PEGTransformerFactory::TransformSortedOptPartitionOptionsInternal(PEGTransformer ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto sorted_options = transformer.Transform>>(list_pr.GetChild(0)); - vector> partition_options {}; + optional>> partition_options {}; auto &partition_options_opt = list_pr.GetChild(1).Cast(); if (partition_options_opt.HasResult()) { auto partition_options_value = @@ -2536,21 +2621,21 @@ unique_ptr PEGTransformerFactory::TransformIdentifierListI unique_ptr PEGTransformerFactory::TransformCreateColumnListInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - ColumnElements create_table_column_list {}; + optional create_table_column_list {}; auto &create_table_column_list_opt = ExtractResultFromParens(list_pr.GetChild(0)).Cast(); if (create_table_column_list_opt.HasResult()) { auto create_table_column_list_value = transformer.Transform(create_table_column_list_opt.GetResult()); create_table_column_list = std::move(create_table_column_list_value); } - PartitionSortedOptions partition_sorted_options {}; + optional partition_sorted_options {}; auto &partition_sorted_options_opt = list_pr.GetChild(1).Cast(); if (partition_sorted_options_opt.HasResult()) { auto partition_sorted_options_value = transformer.Transform(partition_sorted_options_opt.GetResult()); partition_sorted_options = std::move(partition_sorted_options_value); } - case_insensitive_map_t> with_list {}; + optional>> with_list {}; auto &with_list_opt = list_pr.GetChild(2).Cast(); if (with_list_opt.HasResult()) { auto with_list_value = @@ -2723,20 +2808,23 @@ unique_ptr PEGTransformerFactory::TransformColumnDefinitio ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto dotted_identifier = transformer.Transform>(list_pr.GetChild(0)); - LogicalType type {}; + optional type {}; auto &type_opt = list_pr.GetChild(1).Cast(); if (type_opt.HasResult()) { auto type_value = transformer.Transform(type_opt.GetResult()); type = type_value; } - GeneratedColumnDefinition generated_column {}; + optional generated_column {}; auto &generated_column_opt = list_pr.GetChild(2).Cast(); if (generated_column_opt.HasResult()) { auto generated_column_value = transformer.Transform(generated_column_opt.GetResult()); generated_column = std::move(generated_column_value); } - vector column_constraint {}; + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(3).Cast(); + has_result = has_result_opt.HasResult(); + optional> column_constraint {}; auto &column_constraint_opt = list_pr.GetChild(4).Cast(); if (column_constraint_opt.HasResult()) { vector column_constraint_value; @@ -2749,7 +2837,7 @@ unique_ptr PEGTransformerFactory::TransformColumnDefinitio column_constraint = std::move(column_constraint_value); } auto result = TransformColumnDefinition(transformer, dotted_identifier, type, std::move(generated_column), - std::move(column_constraint)); + has_result, std::move(column_constraint)); return make_uniq>(std::move(result)); } @@ -2815,7 +2903,7 @@ unique_ptr PEGTransformerFactory::TransformForeignKeyConstraintInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto base_table_name = transformer.Transform>(list_pr.GetChild(1)); - vector column_list {}; + optional> column_list {}; auto &column_list_opt = list_pr.GetChild(2).Cast(); if (column_list_opt.HasResult()) { auto column_list_value = @@ -2846,13 +2934,13 @@ unique_ptr PEGTransformerFactory::TransformColumnCompressi unique_ptr PEGTransformerFactory::TransformKeyActionsInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - string update_action {}; + optional update_action {}; auto &update_action_opt = list_pr.GetChild(0).Cast(); if (update_action_opt.HasResult()) { auto update_action_value = transformer.Transform(update_action_opt.GetResult()); update_action = update_action_value; } - string delete_action {}; + optional delete_action {}; auto &delete_action_opt = list_pr.GetChild(1).Cast(); if (delete_action_opt.HasResult()) { auto delete_action_value = transformer.Transform(delete_action_opt.GetResult()); @@ -2919,8 +3007,11 @@ PEGTransformerFactory::TransformSetDefaultKeyActionInternal(PEGTransformer &tran unique_ptr PEGTransformerFactory::TransformTopLevelConstraintInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(0).Cast(); + has_result = has_result_opt.HasResult(); auto top_level_constraint_list = transformer.Transform>(list_pr.GetChild(1)); - auto result = TransformTopLevelConstraint(transformer, std::move(top_level_constraint_list)); + auto result = TransformTopLevelConstraint(transformer, has_result, std::move(top_level_constraint_list)); return make_uniq>>(std::move(result)); } @@ -2976,7 +3067,7 @@ unique_ptr PEGTransformerFactory::TransformDottedIdentifie ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto identifier = list_pr.GetChild(0).Cast().identifier; - vector dot_col_label {}; + optional> dot_col_label {}; auto &dot_col_label_opt = list_pr.GetChild(1).Cast(); if (dot_col_label_opt.HasResult()) { vector dot_col_label_value; @@ -3044,14 +3135,17 @@ unique_ptr PEGTransformerFactory::TransformTypeFuncNameInt unique_ptr PEGTransformerFactory::TransformGeneratedColumnInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(0).Cast(); + has_result = has_result_opt.HasResult(); auto expression = transformer.Transform>(ExtractResultFromParens(list_pr.GetChild(2))); - bool generated_column_type {}; + optional generated_column_type {}; auto &generated_column_type_opt = list_pr.GetChild(3).Cast(); if (generated_column_type_opt.HasResult()) { auto generated_column_type_value = transformer.Transform(generated_column_type_opt.GetResult()); generated_column_type = generated_column_type_value; } - auto result = TransformGeneratedColumn(transformer, std::move(expression), generated_column_type); + auto result = TransformGeneratedColumn(transformer, has_result, std::move(expression), generated_column_type); return make_uniq>(std::move(result)); } @@ -3106,7 +3200,7 @@ PEGTransformerFactory::TransformStoredGeneratedColumnInternal(PEGTransformer &tr unique_ptr PEGTransformerFactory::TransformCreateTriggerStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_not_exists {}; + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); @@ -3116,14 +3210,14 @@ unique_ptr PEGTransformerFactory::TransformCreateTriggerSt auto trigger_timing = transformer.Transform(list_pr.GetChild(3)); auto trigger_event = transformer.Transform(list_pr.GetChild(4)); auto base_table_name = transformer.Transform>(list_pr.GetChild(6)); - TriggerTableReferencingInfo referencing_clause {}; + optional referencing_clause {}; auto &referencing_clause_opt = list_pr.GetChild(7).Cast(); if (referencing_clause_opt.HasResult()) { auto referencing_clause_value = transformer.Transform(referencing_clause_opt.GetResult()); referencing_clause = referencing_clause_value; } - TriggerForEach for_each_clause {}; + optional for_each_clause {}; auto &for_each_clause_opt = list_pr.GetChild(8).Cast(); if (for_each_clause_opt.HasResult()) { auto for_each_clause_value = transformer.Transform(for_each_clause_opt.GetResult()); @@ -3156,7 +3250,7 @@ unique_ptr PEGTransformerFactory::TransformReferencingClau ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto referencing_item = transformer.Transform(list_pr.GetChild(1)); - TriggerTableReferencingInfo referencing_item_1 {}; + optional referencing_item_1 {}; auto &referencing_item_1_opt = list_pr.GetChild(2).Cast(); if (referencing_item_1_opt.HasResult()) { auto referencing_item_1_value = @@ -3287,7 +3381,7 @@ unique_ptr PEGTransformerFactory::TransformForEachStatemen unique_ptr PEGTransformerFactory::TransformCreateTypeStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_not_exists {}; + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(1).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); @@ -3327,7 +3421,7 @@ unique_ptr PEGTransformerFactory::TransformEnumSelectTypeI unique_ptr PEGTransformerFactory::TransformEnumStringLiteralListInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector string_literal {}; + optional> string_literal {}; auto &string_literal_opt = ExtractResultFromParens(list_pr.GetChild(1)).Cast(); if (string_literal_opt.HasResult()) { vector string_literal_value; @@ -3345,26 +3439,26 @@ PEGTransformerFactory::TransformEnumStringLiteralListInternal(PEGTransformer &tr unique_ptr PEGTransformerFactory::TransformCreateViewStmtInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool create_recursive {}; + optional create_recursive {}; auto &create_recursive_opt = list_pr.GetChild(0).Cast(); if (create_recursive_opt.HasResult()) { auto create_recursive_value = transformer.Transform(create_recursive_opt.GetResult()); create_recursive = create_recursive_value; } - bool if_not_exists {}; + optional if_not_exists {}; auto &if_not_exists_opt = list_pr.GetChild(2).Cast(); if (if_not_exists_opt.HasResult()) { auto if_not_exists_value = transformer.Transform(if_not_exists_opt.GetResult()); if_not_exists = if_not_exists_value; } auto qualified_name = transformer.Transform(list_pr.GetChild(3)); - vector insert_column_list {}; + optional> insert_column_list {}; auto &insert_column_list_opt = list_pr.GetChild(4).Cast(); if (insert_column_list_opt.HasResult()) { auto insert_column_list_value = transformer.Transform>(insert_column_list_opt.GetResult()); insert_column_list = insert_column_list_value; } - case_insensitive_map_t> with_list {}; + optional>> with_list {}; auto &with_list_opt = list_pr.GetChild(5).Cast(); if (with_list_opt.HasResult()) { auto with_list_value = @@ -3387,7 +3481,7 @@ unique_ptr PEGTransformerFactory::TransformCreateRecursive unique_ptr PEGTransformerFactory::TransformDeallocateStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool deallocate_prepare {}; + optional deallocate_prepare {}; auto &deallocate_prepare_opt = list_pr.GetChild(1).Cast(); if (deallocate_prepare_opt.HasResult()) { auto deallocate_prepare_value = transformer.Transform(deallocate_prepare_opt.GetResult()); @@ -3407,27 +3501,27 @@ unique_ptr PEGTransformerFactory::TransformDeallocatePrepa unique_ptr PEGTransformerFactory::TransformDeleteStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - CommonTableExpressionMap with_clause {}; + optional with_clause {}; auto &with_clause_opt = list_pr.GetChild(0).Cast(); if (with_clause_opt.HasResult()) { auto with_clause_value = transformer.Transform(with_clause_opt.GetResult()); with_clause = std::move(with_clause_value); } auto target_opt_alias = transformer.Transform>(list_pr.GetChild(3)); - vector> delete_using_clause {}; + optional>> delete_using_clause {}; auto &delete_using_clause_opt = list_pr.GetChild(4).Cast(); if (delete_using_clause_opt.HasResult()) { auto delete_using_clause_value = transformer.Transform>>(delete_using_clause_opt.GetResult()); delete_using_clause = std::move(delete_using_clause_value); } - unique_ptr where_clause {}; + optional> where_clause {}; auto &where_clause_opt = list_pr.GetChild(5).Cast(); if (where_clause_opt.HasResult()) { auto where_clause_value = transformer.Transform>(where_clause_opt.GetResult()); where_clause = std::move(where_clause_value); } - vector> returning_clause {}; + optional>> returning_clause {}; auto &returning_clause_opt = list_pr.GetChild(6).Cast(); if (returning_clause_opt.HasResult()) { auto returning_clause_value = @@ -3443,8 +3537,11 @@ unique_ptr PEGTransformerFactory::TransformDeleteStatement unique_ptr PEGTransformerFactory::TransformTruncateStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); auto base_table_name = transformer.Transform>(list_pr.GetChild(2)); - auto result = TransformTruncateStatement(transformer, std::move(base_table_name)); + auto result = TransformTruncateStatement(transformer, has_result, std::move(base_table_name)); return make_uniq>>(std::move(result)); } @@ -3452,13 +3549,16 @@ unique_ptr PEGTransformerFactory::TransformTargetOptAliasI ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); - Identifier col_id {}; + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); + optional col_id {}; auto &col_id_opt = list_pr.GetChild(2).Cast(); if (col_id_opt.HasResult()) { auto col_id_value = transformer.Transform(col_id_opt.GetResult()); col_id = col_id_value; } - auto result = TransformTargetOptAlias(transformer, std::move(base_table_name), col_id); + auto result = TransformTargetOptAlias(transformer, std::move(base_table_name), has_result, col_id); return make_uniq>>(std::move(result)); } @@ -3505,7 +3605,7 @@ unique_ptr PEGTransformerFactory::TransformShowQualifiedNa ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto show_or_describe_or_summarize = transformer.Transform(list_pr.GetChild(0)); - DescribeTarget describe_target {}; + optional describe_target {}; auto &describe_target_opt = list_pr.GetChild(1).Cast(); if (describe_target_opt.HasResult()) { auto describe_target_value = transformer.Transform(describe_target_opt.GetResult()); @@ -3608,14 +3708,17 @@ unique_ptr PEGTransformerFactory::TransformDescRuleInterna unique_ptr PEGTransformerFactory::TransformDetachStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(2).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); if_exists = if_exists_value; } auto catalog_name = list_pr.GetChild(3).Cast().identifier; - auto result = TransformDetachStatement(transformer, if_exists, catalog_name); + auto result = TransformDetachStatement(transformer, has_result, if_exists, catalog_name); return make_uniq>>(std::move(result)); } @@ -3623,7 +3726,7 @@ unique_ptr PEGTransformerFactory::TransformDropStatementIn ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto drop_entries = transformer.Transform>(list_pr.GetChild(1)); - bool drop_behavior {}; + optional drop_behavior {}; auto &drop_behavior_opt = list_pr.GetChild(2).Cast(); if (drop_behavior_opt.HasResult()) { auto drop_behavior_value = transformer.Transform(drop_behavior_opt.GetResult()); @@ -3644,7 +3747,7 @@ unique_ptr PEGTransformerFactory::TransformDropEntriesInte unique_ptr PEGTransformerFactory::TransformDropTriggerInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -3660,7 +3763,7 @@ unique_ptr PEGTransformerFactory::TransformDropTableIntern ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto table_or_view = transformer.Transform(list_pr.GetChild(0)); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -3680,7 +3783,7 @@ unique_ptr PEGTransformerFactory::TransformDropTableFuncti ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto comment_macro_table = transformer.Transform(list_pr.GetChild(0)); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -3700,7 +3803,7 @@ unique_ptr PEGTransformerFactory::TransformDropFunctionInt ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto function_type_macro = transformer.Transform(list_pr.GetChild(0)); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -3719,7 +3822,7 @@ unique_ptr PEGTransformerFactory::TransformDropFunctionInt unique_ptr PEGTransformerFactory::TransformDropSchemaInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -3738,7 +3841,7 @@ unique_ptr PEGTransformerFactory::TransformDropSchemaInter unique_ptr PEGTransformerFactory::TransformDropIndexInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -3795,7 +3898,7 @@ PEGTransformerFactory::TransformCatalogReservedSchemaIndexInternal(PEGTransforme unique_ptr PEGTransformerFactory::TransformDropSequenceInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -3814,7 +3917,7 @@ unique_ptr PEGTransformerFactory::TransformDropSequenceInt unique_ptr PEGTransformerFactory::TransformDropCollationInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -3833,7 +3936,7 @@ unique_ptr PEGTransformerFactory::TransformDropCollationIn unique_ptr PEGTransformerFactory::TransformDropTypeInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(1).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); @@ -3852,20 +3955,20 @@ unique_ptr PEGTransformerFactory::TransformDropTypeInterna unique_ptr PEGTransformerFactory::TransformDropSecretInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - SecretPersistType temporary {}; + optional temporary {}; auto &temporary_opt = list_pr.GetChild(0).Cast(); if (temporary_opt.HasResult()) { auto temporary_value = transformer.Transform(temporary_opt.GetResult()); temporary = temporary_value; } - bool if_exists {}; + optional if_exists {}; auto &if_exists_opt = list_pr.GetChild(2).Cast(); if (if_exists_opt.HasResult()) { auto if_exists_value = transformer.Transform(if_exists_opt.GetResult()); if_exists = if_exists_value; } auto secret_name = transformer.Transform(list_pr.GetChild(3)); - Identifier drop_secret_storage {}; + optional drop_secret_storage {}; auto &drop_secret_storage_opt = list_pr.GetChild(4).Cast(); if (drop_secret_storage_opt.HasResult()) { auto drop_secret_storage_value = transformer.Transform(drop_secret_storage_opt.GetResult()); @@ -3974,7 +4077,7 @@ unique_ptr PEGTransformerFactory::TransformExecuteStatemen ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto identifier = list_pr.GetChild(1).Cast().identifier; - vector table_function_arguments {}; + optional> table_function_arguments {}; auto &table_function_arguments_opt = list_pr.GetChild(2).Cast(); if (table_function_arguments_opt.HasResult()) { auto table_function_arguments_value = @@ -3988,13 +4091,13 @@ unique_ptr PEGTransformerFactory::TransformExecuteStatemen unique_ptr PEGTransformerFactory::TransformExplainStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - bool explain_analyze {}; + optional explain_analyze {}; auto &explain_analyze_opt = list_pr.GetChild(1).Cast(); if (explain_analyze_opt.HasResult()) { auto explain_analyze_value = transformer.Transform(explain_analyze_opt.GetResult()); explain_analyze = explain_analyze_value; } - vector explain_option_list {}; + optional> explain_option_list {}; auto &explain_option_list_opt = list_pr.GetChild(2).Cast(); if (explain_option_list_opt.HasResult()) { auto explain_option_list_value = @@ -4030,7 +4133,7 @@ unique_ptr PEGTransformerFactory::TransformExplainOptionIn ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto explain_option_name = transformer.Transform(list_pr.GetChild(0)); - unique_ptr expression {}; + optional> expression {}; auto &expression_opt = list_pr.GetChild(1).Cast(); if (expression_opt.HasResult()) { auto expression_value = transformer.Transform>(expression_opt.GetResult()); @@ -4059,14 +4162,14 @@ PEGTransformerFactory::TransformExplainableStatementsInternal(PEGTransformer &tr unique_ptr PEGTransformerFactory::TransformExportStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - string export_source {}; + optional export_source {}; auto &export_source_opt = list_pr.GetChild(2).Cast(); if (export_source_opt.HasResult()) { auto export_source_value = transformer.Transform(export_source_opt.GetResult()); export_source = export_source_value; } auto string_literal = transformer.Transform(list_pr.GetChild(3)); - vector generic_copy_option_list {}; + optional> generic_copy_option_list {}; auto &generic_copy_option_list_opt = list_pr.GetChild(4).Cast(); if (generic_copy_option_list_opt.HasResult()) { auto generic_copy_option_list_value = @@ -6830,40 +6933,40 @@ unique_ptr PEGTransformerFactory::TransformExtractDatePart unique_ptr PEGTransformerFactory::TransformInsertStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - CommonTableExpressionMap with_clause {}; + optional with_clause {}; auto &with_clause_opt = list_pr.GetChild(0).Cast(); if (with_clause_opt.HasResult()) { auto with_clause_value = transformer.Transform(with_clause_opt.GetResult()); with_clause = std::move(with_clause_value); } - OnConflictAction or_action {}; + optional or_action {}; auto &or_action_opt = list_pr.GetChild(2).Cast(); if (or_action_opt.HasResult()) { auto or_action_value = transformer.Transform(or_action_opt.GetResult()); or_action = or_action_value; } auto insert_target = transformer.Transform>(list_pr.GetChild(4)); - InsertColumnOrder by_name_or_position {}; + optional by_name_or_position {}; auto &by_name_or_position_opt = list_pr.GetChild(5).Cast(); if (by_name_or_position_opt.HasResult()) { auto by_name_or_position_value = transformer.Transform(by_name_or_position_opt.GetResult()); by_name_or_position = by_name_or_position_value; } - vector insert_column_list {}; + optional> insert_column_list {}; auto &insert_column_list_opt = list_pr.GetChild(6).Cast(); if (insert_column_list_opt.HasResult()) { auto insert_column_list_value = transformer.Transform>(insert_column_list_opt.GetResult()); insert_column_list = insert_column_list_value; } auto insert_values = transformer.Transform(list_pr.GetChild(7)); - unique_ptr on_conflict_clause {}; + optional> on_conflict_clause {}; auto &on_conflict_clause_opt = list_pr.GetChild(8).Cast(); if (on_conflict_clause_opt.HasResult()) { auto on_conflict_clause_value = transformer.Transform>(on_conflict_clause_opt.GetResult()); on_conflict_clause = std::move(on_conflict_clause_value); } - vector> returning_clause {}; + optional>> returning_clause {}; auto &returning_clause_opt = list_pr.GetChild(9).Cast(); if (returning_clause_opt.HasResult()) { auto returning_clause_value = @@ -6936,7 +7039,7 @@ unique_ptr PEGTransformerFactory::TransformInsertTargetInt ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); - Identifier insert_alias {}; + optional insert_alias {}; auto &insert_alias_opt = list_pr.GetChild(1).Cast(); if (insert_alias_opt.HasResult()) { auto insert_alias_value = transformer.Transform(insert_alias_opt.GetResult()); @@ -7000,7 +7103,7 @@ unique_ptr PEGTransformerFactory::TransformDefaultValuesIn unique_ptr PEGTransformerFactory::TransformOnConflictClauseInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - OnConflictExpressionTarget on_conflict_target {}; + optional on_conflict_target {}; auto &on_conflict_target_opt = list_pr.GetChild(2).Cast(); if (on_conflict_target_opt.HasResult()) { auto on_conflict_target_value = @@ -7025,7 +7128,7 @@ PEGTransformerFactory::TransformOnConflictExpressionTargetInternal(PEGTransforme ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto column_id_list = transformer.Transform>(list_pr.GetChild(0)); - unique_ptr where_clause {}; + optional> where_clause {}; auto &where_clause_opt = list_pr.GetChild(1).Cast(); if (where_clause_opt.HasResult()) { auto where_clause_value = transformer.Transform>(where_clause_opt.GetResult()); @@ -7055,7 +7158,7 @@ unique_ptr PEGTransformerFactory::TransformOnConflictUpdat ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto update_set_clause = transformer.Transform>(list_pr.GetChild(3)); - unique_ptr where_clause {}; + optional> where_clause {}; auto &where_clause_opt = list_pr.GetChild(4).Cast(); if (where_clause_opt.HasResult()) { auto where_clause_value = transformer.Transform>(where_clause_opt.GetResult()); @@ -7083,7 +7186,7 @@ unique_ptr PEGTransformerFactory::TransformLoadStatementIn ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto col_id_or_string = transformer.Transform(list_pr.GetChild(1)); - Identifier extension_alias {}; + optional extension_alias {}; auto &extension_alias_opt = list_pr.GetChild(2).Cast(); if (extension_alias_opt.HasResult()) { auto extension_alias_value = transformer.Transform(extension_alias_opt.GetResult()); @@ -7104,20 +7207,24 @@ unique_ptr PEGTransformerFactory::TransformExtensionAliasI unique_ptr PEGTransformerFactory::TransformInstallStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(0).Cast(); + has_result = has_result_opt.HasResult(); auto identifier_or_string_literal = transformer.Transform(list_pr.GetChild(2)); - ExtensionRepositoryInfo from_source {}; + optional from_source {}; auto &from_source_opt = list_pr.GetChild(3).Cast(); if (from_source_opt.HasResult()) { auto from_source_value = transformer.Transform(from_source_opt.GetResult()); from_source = from_source_value; } - string version_number {}; + optional version_number {}; auto &version_number_opt = list_pr.GetChild(4).Cast(); if (version_number_opt.HasResult()) { auto version_number_value = transformer.Transform(version_number_opt.GetResult()); version_number = version_number_value; } - auto result = TransformInstallStatement(transformer, identifier_or_string_literal, from_source, version_number); + auto result = + TransformInstallStatement(transformer, has_result, identifier_or_string_literal, from_source, version_number); return make_uniq>>(std::move(result)); } @@ -7125,7 +7232,7 @@ unique_ptr PEGTransformerFactory::TransformUpdateExtensionsStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector identifier {}; + optional> identifier {}; auto &identifier_opt = list_pr.GetChild(2).Cast(); if (identifier_opt.HasResult()) { vector identifier_value; @@ -7176,7 +7283,7 @@ unique_ptr PEGTransformerFactory::TransformVersionNumberIn unique_ptr PEGTransformerFactory::TransformMergeIntoStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - CommonTableExpressionMap with_clause {}; + optional with_clause {}; auto &with_clause_opt = list_pr.GetChild(0).Cast(); if (with_clause_opt.HasResult()) { auto with_clause_value = transformer.Transform(with_clause_opt.GetResult()); @@ -7192,7 +7299,7 @@ unique_ptr PEGTransformerFactory::TransformMergeIntoStatem transformer.Transform>>(merge_match_item.get()); merge_match.push_back(std::move(merge_match_value)); } - vector> returning_clause {}; + optional>> returning_clause {}; auto &returning_clause_opt = list_pr.GetChild(7).Cast(); if (returning_clause_opt.HasResult()) { auto returning_clause_value = @@ -7224,7 +7331,7 @@ unique_ptr PEGTransformerFactory::TransformMergeMatchInter unique_ptr PEGTransformerFactory::TransformMatchedClauseInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - unique_ptr and_expression {}; + optional> and_expression {}; auto &and_expression_opt = list_pr.GetChild(2).Cast(); if (and_expression_opt.HasResult()) { auto and_expression_value = transformer.Transform>(and_expression_opt.GetResult()); @@ -7246,7 +7353,7 @@ PEGTransformerFactory::TransformMatchedClauseActionInternal(PEGTransformer &tran unique_ptr PEGTransformerFactory::TransformUpdateMatchClauseInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - unique_ptr update_match_info {}; + optional> update_match_info {}; auto &update_match_info_opt = list_pr.GetChild(1).Cast(); if (update_match_info_opt.HasResult()) { auto update_match_info_value = @@ -7290,7 +7397,7 @@ unique_ptr PEGTransformerFactory::TransformDeleteMatchClau unique_ptr PEGTransformerFactory::TransformInsertMatchClauseInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - unique_ptr insert_match_info {}; + optional> insert_match_info {}; auto &insert_match_info_opt = list_pr.GetChild(1).Cast(); if (insert_match_info_opt.HasResult()) { auto insert_match_info_value = @@ -7318,20 +7425,23 @@ PEGTransformerFactory::TransformInsertDefaultValuesInternal(PEGTransformer &tran unique_ptr PEGTransformerFactory::TransformInsertByNameOrPositionInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - InsertColumnOrder by_name_or_position {}; + optional by_name_or_position {}; auto &by_name_or_position_opt = list_pr.GetChild(0).Cast(); if (by_name_or_position_opt.HasResult()) { auto by_name_or_position_value = transformer.Transform(by_name_or_position_opt.GetResult()); by_name_or_position = by_name_or_position_value; } - auto result = TransformInsertByNameOrPosition(transformer, by_name_or_position); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); + auto result = TransformInsertByNameOrPosition(transformer, by_name_or_position, has_result); return make_uniq>>(std::move(result)); } unique_ptr PEGTransformerFactory::TransformInsertValuesListInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - vector insert_column_list {}; + optional> insert_column_list {}; auto &insert_column_list_opt = list_pr.GetChild(0).Cast(); if (insert_column_list_opt.HasResult()) { auto insert_column_list_value = transformer.Transform>(insert_column_list_opt.GetResult()); @@ -7356,7 +7466,7 @@ PEGTransformerFactory::TransformDoNothingMatchClauseInternal(PEGTransformer &tra unique_ptr PEGTransformerFactory::TransformErrorMatchClauseInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - unique_ptr expression {}; + optional> expression {}; auto &expression_opt = list_pr.GetChild(1).Cast(); if (expression_opt.HasResult()) { auto expression_value = transformer.Transform>(expression_opt.GetResult()); @@ -7397,14 +7507,14 @@ unique_ptr PEGTransformerFactory::TransformAndExpressionIn unique_ptr PEGTransformerFactory::TransformNotMatchedClauseInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - MergeActionCondition by_source_or_target {}; + optional by_source_or_target {}; auto &by_source_or_target_opt = list_pr.GetChild(3).Cast(); if (by_source_or_target_opt.HasResult()) { auto by_source_or_target_value = transformer.Transform(by_source_or_target_opt.GetResult()); by_source_or_target = by_source_or_target_value; } - unique_ptr and_expression {}; + optional> and_expression {}; auto &and_expression_opt = list_pr.GetChild(4).Cast(); if (and_expression_opt.HasResult()) { auto and_expression_value = transformer.Transform>(and_expression_opt.GetResult()); @@ -7584,7 +7694,7 @@ unique_ptr PEGTransformerFactory::TransformPragmaFunctionI ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto pragma_name = list_pr.GetChild(0).Cast().identifier; - vector> pragma_parameters {}; + optional>> pragma_parameters {}; auto &pragma_parameters_opt = list_pr.GetChild(1).Cast(); if (pragma_parameters_opt.HasResult()) { auto pragma_parameters_value = @@ -7612,7 +7722,7 @@ unique_ptr PEGTransformerFactory::TransformPrepareStatemen ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto identifier = list_pr.GetChild(1).Cast().identifier; - vector type_list {}; + optional> type_list {}; auto &type_list_opt = list_pr.GetChild(2).Cast(); if (type_list_opt.HasResult()) { auto type_list_value = transformer.Transform>(type_list_opt.GetResult()); @@ -9373,7 +9483,7 @@ PEGTransformerFactory::TransformZoneIntervalWithIntervalInternal(PEGTransformer ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto string_literal = transformer.Transform(list_pr.GetChild(1)); - DatePartSpecifier interval {}; + optional interval {}; auto &interval_opt = list_pr.GetChild(2).Cast(); if (interval_opt.HasResult()) { auto interval_value = transformer.Transform(interval_opt.GetResult()); @@ -9397,7 +9507,7 @@ PEGTransformerFactory::TransformZoneIntervalWithPrecisionInternal(PEGTransformer unique_ptr PEGTransformerFactory::TransformSetSettingInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - SetScope setting_scope {}; + optional setting_scope {}; auto &setting_scope_opt = list_pr.GetChild(0).Cast(); if (setting_scope_opt.HasResult()) { auto setting_scope_value = transformer.Transform(setting_scope_opt.GetResult()); @@ -9481,25 +9591,36 @@ PEGTransformerFactory::TransformTransactionStatementInternal(PEGTransformer &tra unique_ptr PEGTransformerFactory::TransformBeginTransactionInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - TransactionModifierType read_or_write {}; + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); + optional read_or_write {}; auto &read_or_write_opt = list_pr.GetChild(2).Cast(); if (read_or_write_opt.HasResult()) { auto read_or_write_value = transformer.Transform(read_or_write_opt.GetResult()); read_or_write = read_or_write_value; } - auto result = TransformBeginTransaction(transformer, read_or_write); + auto result = TransformBeginTransaction(transformer, has_result, read_or_write); return make_uniq>>(std::move(result)); } unique_ptr PEGTransformerFactory::TransformRollbackTransactionInternal(PEGTransformer &transformer, ParseResult &parse_result) { - auto result = TransformRollbackTransaction(transformer); + auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); + auto result = TransformRollbackTransaction(transformer, has_result); return make_uniq>>(std::move(result)); } unique_ptr PEGTransformerFactory::TransformCommitTransactionInternal(PEGTransformer &transformer, ParseResult &parse_result) { - auto result = TransformCommitTransaction(transformer); + auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(1).Cast(); + has_result = has_result_opt.HasResult(); + auto result = TransformCommitTransaction(transformer, has_result); return make_uniq>>(std::move(result)); } @@ -9534,7 +9655,7 @@ unique_ptr PEGTransformerFactory::TransformReadWriteIntern unique_ptr PEGTransformerFactory::TransformUpdateStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - CommonTableExpressionMap with_clause {}; + optional with_clause {}; auto &with_clause_opt = list_pr.GetChild(0).Cast(); if (with_clause_opt.HasResult()) { auto with_clause_value = transformer.Transform(with_clause_opt.GetResult()); @@ -9542,19 +9663,19 @@ unique_ptr PEGTransformerFactory::TransformUpdateStatement } auto update_target = transformer.Transform>(list_pr.GetChild(2)); auto update_set_clause = transformer.Transform>(list_pr.GetChild(3)); - unique_ptr from_clause {}; + optional> from_clause {}; auto &from_clause_opt = list_pr.GetChild(4).Cast(); if (from_clause_opt.HasResult()) { auto from_clause_value = transformer.Transform>(from_clause_opt.GetResult()); from_clause = std::move(from_clause_value); } - unique_ptr where_clause {}; + optional> where_clause {}; auto &where_clause_opt = list_pr.GetChild(5).Cast(); if (where_clause_opt.HasResult()) { auto where_clause_value = transformer.Transform>(where_clause_opt.GetResult()); where_clause = std::move(where_clause_value); } - vector> returning_clause {}; + optional>> returning_clause {}; auto &returning_clause_opt = list_pr.GetChild(6).Cast(); if (returning_clause_opt.HasResult()) { auto returning_clause_value = @@ -9587,7 +9708,7 @@ unique_ptr PEGTransformerFactory::TransformBaseTableAliasS ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto base_table_name = transformer.Transform>(list_pr.GetChild(0)); - Identifier update_alias {}; + optional update_alias {}; auto &update_alias_opt = list_pr.GetChild(1).Cast(); if (update_alias_opt.HasResult()) { auto update_alias_value = transformer.Transform(update_alias_opt.GetResult()); @@ -9600,8 +9721,11 @@ unique_ptr PEGTransformerFactory::TransformBaseTableAliasS unique_ptr PEGTransformerFactory::TransformUpdateAliasInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); + bool has_result {}; + auto &has_result_opt = list_pr.GetChild(0).Cast(); + has_result = has_result_opt.HasResult(); auto col_id = transformer.Transform(list_pr.GetChild(1)); - auto result = TransformUpdateAlias(transformer, col_id); + auto result = TransformUpdateAlias(transformer, has_result, col_id); return make_uniq>(result); } @@ -9654,7 +9778,7 @@ unique_ptr PEGTransformerFactory::TransformUpdateSetColumnTargetInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); auto column_name = list_pr.GetChild(0).Cast().identifier; - vector dot_identifier {}; + optional> dot_identifier {}; auto &dot_identifier_opt = list_pr.GetChild(1).Cast(); if (dot_identifier_opt.HasResult()) { vector dot_identifier_value; @@ -9706,7 +9830,7 @@ PEGTransformerFactory::TransformUseTargetCatalogSchemaInternal(PEGTransformer &t auto &list_pr = parse_result.Cast(); auto catalog_name = list_pr.GetChild(0).Cast().identifier; auto reserved_schema_name = list_pr.GetChild(2).Cast().identifier; - vector dot_identifier {}; + optional> dot_identifier {}; auto &dot_identifier_opt = list_pr.GetChild(3).Cast(); if (dot_identifier_opt.HasResult()) { vector dot_identifier_value; @@ -9732,13 +9856,13 @@ unique_ptr PEGTransformerFactory::TransformDotIdentifierIn unique_ptr PEGTransformerFactory::TransformVacuumStatementInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - VacuumOptions vacuum_options {}; + optional vacuum_options {}; auto &vacuum_options_opt = list_pr.GetChild(1).Cast(); if (vacuum_options_opt.HasResult()) { auto vacuum_options_value = transformer.Transform(vacuum_options_opt.GetResult()); vacuum_options = vacuum_options_value; } - AnalyzeTarget analyze_target {}; + optional analyze_target {}; auto &analyze_target_opt = list_pr.GetChild(2).Cast(); if (analyze_target_opt.HasResult()) { auto analyze_target_value = transformer.Transform(analyze_target_opt.GetResult()); @@ -9772,25 +9896,25 @@ PEGTransformerFactory::TransformVacuumParensOptionsInternal(PEGTransformer &tran unique_ptr PEGTransformerFactory::TransformVacuumLegacyOptionsInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto &list_pr = parse_result.Cast(); - string opt_full {}; + optional opt_full {}; auto &opt_full_opt = list_pr.GetChild(0).Cast(); if (opt_full_opt.HasResult()) { auto opt_full_value = transformer.Transform(opt_full_opt.GetResult()); opt_full = opt_full_value; } - string opt_freeze {}; + optional opt_freeze {}; auto &opt_freeze_opt = list_pr.GetChild(1).Cast(); if (opt_freeze_opt.HasResult()) { auto opt_freeze_value = transformer.Transform(opt_freeze_opt.GetResult()); opt_freeze = opt_freeze_value; } - string opt_verbose {}; + optional opt_verbose {}; auto &opt_verbose_opt = list_pr.GetChild(2).Cast(); if (opt_verbose_opt.HasResult()) { auto opt_verbose_value = transformer.Transform(opt_verbose_opt.GetResult()); opt_verbose = opt_verbose_value; } - string opt_analyze {}; + optional opt_analyze {}; auto &opt_analyze_opt = list_pr.GetChild(3).Cast(); if (opt_analyze_opt.HasResult()) { auto opt_analyze_value = transformer.Transform(opt_analyze_opt.GetResult()); @@ -9924,7 +10048,6 @@ void PEGTransformerFactory::RegisterGenerated() { {"SimpleType", &PEGTransformerFactory::TransformSimpleTypeInternal}, {"CharacterSimpleType", &PEGTransformerFactory::TransformCharacterSimpleTypeInternal}, {"QualifiedSimpleType", &PEGTransformerFactory::TransformQualifiedSimpleTypeInternal}, - {"CharacterType", &PEGTransformerFactory::TransformCharacterTypeInternal}, {"IntervalType", &PEGTransformerFactory::TransformIntervalTypeInternal}, {"IntervalInterval", &PEGTransformerFactory::TransformIntervalIntervalInternal}, {"IntervalWithSpecifier", &PEGTransformerFactory::TransformIntervalWithSpecifierInternal}, @@ -9992,7 +10115,12 @@ void PEGTransformerFactory::RegisterGenerated() { {"WithOrWithout", &PEGTransformerFactory::TransformWithOrWithoutInternal}, {"WithRule", &PEGTransformerFactory::TransformWithRuleInternal}, {"WithoutRule", &PEGTransformerFactory::TransformWithoutRuleInternal}, + {"ConnectStatement", &PEGTransformerFactory::TransformConnectStatementInternal}, {"DisconnectStatement", &PEGTransformerFactory::TransformDisconnectStatementInternal}, + {"SessionTarget", &PEGTransformerFactory::TransformSessionTargetInternal}, + {"LocalSessionTarget", &PEGTransformerFactory::TransformLocalSessionTargetInternal}, + {"StringSessionTarget", &PEGTransformerFactory::TransformStringSessionTargetInternal}, + {"CatalogSessionTarget", &PEGTransformerFactory::TransformCatalogSessionTargetInternal}, {"CopyStatement", &PEGTransformerFactory::TransformCopyStatementInternal}, {"CopyVariations", &PEGTransformerFactory::TransformCopyVariationsInternal}, {"CopyTable", &PEGTransformerFactory::TransformCopyTableInternal}, diff --git a/src/duckdb/src/parser/peg/transformer/transform_generic_copy_option.cpp b/src/duckdb/src/parser/peg/transformer/transform_generic_copy_option.cpp index dde50816a..39f3860d0 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_generic_copy_option.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_generic_copy_option.cpp @@ -77,17 +77,17 @@ static unique_ptr CreateExpressionRowFunction(vector generic_copy_option_value) { GenericCopyOption copy_option; copy_option.name = Identifier(StringUtil::Lower(copy_option_name.GetIdentifierName())); - if (!generic_copy_option_value.has_value) { + if (!generic_copy_option_value || !generic_copy_option_value->has_value) { return copy_option; } - if (generic_copy_option_value.is_order_list) { - auto &orders = generic_copy_option_value.order_list; + if (generic_copy_option_value->is_order_list) { + auto &orders = generic_copy_option_value->order_list; bool has_order_modifier = false; for (auto &order : orders) { if (order.type != OrderType::ORDER_DEFAULT || order.null_order != OrderByNullType::ORDER_DEFAULT) { @@ -103,10 +103,10 @@ GenericCopyOption PEGTransformerFactory::TransformGenericCopyOption(PEGTransform } else if (orders.size() == 1) { SetGenericCopyOptionExpression(copy_option, std::move(orders[0].expression)); } else { - copy_option.expression = CreateExpressionRowFunction(generic_copy_option_value.order_list); + copy_option.expression = CreateExpressionRowFunction(generic_copy_option_value->order_list); } } else { - SetGenericCopyOptionExpression(copy_option, std::move(generic_copy_option_value.expression)); + SetGenericCopyOptionExpression(copy_option, std::move(generic_copy_option_value->expression)); } return copy_option; } diff --git a/src/duckdb/src/parser/peg/transformer/transform_insert.cpp b/src/duckdb/src/parser/peg/transformer/transform_insert.cpp index f2bb7484c..64066880e 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_insert.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_insert.cpp @@ -7,18 +7,23 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformInsertStatement( - PEGTransformer &transformer, CommonTableExpressionMap with_clause, const OnConflictAction &or_action, - unique_ptr insert_target, const InsertColumnOrder &by_name_or_position, - const vector &insert_column_list, InsertValues insert_values, unique_ptr on_conflict_clause, - vector> returning_clause) { + PEGTransformer &transformer, optional with_clause, + const optional &or_action, unique_ptr insert_target, + const optional &by_name_or_position, const optional> &insert_column_list, + InsertValues insert_values, optional> on_conflict_clause, + optional>> returning_clause) { auto result = make_uniq(); auto &node = *result->node; - node.cte_map = std::move(with_clause); + if (with_clause) { + node.cte_map = std::move(*with_clause); + } node.catalog = insert_target->catalog_name; node.schema = insert_target->schema_name; node.table = insert_target->table_name; - node.column_order = by_name_or_position; - node.columns = StringsToIdentifiers(insert_column_list); + node.column_order = by_name_or_position ? *by_name_or_position : InsertColumnOrder::INSERT_BY_POSITION; + if (insert_column_list) { + node.columns = StringsToIdentifiers(*insert_column_list); + } if (!node.columns.empty() && insert_values.default_values) { throw ParserException( "You can not provide both a column list and DEFAULT VALUES, please remove one of the two"); @@ -29,21 +34,24 @@ unique_ptr PEGTransformerFactory::TransformInsertStatement( if (insert_values.select_statement) { node.select_statement = std::move(insert_values.select_statement); } + auto action = or_action.value_or(OnConflictAction::THROW); if (on_conflict_clause) { - if (or_action != OnConflictAction::THROW) { + if (action != OnConflictAction::THROW) { // OR REPLACE | OR IGNORE are shorthands for the ON CONFLICT clause throw ParserException("You can not provide both OR REPLACE|IGNORE and an ON CONFLICT clause, please remove " "the first if you want to have more granular control"); } - node.on_conflict_info = std::move(on_conflict_clause); + node.on_conflict_info = std::move(*on_conflict_clause); node.table_ref = std::move(insert_target); - } else if (or_action != OnConflictAction::THROW) { + } else if (action != OnConflictAction::THROW) { auto on_conflict_info = make_uniq(); - on_conflict_info->action_type = or_action; + on_conflict_info->action_type = action; node.on_conflict_info = std::move(on_conflict_info); node.table_ref = std::move(insert_target); } - node.returning_list = std::move(returning_clause); + if (returning_clause) { + node.returning_list = std::move(*returning_clause); + } return std::move(result); } @@ -57,8 +65,10 @@ OnConflictAction PEGTransformerFactory::TransformInsertOrIgnore(PEGTransformer & unique_ptr PEGTransformerFactory::TransformInsertTarget(PEGTransformer &transformer, unique_ptr base_table_name, - const Identifier &insert_alias) { - base_table_name->alias = insert_alias; + const optional &insert_alias) { + if (insert_alias) { + base_table_name->alias = *insert_alias; + } return base_table_name; } @@ -68,20 +78,26 @@ Identifier PEGTransformerFactory::TransformInsertAlias(PEGTransformer &transform unique_ptr PEGTransformerFactory::TransformOnConflictClause(PEGTransformer &transformer, - OnConflictExpressionTarget on_conflict_target, + optional on_conflict_target, unique_ptr on_conflict_action) { - on_conflict_action->indexed_columns = on_conflict_target.indexed_columns; - if (on_conflict_target.where_clause) { - on_conflict_action->condition = std::move(on_conflict_target.where_clause); + if (on_conflict_target) { + on_conflict_action->indexed_columns = on_conflict_target->indexed_columns; + if (on_conflict_target->where_clause) { + on_conflict_action->condition = std::move(on_conflict_target->where_clause); + } } return on_conflict_action; } -OnConflictExpressionTarget PEGTransformerFactory::TransformOnConflictExpressionTarget( - PEGTransformer &transformer, const vector &column_id_list, unique_ptr where_clause) { +OnConflictExpressionTarget +PEGTransformerFactory::TransformOnConflictExpressionTarget(PEGTransformer &transformer, + const vector &column_id_list, + optional> where_clause) { OnConflictExpressionTarget result; result.indexed_columns = StringsToIdentifiers(column_id_list); - result.where_clause = std::move(where_clause); + if (where_clause) { + result.where_clause = std::move(*where_clause); + } return result; } @@ -90,13 +106,16 @@ OnConflictExpressionTarget PEGTransformerFactory::TransformOnConflictIndexTarget throw NotImplementedException("ON CONSTRAINT conflict target is not supported yet"); } -unique_ptr PEGTransformerFactory::TransformOnConflictUpdate(PEGTransformer &transformer, - unique_ptr update_set_clause, - unique_ptr where_clause) { +unique_ptr +PEGTransformerFactory::TransformOnConflictUpdate(PEGTransformer &transformer, + unique_ptr update_set_clause, + optional> where_clause) { auto result = make_uniq(); result->action_type = OnConflictAction::UPDATE; result->set_info = std::move(update_set_clause); - result->set_info->condition = std::move(where_clause); + if (where_clause) { + result->set_info->condition = std::move(*where_clause); + } return result; } diff --git a/src/duckdb/src/parser/peg/transformer/transform_load.cpp b/src/duckdb/src/parser/peg/transformer/transform_load.cpp index fc3c6f068..64e398b34 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_load.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_load.cpp @@ -8,13 +8,13 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformLoadStatement(PEGTransformer &transformer, const Identifier &col_id_or_string, - const Identifier &extension_alias) { + const optional &extension_alias) { auto result = make_uniq(); auto info = make_uniq(); info->repo_is_alias = false; info->filename = col_id_or_string.GetIdentifierName(); - if (!extension_alias.empty()) { - info->alias = extension_alias; + if (extension_alias) { + info->alias = *extension_alias; info->load_type = LoadType::LOAD_AS; } else { info->load_type = LoadType::LOAD; @@ -28,19 +28,19 @@ Identifier PEGTransformerFactory::TransformExtensionAlias(PEGTransformer &transf } unique_ptr PEGTransformerFactory::TransformInstallStatement( - PEGTransformer &transformer, const QualifiedName &identifier_or_string_literal, - const ExtensionRepositoryInfo &from_source, const string &version_number) { + PEGTransformer &transformer, const bool &has_result, const QualifiedName &identifier_or_string_literal, + const optional &from_source, const optional &version_number) { auto result = make_uniq(); auto info = make_uniq(); info->load_type = LoadType::INSTALL; info->filename = identifier_or_string_literal.name.GetIdentifierName(); info->repo_is_alias = false; - if (!from_source.name.empty()) { - info->repository = from_source.name.GetIdentifierName(); - info->repo_is_alias = from_source.repository_is_alias; + if (from_source) { + info->repository = from_source->name.GetIdentifierName(); + info->repo_is_alias = from_source->repository_is_alias; } - if (!version_number.empty()) { - info->version = version_number; + if (version_number) { + info->version = *version_number; } result->info = std::move(info); return std::move(result); @@ -64,10 +64,12 @@ ExtensionRepositoryInfo PEGTransformerFactory::TransformFromSourceString(PEGTran unique_ptr PEGTransformerFactory::TransformUpdateExtensionsStatement(PEGTransformer &transformer, - const vector &identifier) { + const optional> &identifier) { auto result = make_uniq(); auto info = make_uniq(); - info->extensions_to_update = identifier; + if (identifier) { + info->extensions_to_update = *identifier; + } result->info = std::move(info); return std::move(result); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_merge_into.cpp b/src/duckdb/src/parser/peg/transformer/transform_merge_into.cpp index 6d416dbaa..b40a00cf6 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_merge_into.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_merge_into.cpp @@ -6,13 +6,15 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformMergeIntoStatement( - PEGTransformer &transformer, CommonTableExpressionMap with_clause, unique_ptr target_opt_alias, - unique_ptr merge_into_using_clause, JoinQualifier join_qualifier, - vector>> merge_match, - vector> returning_clause) { + PEGTransformer &transformer, optional with_clause, + unique_ptr target_opt_alias, unique_ptr merge_into_using_clause, + JoinQualifier join_qualifier, vector>> merge_match, + optional>> returning_clause) { auto result = make_uniq(); auto &node = *result->node; - node.cte_map = std::move(with_clause); + if (with_clause) { + node.cte_map = std::move(*with_clause); + } node.target = std::move(target_opt_alias); node.source = std::move(merge_into_using_clause); if (join_qualifier.on_clause) { @@ -38,7 +40,9 @@ unique_ptr PEGTransformerFactory::TransformMergeIntoStatement( } node.actions[action_condition].push_back(std::move(action)); } - node.returning_list = std::move(returning_clause); + if (returning_clause) { + node.returning_list = std::move(*returning_clause); + } return std::move(result); } @@ -48,17 +52,20 @@ unique_ptr PEGTransformerFactory::TransformMergeIntoUsingClause(PEGTra } pair> -PEGTransformerFactory::TransformMatchedClause(PEGTransformer &transformer, unique_ptr and_expression, +PEGTransformerFactory::TransformMatchedClause(PEGTransformer &transformer, + optional> and_expression, unique_ptr matched_clause_action) { - matched_clause_action->condition = std::move(and_expression); + if (and_expression) { + matched_clause_action->condition = std::move(*and_expression); + } return pair>(MergeActionCondition::WHEN_MATCHED, std::move(matched_clause_action)); } unique_ptr PEGTransformerFactory::TransformUpdateMatchClause(PEGTransformer &transformer, - unique_ptr update_match_info) { - auto result = std::move(update_match_info); + optional> update_match_info) { + auto result = update_match_info ? std::move(*update_match_info) : nullptr; if (!result) { result = make_uniq(); } @@ -90,8 +97,8 @@ unique_ptr PEGTransformerFactory::TransformDeleteMatchClause(PE unique_ptr PEGTransformerFactory::TransformInsertMatchClause(PEGTransformer &transformer, - unique_ptr insert_match_info) { - auto result = std::move(insert_match_info); + optional> insert_match_info) { + auto result = insert_match_info ? std::move(*insert_match_info) : nullptr; if (!result) { result = make_uniq(); } @@ -105,19 +112,23 @@ unique_ptr PEGTransformerFactory::TransformInsertDefaultValues( return result; } -unique_ptr -PEGTransformerFactory::TransformInsertByNameOrPosition(PEGTransformer &transformer, - const InsertColumnOrder &by_name_or_position) { +unique_ptr PEGTransformerFactory::TransformInsertByNameOrPosition( + PEGTransformer &transformer, const optional &by_name_or_position, const bool &has_result) { auto result = make_uniq(); - result->column_order = by_name_or_position; + if (by_name_or_position) { + result->column_order = *by_name_or_position; + } return result; } unique_ptr -PEGTransformerFactory::TransformInsertValuesList(PEGTransformer &transformer, const vector &insert_column_list, +PEGTransformerFactory::TransformInsertValuesList(PEGTransformer &transformer, + const optional> &insert_column_list, vector> expression) { auto result = make_uniq(); - result->insert_columns = StringsToIdentifiers(insert_column_list); + if (insert_column_list) { + result->insert_columns = StringsToIdentifiers(*insert_column_list); + } result->expressions = std::move(expression); return result; } @@ -128,12 +139,13 @@ unique_ptr PEGTransformerFactory::TransformDoNothingMatchClause return result; } -unique_ptr PEGTransformerFactory::TransformErrorMatchClause(PEGTransformer &transformer, - unique_ptr expression) { +unique_ptr +PEGTransformerFactory::TransformErrorMatchClause(PEGTransformer &transformer, + optional> expression) { auto result = make_uniq(); result->action_type = MergeActionType::MERGE_ERROR; if (expression) { - result->expressions.push_back(std::move(expression)); + result->expressions.push_back(std::move(*expression)); } return result; } @@ -144,10 +156,12 @@ unique_ptr PEGTransformerFactory::TransformAndExpression(PEGTr } pair> PEGTransformerFactory::TransformNotMatchedClause( - PEGTransformer &transformer, const MergeActionCondition &by_source_or_target, - unique_ptr and_expression, unique_ptr matched_clause_action) { - matched_clause_action->condition = std::move(and_expression); - auto action_condition = by_source_or_target; + PEGTransformer &transformer, const optional &by_source_or_target, + optional> and_expression, unique_ptr matched_clause_action) { + if (and_expression) { + matched_clause_action->condition = std::move(*and_expression); + } + auto action_condition = by_source_or_target.value_or(MergeActionCondition::WHEN_MATCHED); if (action_condition == MergeActionCondition::WHEN_MATCHED) { action_condition = MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET; } diff --git a/src/duckdb/src/parser/peg/transformer/transform_pragma.cpp b/src/duckdb/src/parser/peg/transformer/transform_pragma.cpp index 8381a41d5..2eec81384 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_pragma.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_pragma.cpp @@ -44,14 +44,14 @@ PEGTransformerFactory::TransformPragmaAssign(PEGTransformer &transformer, const unique_ptr PEGTransformerFactory::TransformPragmaFunction(PEGTransformer &transformer, const Identifier &pragma_name, - vector> pragma_parameters) { + optional>> pragma_parameters) { // Rule: PragmaFunction <- PragmaName PragmaParameters? auto result = make_uniq(); result->info->name = pragma_name; - if (pragma_parameters.empty()) { + if (!pragma_parameters) { return std::move(result); } - for (auto ¶meter : pragma_parameters) { + for (auto ¶meter : *pragma_parameters) { if (parameter->GetExpressionType() == ExpressionType::COMPARE_EQUAL) { auto &comp = parameter->Cast(); if (comp.Left().GetExpressionType() != ExpressionType::COLUMN_REF) { diff --git a/src/duckdb/src/parser/peg/transformer/transform_prepare.cpp b/src/duckdb/src/parser/peg/transformer/transform_prepare.cpp index 132e591e0..e310ee24f 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_prepare.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_prepare.cpp @@ -16,10 +16,10 @@ bool IsPrepareableStatement(StatementType type) { } } -unique_ptr PEGTransformerFactory::TransformPrepareStatement(PEGTransformer &transformer, - const Identifier &identifier, - const vector &type_list, - unique_ptr statement) { +unique_ptr +PEGTransformerFactory::TransformPrepareStatement(PEGTransformer &transformer, const Identifier &identifier, + const optional> &type_list, + unique_ptr statement) { auto result = make_uniq(); result->name = identifier; if (!IsPrepareableStatement(statement->type)) { diff --git a/src/duckdb/src/parser/peg/transformer/transform_set.cpp b/src/duckdb/src/parser/peg/transformer/transform_set.cpp index 6890ef587..4026bcb5e 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_set.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_set.cpp @@ -21,11 +21,14 @@ PEGTransformerFactory::TransformSetAssignment(PEGTransformer &transformer, } // SetSetting <- SettingScope? SettingName -SettingInfo PEGTransformerFactory::TransformSetSetting(PEGTransformer &transformer, const SetScope &setting_scope, +SettingInfo PEGTransformerFactory::TransformSetSetting(PEGTransformer &transformer, + const optional &setting_scope, const Identifier &setting_name) { SettingInfo result; result.name = setting_name; - result.scope = setting_scope; + if (setting_scope) { + result.scope = *setting_scope; + } return result; } @@ -129,7 +132,7 @@ SetScope PEGTransformerFactory::TransformGlobalScope(PEGTransformer &transformer // ZoneIntervalWithInterval <- 'INTERVAL' StringLiteral Interval? unique_ptr PEGTransformerFactory::TransformZoneIntervalWithInterval(PEGTransformer &transformer, const string &string_literal, - const DatePartSpecifier &interval) { + const optional &interval) { auto expr = make_uniq(Value(string_literal)); return make_uniq(LogicalType::INTERVAL, std::move(expr)); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_transaction.cpp b/src/duckdb/src/parser/peg/transformer/transform_transaction.cpp index f852077a9..487cf2614 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_transaction.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_transaction.cpp @@ -4,10 +4,12 @@ namespace duckdb { unique_ptr -PEGTransformerFactory::TransformBeginTransaction(PEGTransformer &transformer, - const TransactionModifierType &read_or_write) { +PEGTransformerFactory::TransformBeginTransaction(PEGTransformer &transformer, const bool &has_result, + const optional &read_or_write) { auto info = make_uniq(TransactionType::BEGIN_TRANSACTION); - info->modifier = read_or_write; + if (read_or_write) { + info->modifier = *read_or_write; + } return make_uniq(std::move(info)); } @@ -25,11 +27,13 @@ TransactionModifierType PEGTransformerFactory::TransformReadWrite(PEGTransformer return TransactionModifierType::TRANSACTION_READ_WRITE; } -unique_ptr PEGTransformerFactory::TransformCommitTransaction(PEGTransformer &transformer) { +unique_ptr PEGTransformerFactory::TransformCommitTransaction(PEGTransformer &transformer, + const bool &has_result) { return make_uniq(make_uniq(TransactionType::COMMIT)); } -unique_ptr PEGTransformerFactory::TransformRollbackTransaction(PEGTransformer &transformer) { +unique_ptr PEGTransformerFactory::TransformRollbackTransaction(PEGTransformer &transformer, + const bool &has_result) { return make_uniq(make_uniq(TransactionType::ROLLBACK)); } } // namespace duckdb diff --git a/src/duckdb/src/parser/peg/transformer/transform_update.cpp b/src/duckdb/src/parser/peg/transformer/transform_update.cpp index e3e5dd239..d8c447127 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_update.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_update.cpp @@ -6,17 +6,26 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformUpdateStatement( - PEGTransformer &transformer, CommonTableExpressionMap with_clause, unique_ptr update_target, - unique_ptr update_set_clause, unique_ptr from_clause, - unique_ptr where_clause, vector> returning_clause) { + PEGTransformer &transformer, optional with_clause, unique_ptr update_target, + unique_ptr update_set_clause, optional> from_clause, + optional> where_clause, + optional>> returning_clause) { auto result = make_uniq(); auto &node = *result->node; - node.cte_map = std::move(with_clause); + if (with_clause) { + node.cte_map = std::move(*with_clause); + } node.table = std::move(update_target); node.set_info = std::move(update_set_clause); - node.from_table = std::move(from_clause); - node.set_info->condition = std::move(where_clause); - node.returning_list = std::move(returning_clause); + if (from_clause) { + node.from_table = std::move(*from_clause); + } + if (where_clause) { + node.set_info->condition = std::move(*where_clause); + } + if (returning_clause) { + node.returning_list = std::move(*returning_clause); + } return std::move(result); } @@ -27,12 +36,15 @@ unique_ptr PEGTransformerFactory::TransformBaseTableSet(PEGTransformer unique_ptr PEGTransformerFactory::TransformBaseTableAliasSet(PEGTransformer &transformer, unique_ptr base_table_name, - const Identifier &update_alias) { - base_table_name->alias = update_alias; + const optional &update_alias) { + if (update_alias) { + base_table_name->alias = *update_alias; + } return std::move(base_table_name); } -Identifier PEGTransformerFactory::TransformUpdateAlias(PEGTransformer &transformer, const Identifier &col_id) { +Identifier PEGTransformerFactory::TransformUpdateAlias(PEGTransformer &transformer, const bool &has_result, + const Identifier &col_id) { return Identifier(col_id); } @@ -86,8 +98,8 @@ PEGTransformerFactory::TransformUpdateSetElement(PEGTransformer &transformer, co } string PEGTransformerFactory::TransformUpdateSetColumnTarget(PEGTransformer &transformer, const Identifier &column_name, - const vector &dot_identifier) { - if (!dot_identifier.empty()) { + const optional> &dot_identifier) { + if (dot_identifier) { throw ParserException("Qualified column names in UPDATE .. SET not supported"); } return column_name.GetIdentifierName(); diff --git a/src/duckdb/src/parser/peg/transformer/transform_use.cpp b/src/duckdb/src/parser/peg/transformer/transform_use.cpp index 061f815ab..50e23914a 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_use.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_use.cpp @@ -33,11 +33,11 @@ QualifiedName PEGTransformerFactory::TransformCatalogNameAsUseTarget(PEGTransfor } // UseTargetCatalogSchema <- CatalogName '.' ReservedSchemaName DotIdentifier* -QualifiedName PEGTransformerFactory::TransformUseTargetCatalogSchema(PEGTransformer &transformer, - const Identifier &catalog_name, - const Identifier &reserved_schema_name, - const vector &dot_identifier) { - if (!dot_identifier.empty()) { +QualifiedName +PEGTransformerFactory::TransformUseTargetCatalogSchema(PEGTransformer &transformer, const Identifier &catalog_name, + const Identifier &reserved_schema_name, + const optional> &dot_identifier) { + if (dot_identifier && !dot_identifier->empty()) { throw ParserException("Expected \"USE database\" or \"USE database.schema\""); } QualifiedName result; diff --git a/src/duckdb/src/parser/peg/transformer/transform_vacuum.cpp b/src/duckdb/src/parser/peg/transformer/transform_vacuum.cpp index 7d0139980..624e76f35 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_vacuum.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_vacuum.cpp @@ -4,30 +4,36 @@ namespace duckdb { unique_ptr PEGTransformerFactory::TransformVacuumStatement(PEGTransformer &transformer, - const VacuumOptions &vacuum_options, - AnalyzeTarget analyze_target) { - auto result = make_uniq(vacuum_options); - if (analyze_target.ref) { - result->info->columns = analyze_target.columns; - result->info->ref = std::move(analyze_target.ref); + const optional &vacuum_options, + optional analyze_target) { + VacuumOptions options; + if (vacuum_options) { + options = *vacuum_options; + } + auto result = make_uniq(options); + if (analyze_target && analyze_target->ref) { + result->info->columns = analyze_target->columns; + result->info->ref = std::move(analyze_target->ref); result->info->has_table = true; } return std::move(result); } -VacuumOptions PEGTransformerFactory::TransformVacuumLegacyOptions(PEGTransformer &transformer, const string &opt_full, - const string &opt_freeze, const string &opt_verbose, - const string &opt_analyze) { +VacuumOptions PEGTransformerFactory::TransformVacuumLegacyOptions(PEGTransformer &transformer, + const optional &opt_full, + const optional &opt_freeze, + const optional &opt_verbose, + const optional &opt_analyze) { VacuumOptions options; options.vacuum = true; - options.analyze = !opt_analyze.empty(); - if (!opt_full.empty()) { + options.analyze = opt_analyze.has_value(); + if (opt_full) { throw NotImplementedException("FULL is not yet implemented"); } - if (!opt_freeze.empty()) { + if (opt_freeze) { throw NotImplementedException("FREEZE is not yet implemented"); } - if (!opt_verbose.empty()) { + if (opt_verbose) { throw NotImplementedException("VERBOSE is not yet implemented"); } return options; diff --git a/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp index c9b597967..f4745e4ea 100644 --- a/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_function_expression.cpp @@ -378,14 +378,15 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc auto &args = function.GetArgumentsMutable(); - // the first child is the list, the second child is the lambda expression - // constexpr idx_t list_ix = 0; - constexpr idx_t list_idx = 0; - constexpr idx_t lambda_expr_idx = 1; - D_ASSERT(args[lambda_expr_idx].GetExpression().GetExpressionClass() == ExpressionClass::LAMBDA); + // list lambda functions use the existing (list, lambda) shape; invoke is the only lambda + // function that accepts the lambda expression as the first argument. + const idx_t lambda_expr_idx = func.name == "invoke" ? 0 : 1; + if (args.size() <= lambda_expr_idx || + args[lambda_expr_idx].GetExpression().GetExpressionClass() != ExpressionClass::LAMBDA) { + return BindResult("This scalar function requires a lambda expression!"); + } vector function_child_types; - // bind the list ErrorData error; for (idx_t i = 0; i < function.GetArguments().size(); i++) { @@ -408,13 +409,15 @@ BindResult ExpressionBinder::BindLambdaFunction(FunctionExpression &function, Sc function_child_types.push_back(child->GetReturnType()); } - // get the logical type of the children of the list - auto &list_child = BoundExpression::GetExpression(*args[list_idx].GetExpressionMutable()); - if (list_child->GetReturnType().id() != LogicalTypeId::LIST && - list_child->GetReturnType().id() != LogicalTypeId::ARRAY && - list_child->GetReturnType().id() != LogicalTypeId::SQLNULL && - list_child->GetReturnType().id() != LogicalTypeId::UNKNOWN) { - return BindResult("Invalid LIST argument during lambda function binding!"); + if (lambda_expr_idx == 1) { + // get the logical type of the children of the list + auto &list_child = BoundExpression::GetExpression(*args[0].GetExpressionMutable()); + if (list_child->GetReturnType().id() != LogicalTypeId::LIST && + list_child->GetReturnType().id() != LogicalTypeId::ARRAY && + list_child->GetReturnType().id() != LogicalTypeId::SQLNULL && + list_child->GetReturnType().id() != LogicalTypeId::UNKNOWN) { + return BindResult("Invalid LIST argument during lambda function binding!"); + } } // bind the lambda parameter diff --git a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp index f9a1fdef2..7b1764412 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp @@ -638,7 +638,9 @@ BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from auto &expanded = expr->Cast(); auto &struct_expressions = expanded.GetChildrenMutable(); - D_ASSERT(!struct_expressions.empty()); + if (struct_expressions.empty()) { + throw BinderException("UNNEST of an empty struct is not supported"); + } for (auto &struct_expr : struct_expressions) { new_names.emplace_back(struct_expr->GetName()); diff --git a/src/duckdb/src/planner/binder/query_node/bind_trigger_expansion.cpp b/src/duckdb/src/planner/binder/query_node/bind_trigger_expansion.cpp index 14288adb1..44cc76f29 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_trigger_expansion.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_trigger_expansion.cpp @@ -21,8 +21,15 @@ #include "duckdb/common/identifier.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" +#include "duckdb/planner/expression_binder.hpp" #include "duckdb/planner/expression_binder/returning_binder.hpp" #include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_cteref.hpp" +#include "duckdb/planner/operator/logical_delete.hpp" +#include "duckdb/planner/operator/logical_insert.hpp" +#include "duckdb/planner/operator/logical_materialized_cte.hpp" +#include "duckdb/planner/operator/logical_trigger.hpp" +#include "duckdb/planner/operator/logical_update.hpp" #include #include @@ -46,8 +53,9 @@ unique_ptr Binder::TryExpandTriggers(QueryNode &node, TableCatal return nullptr; } auto txn = table.ParentCatalog().GetCatalogTransaction(context); - auto before_triggers = table.GetTriggersForEvent(txn, TriggerTiming::BEFORE, event_type); - auto after_triggers = table.GetTriggersForEvent(txn, TriggerTiming::AFTER, event_type); + // FOR EACH ROW triggers are expanded separately via TryExpandRowTriggers; this path only fires statement triggers. + auto before_triggers = table.GetTriggersForEvent(txn, TriggerTiming::BEFORE, event_type, TriggerForEach::STATEMENT); + auto after_triggers = table.GetTriggersForEvent(txn, TriggerTiming::AFTER, event_type, TriggerForEach::STATEMENT); // UPDATE OF : drop triggers whose OF list is disjoint from the SET list. // Triggers without an OF list are unrestricted and always fire. @@ -392,4 +400,179 @@ BoundStatement Binder::ExpandTriggers(QueryNode &node, TableCatalogEntry &table, return result; } +static bool BoundBodyTargetsTable(const LogicalOperator &op, const TableCatalogEntry &table) { + if (op.type == LogicalOperatorType::LOGICAL_INSERT) { + return &op.Cast().table == &table; + } + if (op.type == LogicalOperatorType::LOGICAL_DELETE) { + return &op.Cast().table == &table; + } + if (op.type == LogicalOperatorType::LOGICAL_UPDATE) { + return &op.Cast().table == &table; + } + for (auto &child : op.children) { + if (child && BoundBodyTargetsTable(*child, table)) { + return true; + } + } + return false; +} + +// Detects a nested LogicalTrigger in a bound trigger body, which means the body targets a table that itself has a +// FOR EACH ROW trigger (a cascade). Cascades are rejected for now for two reasons: +// 1. Over-firing: the inner trigger's affected-row set gets cross-joined with the outer rows during decorrelation, +// so it fires once per (outer row x inner row) instead of once per inner row. +// 2. all firings run as one set-based batch against a single snapshot, so a downstream trigger cannot +// see rows an upstream trigger wrote earlier in the same statement. Correct semantics need per-row execution. +bool BoundBodyContainsTrigger(const LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_TRIGGER) { + return true; + } + for (auto &child : op.children) { + if (child && BoundBodyContainsTrigger(*child)) { + return true; + } + } + return false; +} + +unique_ptr Binder::TryExpandRowTriggers(QueryNode &node, + vector> &returning_list, + TableCatalogEntry &table, TriggerEventType event_type) { + auto &expanded_tables = global_binder_state->trigger_expanded_tables; + if (expanded_tables.find(table) != expanded_tables.end()) { + return nullptr; + } + auto triggers = table.GetTriggersForEvent(table.ParentCatalog().GetCatalogTransaction(context), + TriggerTiming::AFTER, event_type, TriggerForEach::ROW); + if (triggers.empty()) { + return nullptr; + } + if (node.type == QueryNodeType::INSERT_QUERY_NODE && node.Cast().on_conflict_info) { + // Updated rows via ON CONFLICT currently appear in the INSERT affected-row set. This would fire INSERT triggers + // for those rows. Therefore, the combination is currently rejected. + throw NotImplementedException("ON CONFLICT is not yet supported on tables with FOR EACH ROW triggers"); + } + if (!returning_list.empty()) { + throw NotImplementedException("RETURNING is not yet supported on tables with FOR EACH ROW triggers"); + } + expanded_tables.insert(table); + auto bound = ExpandRowTriggers(node, returning_list, table, triggers); + expanded_tables.erase(table); + return make_uniq(std::move(bound)); +} + +unique_ptr Binder::SetupNewRowScope(TableIndex table_index, const vector &col_names, + const vector &col_types) { + bind_context.AddGenericBinding(table_index, "new", col_names, col_types); + auto scope_binder = make_uniq(*this, context); + GetActiveBinders().push_back(*scope_binder); + return scope_binder; +} + +BoundStatement Binder::ExpandRowTriggers(QueryNode &node, vector> &returning_list, + const TableCatalogEntry &table, + const vector> &triggers) { + D_ASSERT(!triggers.empty()); + D_ASSERT(returning_list.empty()); + returning_list.push_back(make_uniq()); + + auto uuid_suffix = UUID::ToString(UUID::GenerateRandomUUID()); + Identifier base_cte_name(TRIGGER_BASE_CTE_PREFIX + uuid_suffix); + + auto base_cte = make_uniq(); + base_cte->query_node = node.Copy(); + base_cte->materialized = CTEMaterialize::CTE_MATERIALIZE_ALWAYS; + base_cte->is_trigger_generated = true; + + auto outer = make_uniq(); + outer->select_list.push_back(make_uniq("count_star", vector>())); + auto from_ref = make_uniq(); + from_ref->table_name = base_cte_name; + outer->from_table = std::move(from_ref); + outer->cte_map.map[base_cte_name] = std::move(base_cte); + + auto bound = Bind(*outer); + auto &base_mat_cte = bound.plan->Cast(); + auto cte_table_idx = base_mat_cte.table_index; + + // proj_idx is the binding source for NEW.col refs in trigger bodies. + vector col_names; + vector col_types; + for (auto &col : table.GetColumns().Physical()) { + col_names.push_back(col.GetName()); + col_types.push_back(col.GetType()); + } + + auto cte_ref_idx = GenerateTableIndex(); + auto cte_ref = make_uniq(cte_ref_idx, cte_table_idx, col_types, col_names, false); + cte_ref->ResolveOperatorTypes(); + + auto proj_idx = GenerateTableIndex(); // the table_index for NEW bindings + vector> proj_exprs; + for (idx_t i = 0; i < col_types.size(); i++) { + proj_exprs.push_back(make_uniq(col_names[i], col_types[i], + ColumnBinding(cte_ref_idx, ProjectionIndex(i)))); + } + auto new_rows_proj = make_uniq(proj_idx, std::move(proj_exprs)); + new_rows_proj->children.push_back(std::move(cte_ref)); + new_rows_proj->ResolveOperatorTypes(); + + auto new_scope_binder = SetupNewRowScope(proj_idx, col_names, col_types); + + unique_ptr trigger_plan = std::move(new_rows_proj); + for (idx_t i = 0; i < triggers.size(); i++) { + auto &trigger = triggers[i].get(); + + auto child_binder = Binder::CreateBinder(context, this); + auto body_copy = trigger.trigger_action->Copy(); + auto bound_body = child_binder->Bind(*body_copy); + + CorrelatedColumns corr_cols = std::move(child_binder->correlated_columns); + if (corr_cols.empty()) { + throw BinderException("FOR EACH ROW trigger \"%s\" on table \"%s\" must reference at least one NEW " + "column in the trigger body (use FOR EACH STATEMENT if row data is not needed)", + trigger.name, table.name); + } + + if (BoundBodyTargetsTable(*bound_body.plan, table)) { + throw BinderException("FOR EACH ROW trigger \"%s\" on table \"%s\" writes to the trigger table " + "(self-referential triggers are not supported)", + trigger.name, table.name); + } + + if (BoundBodyContainsTrigger(*bound_body.plan)) { + throw NotImplementedException("FOR EACH ROW trigger \"%s\" on table \"%s\" writes to a table that has its " + "own FOR EACH ROW trigger (cascading row triggers are not yet supported)", + trigger.name, table.name); + } + + auto logi_trig = make_uniq(trigger.name.GetIdentifierName(), trigger.timing, trigger.event_type, + std::move(corr_cols)); + logi_trig->children.push_back(std::move(trigger_plan)); + logi_trig->children.push_back(std::move(bound_body.plan)); + logi_trig->ResolveOperatorTypes(); + trigger_plan = std::move(logi_trig); + } + // remove new_scope_binder + GetActiveBinders().pop_back(); + + Identifier trigger_cte_name(string(TRIGGER_BODY_CTE_PREFIX) + "row_" + uuid_suffix); + auto trigger_cte_idx = GenerateTableIndex(); + auto outer_query = std::move(bound.plan->children[1]); + + auto trigger_mat_cte = + make_uniq(trigger_cte_name, trigger_cte_idx, col_types.size(), std::move(trigger_plan), + std::move(outer_query), CTEMaterialize::CTE_MATERIALIZE_DEFAULT); + trigger_mat_cte->ResolveOperatorTypes(); + + bound.plan->children[1] = std::move(trigger_mat_cte); + bound.plan->ResolveOperatorTypes(); + + auto &properties = GetStatementProperties(); + properties.return_type = StatementReturnType::CHANGED_ROWS; + properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + return bound; +} + } // namespace duckdb diff --git a/src/duckdb/src/planner/binder/statement/bind_create.cpp b/src/duckdb/src/planner/binder/statement/bind_create.cpp index cdb6f4824..84416f015 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create.cpp @@ -476,6 +476,8 @@ void Binder::BindLogicalType(LogicalType &type) { }); } +bool BoundBodyContainsTrigger(const LogicalOperator &op); + SchemaCatalogEntry &Binder::BindCreateTriggerInfo(CreateTriggerInfo &create_trigger_info) { // Resolve the base table first — triggers inherit catalog/schema from their table (like Postgres) TableDescription table_description(create_trigger_info.base_table->catalog_name, @@ -521,8 +523,16 @@ SchemaCatalogEntry &Binder::BindCreateTriggerInfo(CreateTriggerInfo &create_trig } } } - if (create_trigger_info.for_each == TriggerForEach::ROW) { - throw NotImplementedException("FOR EACH ROW triggers are not yet supported"); + if (create_trigger_info.for_each == TriggerForEach::ROW && + (!create_trigger_info.referencing_new_table.empty() || !create_trigger_info.referencing_old_table.empty())) { + throw BinderException("REFERENCING is not valid for FOR EACH ROW triggers"); + } + if (create_trigger_info.for_each == TriggerForEach::ROW && create_trigger_info.timing != TriggerTiming::AFTER) { + throw NotImplementedException("BEFORE FOR EACH ROW triggers are not yet supported"); + } + if (create_trigger_info.for_each == TriggerForEach::ROW && + create_trigger_info.event_type != TriggerEventType::INSERT_EVENT) { + throw NotImplementedException("UPDATE and DELETE FOR EACH ROW triggers are not yet supported"); } if ((!create_trigger_info.referencing_new_table.empty() || !create_trigger_info.referencing_old_table.empty()) && create_trigger_info.timing != TriggerTiming::AFTER) { @@ -545,6 +555,23 @@ SchemaCatalogEntry &Binder::BindCreateTriggerInfo(CreateTriggerInfo &create_trig throw BinderException("REFERENCING NEW TABLE AS is not valid for AFTER DELETE triggers"); } + auto opposite_for_each = + create_trigger_info.for_each == TriggerForEach::ROW ? TriggerForEach::STATEMENT : TriggerForEach::ROW; + // Statement and row triggers use separate expansion paths that don't compose, so reject mixing them per event. + // CREATE OR REPLACE that targets the same-named trigger is allowed: that trigger is atomically replaced, + // so the final catalog contains only the new one and there is no mixing. + auto conflicting = table.GetTriggersForEvent(table.ParentCatalog().GetCatalogTransaction(context), + create_trigger_info.event_type, opposite_for_each); + bool is_replace = create_trigger_info.on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT; + auto has_real_conflict = + std::any_of(conflicting.begin(), conflicting.end(), [&](const_reference t) { + return !(is_replace && t.get().name == create_trigger_info.trigger_name); + }); + if (has_real_conflict) { + throw NotImplementedException( + "Mixing FOR EACH STATEMENT and FOR EACH ROW triggers on the same table is not yet supported"); + } + // Validate the trigger body using an isolated binder (own GlobalBinderState). // Set up trigger_expanded_tables to match runtime behavior. // Set up trigger_creation_table to detect recursive triggers during the validation. @@ -559,7 +586,44 @@ SchemaCatalogEntry &Binder::BindCreateTriggerInfo(CreateTriggerInfo &create_trig body_copy->cte_map.map[alias] = MakeTriggerValidationCTE(table); } } - validation_binder->Bind(*body_copy); + // For FOR EACH ROW: register NEW as a generic binding so BindCorrelatedColumns can resolve NEW.col column + // references. + unique_ptr row_scope_binder; + if (create_trigger_info.for_each == TriggerForEach::ROW) { + if (table.HasGeneratedColumns()) { + throw NotImplementedException( + "FOR EACH ROW triggers on tables with generated columns are not yet supported"); + } + if (create_trigger_info.trigger_action->type == QueryNodeType::UPDATE_QUERY_NODE) { + throw NotImplementedException("UPDATE trigger bodies in FOR EACH ROW triggers are not yet supported"); + } + vector col_names; + vector col_types; + for (auto &col : table.GetColumns().Physical()) { + col_names.push_back(col.GetName()); + col_types.push_back(col.GetType()); + } + auto new_idx = validation_binder->GenerateTableIndex(); + row_scope_binder = validation_binder->SetupNewRowScope(new_idx, col_names, col_types); + } + if (row_scope_binder) { + auto body_binder = Binder::CreateBinder(context, validation_binder.get()); + auto bound_body = body_binder->Bind(*body_copy); + validation_binder->GetActiveBinders().pop_back(); + if (body_binder->correlated_columns.empty()) { + throw BinderException("FOR EACH ROW trigger \"%s\" on table \"%s\" must reference at least one NEW " + "column in the trigger body (use FOR EACH STATEMENT if row data is not needed)", + create_trigger_info.trigger_name, table.name); + } + if (BoundBodyContainsTrigger(*bound_body.plan)) { + throw NotImplementedException( + "FOR EACH ROW trigger \"%s\" on table \"%s\" writes to a table that has its own FOR EACH ROW " + "trigger (cascading row triggers are not yet supported)", + create_trigger_info.trigger_name, table.name); + } + } else { + validation_binder->Bind(*body_copy); + } // Add table dependency create_trigger_info.dependencies.AddDependency(table); diff --git a/src/duckdb/src/planner/binder/statement/bind_insert.cpp b/src/duckdb/src/planner/binder/statement/bind_insert.cpp index 08ed07a24..227b398db 100644 --- a/src/duckdb/src/planner/binder/statement/bind_insert.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_insert.cpp @@ -604,6 +604,9 @@ BoundStatement Binder::BindNode(InsertQueryNode &node) { if (auto expanded = TryExpandTriggers(node, table, TriggerEventType::INSERT_EVENT)) { return std::move(*expanded); } + if (auto expanded = TryExpandRowTriggers(node, node.returning_list, table, TriggerEventType::INSERT_EVENT)) { + return std::move(*expanded); + } if (node.on_conflict_info) { // generate a MERGE INTO statement and bind it instead diff --git a/src/duckdb/src/planner/expression/bound_function_expression.cpp b/src/duckdb/src/planner/expression/bound_function_expression.cpp index a33689666..f8b080b15 100644 --- a/src/duckdb/src/planner/expression/bound_function_expression.cpp +++ b/src/duckdb/src/planner/expression/bound_function_expression.cpp @@ -39,12 +39,10 @@ bool BoundFunctionExpression::IsFoldable() const { if (function.HasBindLambdaCallback()) { // This is a lambda function D_ASSERT(bind_info); - auto &lambda_bind_data = bind_info->Cast(); - if (lambda_bind_data.lambda_expr) { - auto &expr = *lambda_bind_data.lambda_expr; - if (expr.IsVolatile()) { - return false; - } + auto &lambda_bind_data = bind_info->Cast(); + auto lambda_expr = lambda_bind_data.GetLambdaExpression(); + if (lambda_expr && lambda_expr->IsVolatile()) { + return false; } } return function.GetStability() == FunctionStability::VOLATILE ? false : Expression::IsFoldable(); diff --git a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp index b9941868c..be88ceb52 100644 --- a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp @@ -1,4 +1,8 @@ #include "duckdb/planner/expression_binder/table_function_binder.hpp" +#include "duckdb/common/enums/table_function_identifier_conversion.hpp" +#include "duckdb/common/sql_identifier.hpp" +#include "duckdb/logging/logger.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/parser/expression/columnref_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/table_binding.hpp" @@ -72,6 +76,21 @@ BindResult TableFunctionBinder::BindColumnReference(unique_ptr result_name); } + auto setting = Settings::Get(context); + auto implicit_conversion_disabled = setting == TableFunctionIdentifierConversion::DISABLE_IMPLICIT_STRING; + auto warn_implicit_conversion = setting == TableFunctionIdentifierConversion::DEFAULT; + const auto msg = + StringUtil::Format("Deprecated implicit conversion of unbound identifiers to strings in table function " + "arguments detected. Please use a string literal instead, e.g. %s.\n" + "Use SET table_function_identifier_conversion='ENABLE_IMPLICIT_STRING' to revert to the " + "deprecated behavior.", + SQLString::ToString(result_name)); + if (implicit_conversion_disabled) { + throw BinderException(query_location, msg); + } + if (warn_implicit_conversion) { + DUCKDB_LOG_WARNING(context, msg); + } return BindResult(make_uniq(Value(result_name))); } diff --git a/src/duckdb/src/planner/operator/logical_trigger.cpp b/src/duckdb/src/planner/operator/logical_trigger.cpp new file mode 100644 index 000000000..945c14947 --- /dev/null +++ b/src/duckdb/src/planner/operator/logical_trigger.cpp @@ -0,0 +1,19 @@ +#include "duckdb/planner/operator/logical_trigger.hpp" + +namespace duckdb { + +LogicalTrigger::LogicalTrigger(string trigger_name_p, TriggerTiming timing_p, TriggerEventType event_type_p, + CorrelatedColumns correlated_columns_p) + : LogicalOperator(LogicalOperatorType::LOGICAL_TRIGGER), trigger_name(std::move(trigger_name_p)), timing(timing_p), + event_type(event_type_p), correlated_columns(std::move(correlated_columns_p)) { +} + +vector LogicalTrigger::GetColumnBindings() { + return children[0]->GetColumnBindings(); +} + +void LogicalTrigger::ResolveTypes() { + types = children[0]->types; +} + +} // namespace duckdb diff --git a/src/duckdb/src/planner/planner.cpp b/src/duckdb/src/planner/planner.cpp index 3ef0da32f..e57f97d0f 100644 --- a/src/duckdb/src/planner/planner.cpp +++ b/src/duckdb/src/planner/planner.cpp @@ -17,6 +17,8 @@ #include "duckdb/main/attached_database.hpp" #include "duckdb/parser/statement/multi_statement.hpp" #include "duckdb/planner/subquery/flatten_dependent_join.hpp" +#include "duckdb/planner/operator/logical_dependent_join.hpp" +#include "duckdb/planner/operator/logical_trigger.hpp" #include "duckdb/planner/operator_extension.hpp" #include "duckdb/planner/planner_extension.hpp" #include "duckdb/optimizer/optimizer.hpp" @@ -26,6 +28,35 @@ namespace duckdb { Planner::Planner(ClientContext &context) : binder(Binder::CreateBinder(context)), context(context) { } +// Pre-decorrelation pass: replace LogicalTrigger with LogicalDependentJoin so the standard +// FlattenDependentJoins machinery can decorrelate the trigger body. +static void RewriteTriggersToDependent(Binder &binder, LogicalOperator &op) { + for (auto &child : op.children) { + if (child) { + RewriteTriggersToDependent(binder, *child); + } + } + for (idx_t i = 0; i < op.children.size(); i++) { + if (!op.children[i] || op.children[i]->type != LogicalOperatorType::LOGICAL_TRIGGER) { + continue; + } + auto &trig = op.children[i]->Cast(); + auto dep_join = make_uniq(JoinType::INNER); + dep_join->correlated_columns = std::move(trig.correlated_columns); + // Trigger bodies have side effects and must fire once per row. Dedup on a synthetic per-row + // row_number() key instead of the NEW columns (mirrors PerformDuplicateElimination's + // perform_delim=false path). otherwise rows with identical NEW values would underfire. + auto binding = ColumnBinding(binder.GenerateTableIndex(), ProjectionIndex(0)); + CorrelatedColumnInfo info(binding, LogicalType::BIGINT, "delim_index", 0); + dep_join->correlated_columns.AddColumn(std::move(info)); + dep_join->correlated_columns.SetDelimIndexToZero(); + dep_join->perform_delim = false; + dep_join->children.push_back(std::move(trig.children[0])); + dep_join->children.push_back(std::move(trig.children[1])); + op.children[i] = std::move(dep_join); + } +} + static void CheckTreeDepth(const LogicalOperator &op, idx_t max_depth, idx_t depth = 0) { if (depth >= max_depth) { throw ParserException("Maximum tree depth of %lld exceeded in logical planner", max_depth); @@ -96,6 +127,7 @@ void Planner::CreatePlan(SQLStatement &statement) { auto max_tree_depth = Settings::Get(context); CheckTreeDepth(*plan, max_tree_depth); + RewriteTriggersToDependent(*this->binder, *this->plan); this->plan = FlattenDependentJoins::DecorrelateIndependent(*this->binder, std::move(this->plan)); } this->properties = binder->GetStatementProperties(); diff --git a/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp b/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp index 1364e04ab..7bd2f4856 100644 --- a/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp +++ b/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp @@ -772,23 +772,11 @@ bool GeneratedDedupRefEliminator::CoversAllDedupColumns(const GeneratedDedupRef optional_idx GeneratedDedupRefEliminator::FindGeneratedOutputBinding(const Expression &expr, const GeneratedDedupRef &dedup_ref) const { - optional_idx result; - bool unsupported = false; - ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { - if (unsupported || colref.Depth() != 0) { - return; - } - auto binding_idx = FindBindingIndex(dedup_ref.output_bindings, colref.Binding()); - if (!binding_idx.IsValid()) { - return; - } - if (result.IsValid() && result.GetIndex() != binding_idx.GetIndex()) { - unsupported = true; - return; - } - result = binding_idx; - }); - return unsupported ? optional_idx() : result; + ColumnBinding binding; + if (!GetBoundColumnRefBinding(expr, binding)) { + return optional_idx(); + } + return FindBindingIndex(dedup_ref.output_bindings, binding); } bool GeneratedDedupRefEliminator::ExpressionReferencesGeneratedSide(const Expression &expr, @@ -1782,23 +1770,11 @@ unique_ptr GeneratedDomainJoinEliminator::GetGeneratedDomain optional_idx GeneratedDomainJoinEliminator::FindOutputBinding(Expression &expr, const vector &bindings) const { - optional_idx result; - bool unsupported = false; - ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { - if (unsupported || colref.Depth() != 0) { - return; - } - auto binding_idx = FindBindingIndex(bindings, colref.Binding()); - if (!binding_idx.IsValid()) { - return; - } - if (result.IsValid() && result.GetIndex() != binding_idx.GetIndex()) { - unsupported = true; - return; - } - result = binding_idx; - }); - return unsupported ? optional_idx() : result; + ColumnBinding binding; + if (!GetBoundColumnRefBinding(expr, binding)) { + return optional_idx(); + } + return FindBindingIndex(bindings, binding); } bool GeneratedDomainJoinEliminator::ContainsRecursiveCTERef(LogicalOperator &op) const { diff --git a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp index 0388e83b3..f1e584f9f 100644 --- a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp +++ b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp @@ -855,6 +855,10 @@ vector FlattenDependentJoins::PushDownExpressionGet(unique_ptr state) { state = PushDownChild(plan, propagate_null_values, std::move(state)); + // Rewrite any depth>0 correlated refs already in the expression lists (VALUES (NEW.col)) + // before appending the delim-get bindings for the join condition. + RewriteCorrelatedExpressions::Rewrite(*plan, GetCurrentBindings(state), correlated_aliases); + auto &expr_get = plan->Cast(); for (auto &expr_list : expr_get.expressions) { AppendCorrelatedColumns(expr_list, state, false); @@ -866,6 +870,29 @@ vector FlattenDependentJoins::PushDownExpressionGet(unique_ptr FlattenDependentJoins::PushDownDML(unique_ptr &plan, bool propagate_null_values, + vector state) { + state = PushDownChild(plan, propagate_null_values, std::move(state)); + if (plan->type == LogicalOperatorType::LOGICAL_INSERT || plan->type == LogicalOperatorType::LOGICAL_UPDATE) { + // PushDownChild appended the correlated columns to the child projection. + // PhysicalInsert requires an exact column count and PhysicalUpdate requires the row-id last, + // so remove the appended columns. The DELIM_GET below re-supplies them. + // DELETE is skipped because PhysicalDelete reads the row-id by a fixed index. + auto &child = *plan->children[0]; + if (child.type == LogicalOperatorType::LOGICAL_PROJECTION && + child.expressions.size() > correlated_columns.size()) { + child.expressions.resize(child.expressions.size() - correlated_columns.size()); + child.ResolveOperatorTypes(); + } + } + // DML output does not carry the child columns, so re-expose the correlation keys in a separate DELIM_GET that + // the parent DelimJoin can reference. + auto expose_idx = binder.GenerateTableIndex(); + unique_ptr expose_delim = make_uniq(expose_idx, delim_types); + plan = LogicalCrossProduct::Create(std::move(plan), std::move(expose_delim)); + return CreateContiguousState(ColumnBinding(expose_idx, ProjectionIndex(0))); +} + vector FlattenDependentJoins::PushDownGet(unique_ptr &plan, vector state) { auto &get = plan->Cast(); @@ -1012,6 +1039,11 @@ vector FlattenDependentJoins::PushDownCorrelatedNode(unique_ptr TryBuildShreddingStats(const LogicalType &type, const BaseStatistics &input) { + switch (type.id()) { + case LogicalTypeId::STRUCT: { + auto &fields = StructType::GetChildTypes(type); + if (fields.empty()) { + // an empty object has no shredded fields to push down + return nullptr; + } + child_list_t typed_children; + vector> child_stats; + child_stats.reserve(fields.size()); + for (idx_t i = 0; i < fields.size(); i++) { + auto child_result = TryBuildShreddingStats(fields[i].second, StructStats::GetChildStats(input, i)); + if (!child_result) { + return nullptr; + } + typed_children.emplace_back(fields[i].first, child_result->GetType()); + child_stats.push_back(std::move(child_result)); + } + auto typed_value = BaseStatistics::CreateEmpty(LogicalType::STRUCT(std::move(typed_children))); + for (idx_t i = 0; i < child_stats.size(); i++) { + StructStats::SetChildStats(typed_value, i, *child_stats[i]); + } + typed_value.CopyValidity(input); + return WrapTypedValue(typed_value, nullptr).ToUnique(); + } + case LogicalTypeId::LIST: + case LogicalTypeId::ARRAY: { + const bool is_list = type.id() == LogicalTypeId::LIST; + auto &child_type = is_list ? ListType::GetChildType(type) : ArrayType::GetChildType(type); + auto &input_child = is_list ? ListStats::GetChildStats(input) : ArrayStats::GetChildStats(input); + auto child_result = TryBuildShreddingStats(child_type, input_child); + if (!child_result) { + return nullptr; + } + // a variant stores both LISTs and fixed-size ARRAYs as a (variable length) array + auto typed_value = BaseStatistics::CreateEmpty(LogicalType::LIST(child_result->GetType())); + ListStats::SetChildStats(typed_value, std::move(child_result)); + typed_value.CopyValidity(input); + return WrapTypedValue(typed_value, nullptr).ToUnique(); + } + default: + if (type.IsNested() || type.id() == LogicalTypeId::ENUM) { + // MAP / UNION / ENUM etc. are not stored in their source representation in the variant + return nullptr; + } + return input.ToUnique(); + } +} + +unique_ptr VariantStats::StatisticsPropagateToVariant(const LogicalType &source_type, + const BaseStatistics &child_stats) { + if (source_type.id() == LogicalTypeId::VARIANT) { + return nullptr; + } + auto shredding = TryBuildShreddingStats(source_type, child_stats); + if (!shredding) { + return nullptr; + } + auto result = VariantStats::CreateShredded(shredding->GetType()); + VariantStats::SetShreddedStats(result, *shredding); + // the cast preserves NULLs exactly, so the top-level variant validity matches the input validity + result.CopyBase(child_stats); + return result.ToUnique(); +} + unique_ptr VariantStats::WrapExtractedFieldAsVariant(const BaseStatistics &base_variant, const BaseStatistics &extracted_field) { D_ASSERT(base_variant.type.id() == LogicalTypeId::VARIANT); diff --git a/src/duckdb/src/storage/table/struct_column_data.cpp b/src/duckdb/src/storage/table/struct_column_data.cpp index 1c1898aa2..c435fbc21 100644 --- a/src/duckdb/src/storage/table/struct_column_data.cpp +++ b/src/duckdb/src/storage/table/struct_column_data.cpp @@ -16,8 +16,7 @@ StructColumnData::StructColumnData(BlockManager &block_manager, DataTableInfo &i : ColumnData(block_manager, info, column_index, std::move(type_p), data_type, parent) { D_ASSERT(type.InternalType() == PhysicalType::STRUCT); auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(!child_types.empty()); - if (type.id() != LogicalTypeId::UNION && StructType::IsUnnamed(type)) { + if (type.id() != LogicalTypeId::UNION && !child_types.empty() && StructType::IsUnnamed(type)) { throw InvalidInputException("A table cannot be created from an unnamed struct"); } if (type.id() == LogicalTypeId::VARIANT) { diff --git a/src/duckdb/src/storage/table/variant/variant_shredding.cpp b/src/duckdb/src/storage/table/variant/variant_shredding.cpp index 8cd55ae33..25545aadc 100644 --- a/src/duckdb/src/storage/table/variant/variant_shredding.cpp +++ b/src/duckdb/src/storage/table/variant/variant_shredding.cpp @@ -10,6 +10,7 @@ #include "duckdb/function/variant/variant_shredding.hpp" #include "duckdb/function/variant/variant_normalize.hpp" #include "duckdb/common/serializer/varint.hpp" +#include "duckdb/common/vector_operations/vector_operations.hpp" #ifdef DEBUG #include "duckdb/common/value_operations/value_operations.hpp" #endif @@ -403,7 +404,7 @@ static LogicalType SetShreddedType(const LogicalType &typed_value, bool fully_co } bool VariantShreddingStats::GetShreddedTypeInternal(const VariantColumnStatsData &column, LogicalType &out_type, - optional_idx parent_count) const { + optional_idx parent_count, bool force_partial) const { if (parent_count.IsValid() && column.total_count > parent_count.GetIndex()) { throw InternalException("Column count is larger than parent count - this should not be possible"); } @@ -411,7 +412,7 @@ bool VariantShreddingStats::GetShreddedTypeInternal(const VariantColumnStatsData const auto null_count = column.type_counts[0]; if (null_count == column.total_count) { //! All NULL, emit INT32 - auto fully_consistent = null_count == total_value_count; + auto fully_consistent = !force_partial && null_count == total_value_count; out_type = SetShreddedType(LogicalTypeId::INTEGER, fully_consistent); return true; } @@ -435,13 +436,13 @@ bool VariantShreddingStats::GetShreddedTypeInternal(const VariantColumnStatsData return false; } - bool fully_consistent = max_count == total_value_count; + bool fully_consistent = !force_partial && max_count == total_value_count; if (type_index == static_cast(VariantLogicalType::OBJECT)) { child_list_t child_types; for (auto &entry : column.field_stats) { auto &child_column = GetColumnStats(entry.second); LogicalType child_type; - if (GetShreddedTypeInternal(child_column, child_type, total_value_count)) { + if (GetShreddedTypeInternal(child_column, child_type, total_value_count, force_partial)) { child_types.emplace_back(entry.first, child_type); } } @@ -458,7 +459,7 @@ bool VariantShreddingStats::GetShreddedTypeInternal(const VariantColumnStatsData D_ASSERT(column.element_stats != DConstants::INVALID_INDEX); auto &element_column = GetColumnStats(column.element_stats); LogicalType element_type; - if (!GetShreddedTypeInternal(element_column, element_type)) { + if (!GetShreddedTypeInternal(element_column, element_type, optional_idx(), force_partial)) { return false; } auto shredded_type = LogicalType::LIST(element_type); @@ -478,13 +479,13 @@ bool VariantShreddingStats::GetShreddedTypeInternal(const VariantColumnStatsData return true; } -LogicalType VariantShreddingStats::GetShreddedType() const { +LogicalType VariantShreddingStats::GetShreddedType(bool force_partial) const { auto &root_column = GetColumnStats(0); child_list_t child_types; child_types.emplace_back("unshredded", VariantShredding::GetUnshreddedType()); LogicalType shredded_type; - if (GetShreddedTypeInternal(root_column, shredded_type)) { + if (GetShreddedTypeInternal(root_column, shredded_type, optional_idx(), force_partial)) { child_types.emplace_back("shredded", shredded_type); } return LogicalType::STRUCT(child_types); @@ -804,4 +805,34 @@ void VariantColumnData::ShredVariantData(const Vector &input, Vector &output, id #endif } +void VariantColumnData::DebugShred(Vector &variant, idx_t count) { + D_ASSERT(variant.GetType().id() == LogicalTypeId::VARIANT); + if (count == 0 || variant.GetVectorType() == VectorType::SHREDDED_VECTOR) { + //! nothing to do (already shredded, or empty) + return; + } + + Vector materialized(LogicalType::VARIANT(), count); + VectorOperations::Copy(variant, materialized, count, 0, 0); + variant.Reference(materialized); + + //! Derive the shredding schema from the *first* value only - subsequent values that don't match it + //! will be partially shredded (i.e. fall back to the unshredded/overlay component) + VariantShreddingStats stats; + stats.Update(materialized, 1); + //! force_partial keeps the overlay columns so that later values that don't match the first value's + //! schema are partially shredded instead of failing to shred + auto shredded_struct_type = stats.GetShreddedType(true); + if (StructType::GetChildCount(shredded_struct_type) < 2) { + //! the first value did not yield a shreddable type (only the 'unshredded' component) - leave as-is + return; + } + + //! Shred into a STRUCT(unshredded, shredded) and wrap it as a SHREDDED_VECTOR + Vector shredded_struct(shredded_struct_type, count); + ShredVariantData(variant, shredded_struct, count); + FlatVector::SetSize(shredded_struct, count_t(count)); + variant.Shred(shredded_struct, count); +} + } // namespace duckdb diff --git a/src/duckdb/src/storage/table/variant/variant_unshredding.cpp b/src/duckdb/src/storage/table/variant/variant_unshredding.cpp index 054c77222..cfbccf964 100644 --- a/src/duckdb/src/storage/table/variant/variant_unshredding.cpp +++ b/src/duckdb/src/storage/table/variant/variant_unshredding.cpp @@ -224,8 +224,9 @@ static vector Unshred(UnifiedVariantVectorData &variant, Vector &s auto row = row_sel ? static_cast(row_sel->get_index(i)) : i; auto unshredded = UnshreddedVariantValue(variant, row, value_index); - if (res[i].IsNull()) { - //! Unshredded, has no shredded value + if (res[i].IsNull() || res[i].IsMissing()) { + //! No shredded value was produced for this row - either the value was not shredded at all, or it is a + //! shredded object none of whose fields are present in the typed schema. Take the overlay value as-is. res[i] = std::move(unshredded); } else { //! Partial shredding, already has a shredded value that this has to be combined into diff --git a/src/duckdb/src/transaction/duck_transaction.cpp b/src/duckdb/src/transaction/duck_transaction.cpp index cad22b173..f815357b2 100644 --- a/src/duckdb/src/transaction/duck_transaction.cpp +++ b/src/duckdb/src/transaction/duck_transaction.cpp @@ -2,6 +2,8 @@ #include "duckdb/transaction/commit_state.hpp" #include "duckdb/transaction/duck_transaction_manager.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/settings.hpp" +#include "duckdb/main/valid_checker.hpp" #include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/common/exception.hpp" @@ -261,12 +263,14 @@ ErrorData DuckTransaction::Commit(AttachedDatabase &db, CommitInfo &commit_info, } CommitDropState drop_state(block_manager); commit_info.drop_state = &drop_state; + + ErrorData error_data; try { storage->Commit(commit_state.get()); undo_buffer.Commit(iterator_state, commit_info); - // if (DebugForceAbortCommit()) { - // throw InvalidInputException("Force revert"); - // } + if (!db.IsSystem() && !db.IsTemporary() && Settings::Get(db.GetDatabase())) { + throw InvalidInputException("Forced commit failure (debug_force_commit_failure)"); + } if (commit_state) { // if we have written to the WAL - flush after the commit has been successful commit_state->FlushCommit(); @@ -274,13 +278,37 @@ ErrorData DuckTransaction::Commit(AttachedDatabase &db, CommitInfo &commit_info, drop_state.FinalizeCommit(); return ErrorData(); } catch (std::exception &ex) { + // Record the error and run RevertCommit() outside this try-catch: RevertCommit() iterates the + // undo buffer and may itself throw (e.g. Pin() failing under memory pressure), which would + // escape this noexcept function and trigger std::terminate. + error_data = ErrorData(ex); + } + + try { undo_buffer.RevertCommit(iterator_state, this->transaction_id); + if (!db.IsSystem() && !db.IsTemporary() && + Settings::Get(db.GetDatabase())) { + throw IOException("Forced RevertCommit failure (debug_force_commit_revert_failure)"); + } if (commit_state) { // if we have written to the WAL - truncate the WAL on failure commit_state->RevertCommit(); } - return ErrorData(ex); + } catch (std::exception &ex) { + // If we fail to revert the commit, the database is left in an undefined state - invalidate it. + // Record both the original commit error and the revert error so the root cause stays visible. + ValidChecker::Invalidate(db.GetDatabase(), + "Failed to revert transaction commit, database is in an undefined state. " + "Original commit error: " + + error_data.RawMessage() + ". RevertCommit error: " + ErrorData(ex).RawMessage()); + } catch (...) { + // last line of defense: this is a noexcept function, nothing may escape + ValidChecker::Invalidate(db.GetDatabase(), + "Failed to revert transaction commit (unknown error), database is in an " + "undefined state. Original commit error: " + + error_data.RawMessage()); } + return error_data; } ErrorData DuckTransaction::Rollback() { diff --git a/src/duckdb/src/transaction/transaction_context.cpp b/src/duckdb/src/transaction/transaction_context.cpp index 33884d6f0..80f3d8274 100644 --- a/src/duckdb/src/transaction/transaction_context.cpp +++ b/src/duckdb/src/transaction/transaction_context.cpp @@ -12,7 +12,8 @@ namespace duckdb { TransactionContext::TransactionContext(ClientContext &context) - : context(context), auto_commit(true), current_transaction(nullptr) { + : context(context), auto_commit(true), invalidation_policy(TransactionInvalidationPolicy::STANDARD_POLICY), + auto_rollback(false), current_transaction(nullptr) { } TransactionContext::~TransactionContext() { diff --git a/src/duckdb/ub_src_function_scalar_generic.cpp b/src/duckdb/ub_src_function_scalar_generic.cpp index 926f394d3..6b6bbcb09 100644 --- a/src/duckdb/ub_src_function_scalar_generic.cpp +++ b/src/duckdb/ub_src_function_scalar_generic.cpp @@ -4,3 +4,5 @@ #include "src/function/scalar/generic/getvariable.cpp" +#include "src/function/scalar/generic/invoke.cpp" + diff --git a/src/duckdb/ub_src_planner_operator.cpp b/src/duckdb/ub_src_planner_operator.cpp index 7ef28c1eb..c697a6e9b 100644 --- a/src/duckdb/ub_src_planner_operator.cpp +++ b/src/duckdb/ub_src_planner_operator.cpp @@ -80,6 +80,8 @@ #include "src/planner/operator/logical_top_n.cpp" +#include "src/planner/operator/logical_trigger.cpp" + #include "src/planner/operator/logical_unconditional_join.cpp" #include "src/planner/operator/logical_unnest.cpp" From 320fd3172cdced296ae631275fb825dd10895daa Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Wed, 24 Jun 2026 13:52:20 +0000 Subject: [PATCH 2/3] Update vendored DuckDB sources to 3518c60315 --- .../aggregate/holistic/approx_top_k.cpp | 41 +- .../holistic/approximate_quantile.cpp | 3 +- .../core_functions/aggregate/nested/list.cpp | 3 +- .../core_functions/scalar/date/date_part.cpp | 54 +- src/duckdb/extension/json/json_extension.cpp | 5 +- .../json/json_functions/json_create.cpp | 6 +- .../json/json_functions/json_table_in_out.cpp | 42 +- .../json/json_functions/json_transform.cpp | 2 +- .../extension/json/json_multi_file_info.cpp | 4 +- src/duckdb/extension/json/json_reader.cpp | 17 +- .../extension/parquet/column_reader.cpp | 33 +- .../parquet/include/decode_utils.hpp | 11 + .../parquet/include/parquet_writer.hpp | 3 + .../variant/parquet_variant_iterator.hpp | 256 ++ .../reader/variant/variant_binary_decoder.hpp | 43 +- .../variant/variant_shredded_conversion.hpp | 22 - .../extension/parquet/parquet_extension.cpp | 14 + .../extension/parquet/parquet_reader.cpp | 10 +- .../extension/parquet/parquet_writer.cpp | 28 +- .../variant/parquet_variant_iterator.cpp | 736 ++++++ .../reader/variant/variant_binary_decoder.cpp | 300 +-- .../variant/variant_shredded_conversion.cpp | 579 ----- .../parquet/reader/variant_column_reader.cpp | 8 +- src/duckdb/src/catalog/catalog_set.cpp | 4 +- src/duckdb/src/catalog/dependency_manager.cpp | 36 +- src/duckdb/src/common/adbc/adbc.cpp | 385 ++- src/duckdb/src/common/allocator/allocator.cpp | 26 + .../common/allocator/allocator_jemalloc.cpp | 8 + .../common/allocator/allocator_standard.cpp | 29 - .../src/common/arrow/appender/append_data.cpp | 14 + .../arrow/appender/fixed_size_list_data.cpp | 2 +- .../src/common/arrow/appender/struct_data.cpp | 2 +- .../src/common/arrow/appender/union_data.cpp | 2 +- .../src/common/arrow/arrow_appender.cpp | 47 +- .../src/common/arrow/arrow_type_extension.cpp | 4 +- src/duckdb/src/common/enum_util.cpp | 33 +- src/duckdb/src/common/gzip_file_system.cpp | 9 +- src/duckdb/src/common/local_file_system.cpp | 50 +- .../multi_file/multi_file_column_mapper.cpp | 17 + .../common/serializer/async_file_writer.cpp | 4 + .../serializer/async_memory_governor.cpp | 93 + .../common/serializer/async_task_queue.cpp | 634 +++++ .../common/serializer/async_write_queue.cpp | 94 +- src/duckdb/src/common/thread_util.cpp | 8 +- src/duckdb/src/common/types/list_segment.cpp | 215 +- .../src/common/types/selection_vector.cpp | 2 + src/duckdb/src/common/types/value.cpp | 9 + .../common/types/variant/variant_iterator.cpp | 5 + .../common/types/variant/variant_value.cpp | 849 ------- src/duckdb/src/common/types/vector.cpp | 3 + .../src/common/vector/dictionary_vector.cpp | 7 + src/duckdb/src/execution/index/art/art.cpp | 6 - .../src/execution/index/art/art_index.cpp | 1 - src/duckdb/src/execution/join_hashtable.cpp | 272 +- .../physical_materialized_collector.cpp | 11 +- .../join/perfect_hash_join_executor.cpp | 3 +- .../operator/join/physical_cross_product.cpp | 8 + .../operator/join/physical_hash_join.cpp | 292 ++- .../persistent/physical_copy_to_file.cpp | 2252 ++++++++++++----- .../src/execution/physical_operator.cpp | 158 +- .../execution/physical_plan/plan_window.cpp | 18 +- .../aggregate/sorted_aggregate_function.cpp | 232 +- .../function/cast/variant/from_variant.cpp | 4 +- src/duckdb/src/function/function_binder.cpp | 49 +- src/duckdb/src/function/function_list.cpp | 2 + .../compress_geometry.cpp | 134 + .../scalar/geometry/geometry_functions.cpp | 76 + .../scalar/list/contains_or_position.cpp | 7 +- .../function/scalar/operator/arithmetic.cpp | 6 +- .../src/function/scalar/string/like.cpp | 116 +- .../scalar/system/aggregate_export.cpp | 252 +- .../scalar/variant/variant_normalize.cpp | 8 +- .../function/scalar/variant/variant_utils.cpp | 32 + .../table/system/duckdb_variables.cpp | 2 +- .../function/table/system/logging_utils.cpp | 18 + .../function/table/version/pragma_version.cpp | 6 +- src/duckdb/src/include/duckdb.h | 2 +- .../duckdb/catalog/dependency_manager.hpp | 3 + .../src/include/duckdb/common/adbc/adbc.hpp | 10 +- .../src/include/duckdb/common/allocator.hpp | 4 + .../common/arrow/appender/append_data.hpp | 7 + .../common/arrow/appender/list_data.hpp | 2 +- .../common/arrow/appender/list_view_data.hpp | 2 +- .../duckdb/common/arrow/appender/map_data.hpp | 4 +- .../src/include/duckdb/common/enum_util.hpp | 8 - .../src/include/duckdb/common/http_util.hpp | 4 + .../common/multi_file/multi_file_function.hpp | 2 +- .../serializer/async_memory_governor.hpp | 60 + .../common/serializer/async_task_queue.hpp | 207 ++ .../common/serializer/async_write_queue.hpp | 18 +- .../include/duckdb/common/types/geometry.hpp | 23 +- .../duckdb/common/types/list_segment.hpp | 9 +- .../duckdb/common/types/selection_vector.hpp | 5 + .../duckdb/common/types/string_type.hpp | 1 + .../src/include/duckdb/common/types/value.hpp | 5 + .../include/duckdb/common/types/variant.hpp | 10 + .../common/types/variant/variant_builder.hpp | 590 +++++ .../duckdb/common/types/variant_iterator.hpp | 19 +- .../duckdb/common/types/variant_value.hpp | 80 - .../common/vector/dictionary_vector.hpp | 12 + .../duckdb/execution/index/art/art.hpp | 2 +- .../execution/index/art/art_operator.hpp | 19 +- .../duckdb/execution/join_hashtable.hpp | 38 +- .../persistent/physical_copy_to_file.hpp | 15 +- .../duckdb/execution/physical_operator.hpp | 20 + .../function/aggregate/list_aggregate.hpp | 42 +- .../function/cast/variant/json_to_variant.hpp | 23 +- .../cast/variant/variant_to_variant.hpp | 25 +- .../duckdb/function/function_binder.hpp | 23 + .../compressed_materialization_functions.hpp | 20 + .../compressed_materialization_utils.hpp | 8 + .../duckdb/function/scalar/variant_utils.hpp | 5 + .../table/arrow/arrow_duck_schema.hpp | 2 +- .../duckdb/function/table_function.hpp | 4 + .../function/variant/variant_shredding.hpp | 2 +- .../src/include/duckdb/main/client_config.hpp | 4 +- .../include/duckdb/main/extension_entries.hpp | 1 + .../duckdb/main/extension_load_options.hpp | 5 + .../duckdb/main/prepared_statement.hpp | 50 +- .../duckdb/main/prepared_statement_data.hpp | 4 +- .../src/include/duckdb/main/settings.hpp | 11 + .../optimizer/compressed_materialization.hpp | 25 + .../duckdb/optimizer/projection_pullup.hpp | 3 +- .../include/duckdb/parallel/async_result.hpp | 4 +- .../duckdb/parser/peg/inlined_grammar.hpp | 3 +- .../peg/transformer/peg_transformer.hpp | 44 +- .../peg/transformer/transform_enum_result.hpp | 17 - .../include/duckdb/parser/sql_statement.hpp | 2 + .../src/include/duckdb/planner/binder.hpp | 1 + .../duckdb/planner/collation_binding.hpp | 11 +- .../storage/compression/alp/alp_scan.hpp | 64 +- .../storage/compression/alprd/alprd_scan.hpp | 95 +- .../storage/compression/patas/patas_scan.hpp | 15 +- .../storage/statistics/variant_stats.hpp | 8 + .../transaction/duck_transaction_manager.hpp | 2 +- .../transaction/transaction_context.hpp | 4 +- src/duckdb/src/logging/log_manager.cpp | 9 + src/duckdb/src/logging/log_types.cpp | 6 +- src/duckdb/src/main/client_config.cpp | 14 +- src/duckdb/src/main/client_context.cpp | 29 +- src/duckdb/src/main/config.cpp | 11 +- .../src/main/extension/extension_helper.cpp | 2 +- src/duckdb/src/main/pending_query_result.cpp | 11 +- src/duckdb/src/main/prepared_statement.cpp | 8 +- .../src/main/prepared_statement_data.cpp | 91 +- .../main/settings/autogenerated_settings.cpp | 57 + src/duckdb/src/main/stream_query_result.cpp | 8 + .../optimizer/aggregate_function_rewriter.cpp | 5 + .../optimizer/common_aggregate_optimizer.cpp | 9 + .../optimizer/compressed_materialization.cpp | 263 +- .../compress_aggregate.cpp | 28 +- .../compress_comparison_join.cpp | 137 +- .../join_filter_pushdown_optimizer.cpp | 63 +- .../join_order/relation_statistics_helper.cpp | 2 +- .../src/optimizer/projection_pullup.cpp | 21 +- src/duckdb/src/parallel/async_result.cpp | 3 + .../peg/transformer/peg_transformer.cpp | 2 + .../transformer/peg_transformer_factory.cpp | 32 +- .../peg/transformer/transform_expression.cpp | 8 +- .../peg/transformer/transform_generated.cpp | 12 + .../peg/transformer/transform_select.cpp | 27 + src/duckdb/src/parser/tableref/joinref.cpp | 22 +- src/duckdb/src/parser/tableref/pivotref.cpp | 2 +- .../expression/bind_aggregate_expression.cpp | 3 +- .../expression/bind_parameter_expression.cpp | 48 +- .../planner/binder/statement/bind_call.cpp | 21 + .../planner/binder/statement/bind_execute.cpp | 8 +- .../planner/binder/tableref/bind_joinref.cpp | 2 +- .../planner/binder/tableref/bind_pivot.cpp | 18 + src/duckdb/src/planner/collation_binding.cpp | 176 +- .../src/planner/filter/expression_filter.cpp | 76 +- src/duckdb/src/planner/planner.cpp | 3 + src/duckdb/src/storage/compression/rle.cpp | 119 +- .../caching_file_system.cpp | 7 +- .../src/storage/metadata/metadata_reader.cpp | 2 +- .../storage/statistics/base_statistics.cpp | 6 +- .../src/storage/statistics/geometry_stats.cpp | 36 +- .../src/storage/statistics/variant_stats.cpp | 1 + .../src/storage/table/geo_column_data.cpp | 26 +- src/duckdb/src/storage/table/row_group.cpp | 3 +- .../src/storage/table/row_group_reorderer.cpp | 7 + .../table/variant/variant_shredding.cpp | 17 +- .../table/variant/variant_unshredding.cpp | 250 +- .../src/storage/temporary_memory_manager.cpp | 29 +- .../transaction/duck_transaction_manager.cpp | 2 +- .../thrift/thrift/protocol/TProtocol.h | 19 + .../ub_extension_parquet_reader_variant.cpp | 4 +- src/duckdb/ub_src_common_serializer.cpp | 4 + src/duckdb/ub_src_common_types_variant.cpp | 2 - ...tion_scalar_compressed_materialization.cpp | 2 + 190 files changed, 8395 insertions(+), 3862 deletions(-) create mode 100644 src/duckdb/extension/parquet/include/reader/variant/parquet_variant_iterator.hpp delete mode 100644 src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp create mode 100644 src/duckdb/extension/parquet/reader/variant/parquet_variant_iterator.cpp delete mode 100644 src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp create mode 100644 src/duckdb/src/common/serializer/async_memory_governor.cpp create mode 100644 src/duckdb/src/common/serializer/async_task_queue.cpp delete mode 100644 src/duckdb/src/common/types/variant/variant_value.cpp create mode 100644 src/duckdb/src/function/scalar/compressed_materialization/compress_geometry.cpp create mode 100644 src/duckdb/src/include/duckdb/common/serializer/async_memory_governor.hpp create mode 100644 src/duckdb/src/include/duckdb/common/serializer/async_task_queue.hpp create mode 100644 src/duckdb/src/include/duckdb/common/types/variant/variant_builder.hpp delete mode 100644 src/duckdb/src/include/duckdb/common/types/variant_value.hpp delete mode 100644 src/duckdb/src/include/duckdb/parser/peg/transformer/transform_enum_result.hpp diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp index af8426984..c3a7d35bf 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp @@ -354,7 +354,7 @@ void ApproxTopKFinalize(Vector &state_vector, AggregateFinalizeInputData &, Vect } // reserve space in the list vector ListVector::Reserve(result, old_len + new_entries); - auto list_entries = FlatVector::Writer(result, offset + count, offset); + auto list_entries = FlatVector::Writer(result, count, offset); auto &child_data = ListVector::GetChildMutable(result); idx_t current_offset = old_len; @@ -404,7 +404,6 @@ AggregateStateLayout ApproxTopKGetStateType(AggregateLayoutInput &input) { template void ApproxTopKExportState(Vector &state_vector, AggregateFinalizeInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { - D_ASSERT(offset == 0); auto states = state_vector.Values(); auto &mask = FlatVector::ValidityMutable(result); @@ -420,23 +419,24 @@ void ApproxTopKExportState(Vector &state_vector, AggregateFinalizeInputData &agg idx_t total_values = ListVector::GetListSize(value_lists); idx_t total_filters = ListVector::GetListSize(filter_lists); for (idx_t i = 0; i < count; i++) { + const idx_t row = offset + i; auto state_ptr = states[i].GetValue()->state; - value_entries[i].offset = total_values; - filter_entries[i].offset = total_filters; + value_entries[row].offset = total_values; + filter_entries[row].offset = total_filters; if (!state_ptr || state_ptr->values.empty()) { // no values have been added to this state - export NULL (children of a NULL struct must also be NULL) - mask.SetInvalid(i); - k_validity.SetInvalid(i); - value_validity.SetInvalid(i); - filter_validity.SetInvalid(i); - value_entries[i].length = 0; - filter_entries[i].length = 0; - k_data[i] = 0; + mask.SetInvalid(row); + k_validity.SetInvalid(row); + value_validity.SetInvalid(row); + filter_validity.SetInvalid(row); + value_entries[row].length = 0; + filter_entries[row].length = 0; + k_data[row] = 0; continue; } - k_data[i] = state_ptr->k; - value_entries[i].length = state_ptr->values.size(); - filter_entries[i].length = state_ptr->filter.size(); + k_data[row] = state_ptr->k; + value_entries[row].length = state_ptr->values.size(); + filter_entries[row].length = state_ptr->filter.size(); total_values += state_ptr->values.size(); total_filters += state_ptr->filter.size(); } @@ -449,13 +449,14 @@ void ApproxTopKExportState(Vector &state_vector, AggregateFinalizeInputData &agg auto count_data = FlatVector::GetDataMutable(value_fields[1]); auto filter_data = FlatVector::GetDataMutable(ListVector::GetChildMutable(filter_lists)); for (idx_t i = 0; i < count; i++) { + const idx_t row = offset + i; auto state_ptr = states[i].GetValue()->state; if (!state_ptr || state_ptr->values.empty()) { continue; } auto &state = *state_ptr; // write the values (in descending count order) - decoding them back to the input type - idx_t value_offset = value_entries[i].offset; + idx_t value_offset = value_entries[row].offset; for (auto &val_ref : state.values) { auto &val = val_ref.get(); OP::template HistogramFinalize(val.str_val.str, value_child, value_offset); @@ -463,15 +464,15 @@ void ApproxTopKExportState(Vector &state_vector, AggregateFinalizeInputData &agg value_offset++; } for (idx_t filter_idx = 0; filter_idx < state.filter.size(); filter_idx++) { - filter_data[filter_entries[i].offset + filter_idx] = state.filter[filter_idx]; + filter_data[filter_entries[row].offset + filter_idx] = state.filter[filter_idx]; } } ListVector::SetListSize(value_lists, total_values); ListVector::SetListSize(filter_lists, total_filters); - FlatVector::SetSize(fields[0], count); - FlatVector::SetSize(value_lists, count); - FlatVector::SetSize(filter_lists, count); - FlatVector::SetSize(result, count); + FlatVector::SetSize(fields[0], offset + count); + FlatVector::SetSize(value_lists, offset + count); + FlatVector::SetSize(filter_lists, offset + count); + FlatVector::SetSize(result, offset + count); } template diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp index e5fa53052..69b79bb77 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp @@ -219,9 +219,8 @@ using APPROX_QUANTILE_EXPORT_TYPE = void ApproxQuantileExportState(Vector &state_vector, AggregateFinalizeInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { - D_ASSERT(offset == 0); auto states = state_vector.Values(); - auto writer = FlatVector::Writer(result, count); + auto writer = FlatVector::Writer(result, count, offset); for (idx_t i = 0; i < count; i++) { auto &state = *states[i].GetValue(); if (!state.h || state.pos == 0) { diff --git a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp index 79d7f5bbd..25c634089 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp @@ -28,7 +28,8 @@ AggregateFunction ListFun::GetFunction() { auto func = AggregateFunction({LogicalType::TEMPLATE("T")}, LogicalType::LIST(LogicalType::TEMPLATE("T")), AggregateFunction::StateSize, AggregateFunction::StateInitialize, ListUpdateFunction<>, - ListCombineFunction, ListFinalize, nullptr, nullptr, nullptr, nullptr); + ListCombineFunction, ListFinalize, ListClusterUpdate<>, nullptr, + nullptr, nullptr); AggregateFunction::WireStructStateType(func); return func; diff --git a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp index e1f7f7bdd..44280a2bc 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp @@ -90,10 +90,24 @@ DatePartSpecifier GetDateTypePartSpecifier(const string &specifier, const Logica throw NotImplementedException("\"%s\" units \"%s\" not recognized", EnumUtil::ToString(type.id()), specifier); } -template +template unique_ptr PropagateSimpleDatePartStatistics(vector &child_stats) { - // we can always propagate simple date part statistics - // since the min and max can never exceed these bounds + // we can only propagate simple date part statistics if the child has stats + auto &nstats = child_stats[0]; + if (!NumericStats::HasMinMax(nstats)) { + return nullptr; + } + auto min = NumericStats::GetMin(nstats); + auto max = NumericStats::GetMax(nstats); + if (min > max) { + return nullptr; + } + // Infinities produce a NULL date part even though the input is not NULL, + // so we cannot propagate the validity (and thus the stats) in that case + if (!Value::IsFinite(min) || !Value::IsFinite(max)) { + return nullptr; + } + // the min and max can never exceed these bounds auto result = NumericStats::CreateEmpty(LogicalType::BIGINT); result.CopyValidity(child_stats[0]); NumericStats::SetMin(result, Value::BIGINT(MIN)); @@ -185,7 +199,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { // min/max of month operator is [1, 12] - return PropagateSimpleDatePartStatistics<1, 12>(input.child_stats); + return PropagateSimpleDatePartStatistics<1, 12, T>(input.child_stats); } }; @@ -198,7 +212,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { // min/max of day operator is [1, 31] - return PropagateSimpleDatePartStatistics<1, 31>(input.child_stats); + return PropagateSimpleDatePartStatistics<1, 31, T>(input.child_stats); } }; @@ -284,7 +298,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { // min/max of quarter operator is [1, 4] - return PropagateSimpleDatePartStatistics<1, 4>(input.child_stats); + return PropagateSimpleDatePartStatistics<1, 4, T>(input.child_stats); } }; @@ -303,7 +317,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 6>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 6, T>(input.child_stats); } }; @@ -316,7 +330,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<1, 7>(input.child_stats); + return PropagateSimpleDatePartStatistics<1, 7, T>(input.child_stats); } }; @@ -328,7 +342,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<1, 366>(input.child_stats); + return PropagateSimpleDatePartStatistics<1, 366, T>(input.child_stats); } }; @@ -340,7 +354,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<1, 53>(input.child_stats); + return PropagateSimpleDatePartStatistics<1, 53, T>(input.child_stats); } }; @@ -429,7 +443,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 59999999999>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 59999999999, T>(input.child_stats); } }; @@ -441,7 +455,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 59999999>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 59999999, T>(input.child_stats); } }; @@ -453,7 +467,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 59999>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 59999, T>(input.child_stats); } }; @@ -465,7 +479,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 59>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 59, T>(input.child_stats); } }; @@ -477,7 +491,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 59>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 59, T>(input.child_stats); } }; @@ -489,7 +503,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 24>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 24, T>(input.child_stats); } }; @@ -518,7 +532,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 1>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 1, T>(input.child_stats); } }; @@ -550,7 +564,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 0, T>(input.child_stats); } }; @@ -563,7 +577,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 0, T>(input.child_stats); } }; @@ -576,7 +590,7 @@ struct DatePart { template static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); + return PropagateSimpleDatePartStatistics<0, 0, T>(input.child_stats); } }; diff --git a/src/duckdb/extension/json/json_extension.cpp b/src/duckdb/extension/json/json_extension.cpp index ca9e922ef..af99c6505 100644 --- a/src/duckdb/extension/json/json_extension.cpp +++ b/src/duckdb/extension/json/json_extension.cpp @@ -16,8 +16,9 @@ static const DefaultMacro JSON_MACROS[] = { {DEFAULT_SCHEMA, "json_group_array", "(x) AS CAST('[' || string_agg(CASE WHEN x IS NULL THEN 'null'::JSON ELSE to_json(x) END, ',') || ']' AS JSON)"}, {DEFAULT_SCHEMA, "json_group_object", - "(n, v) AS CAST('{' || string_agg(to_json(n::VARCHAR) || ':' || CASE WHEN v IS NULL THEN 'null'::JSON ELSE " - "to_json(v) END, ',') || '}' AS JSON)"}, + "(n, v) AS CAST('{' || string_agg(CASE WHEN n IS NULL THEN error('json_group_object key cannot be NULL') ELSE " + "to_json(n::VARCHAR) END || ':' || CASE WHEN v IS NULL THEN 'null'::JSON ELSE to_json(v) END, ',') || '}' AS " + "JSON)"}, {DEFAULT_SCHEMA, "json_group_structure", "(x) AS json_structure(json_group_array(x))->0"}, {DEFAULT_SCHEMA, "json", "(x) AS json_extract(x, '$')"}, {DEFAULT_SCHEMA, "json_copy_strftime_if_date", "(x, format) AS x, (x DATE, format) AS strftime(x, format);"}, diff --git a/src/duckdb/extension/json/json_functions/json_create.cpp b/src/duckdb/extension/json/json_functions/json_create.cpp index 5cfefab60..affb9f7c9 100644 --- a/src/duckdb/extension/json/json_functions/json_create.cpp +++ b/src/duckdb/extension/json/json_functions/json_create.cpp @@ -186,8 +186,8 @@ static unique_ptr ArrayToJSONBind(BindScalarFunctionInput &input) if (arguments[0]->HasParameter()) { throw ParameterNotResolvedException(); } - if (arg_id != LogicalTypeId::LIST && arg_id != LogicalTypeId::SQLNULL) { - throw BinderException("array_to_json() argument type must be LIST"); + if (arg_id != LogicalTypeId::LIST && arg_id != LogicalTypeId::ARRAY && arg_id != LogicalTypeId::SQLNULL) { + throw BinderException("array_to_json() argument type must be LIST or ARRAY"); } return JSONCreateBindParams(bound_function, arguments, false); } @@ -288,7 +288,7 @@ static void AddKeyValuePairs(yyjson_mut_doc *doc, yyjson_mut_val *objs[], const for (idx_t i = 0; i < count; i++) { auto key_entry = keys[i]; if (!key_entry.IsValid()) { - continue; + throw InvalidInputException("JSON key cannot be NULL"); } auto key = CreateJSONValue::Operation(doc, key_entry.GetValue()); yyjson_mut_obj_add(objs[i], key, vals[i]); diff --git a/src/duckdb/extension/json/json_functions/json_table_in_out.cpp b/src/duckdb/extension/json/json_functions/json_table_in_out.cpp index b2c6364f2..a2507276d 100644 --- a/src/duckdb/extension/json/json_functions/json_table_in_out.cpp +++ b/src/duckdb/extension/json/json_functions/json_table_in_out.cpp @@ -1,5 +1,6 @@ #include "json_common.hpp" #include "json_functions.hpp" +#include "duckdb/common/string_util.hpp" #include "duckdb/function/table_function.hpp" namespace duckdb { @@ -100,6 +101,40 @@ static unique_ptr JSONTableInOutInitGlobal(ClientConte return std::move(result); } +//! Whether an object key can appear unquoted in a JSON path (as in "$.key"). Mirrors SQLite's json_each/json_tree: +//! only keys consisting of an ASCII letter followed by ASCII alphanumerics/underscores go unquoted +static bool JSONPathKeyNeedsQuoting(const char *data, const idx_t len) { + if (len == 0 || !StringUtil::CharacterIsAlpha(data[0])) { + return true; + } + for (idx_t i = 1; i < len; i++) { + if (!StringUtil::CharacterIsAlphaNumeric(data[i]) && data[i] != '_') { + return true; + } + } + return false; +} + +//! Appends ".key" to the path, quoting the key (as in "$.\"a.b\"") when needed so that the resulting path +//! round-trips through JSON path extraction (issue #23148) +static void AppendObjectPathElement(yyjson_val *vkey, string &path) { + const auto data = unsafe_yyjson_get_str(vkey); + const auto len = unsafe_yyjson_get_len(vkey); + path += '.'; + if (!JSONPathKeyNeedsQuoting(data, len)) { + path.append(data, len); + return; + } + path += '"'; + for (idx_t i = 0; i < len; i++) { + if (data[i] == '"' || data[i] == '\\') { + path += '\\'; + } + path += data[i]; + } + path += '"'; +} + struct JSONTableInOutRecursionNode { JSONTableInOutRecursionNode(string path_p, yyjson_val *parent_val_p) : path(std::move(path_p)), parent_val(parent_val_p), child_index(0) { @@ -127,7 +162,7 @@ struct JSONTableInOutLocalState : LocalTableFunctionState { void AddRecursionNode(yyjson_val *val, optional_ptr vkey, const optional_idx arr_index) { string str; if (vkey) { - str = "." + string(unsafe_yyjson_get_str(vkey.get()), unsafe_yyjson_get_len(vkey.get())); + AppendObjectPathElement(vkey.get(), str); } else if (arr_index.IsValid()) { str = "[" + to_string(arr_index.GetIndex()) + "]"; } @@ -213,8 +248,9 @@ struct JSONTableInOutResult { const auto path_str = lstate.GetPath(); if (fullkey.enabled) { if (vkey) { // Object field - const auto vkey_str = string(unsafe_yyjson_get_str(vkey.get()), unsafe_yyjson_get_len(vkey.get())); - fullkey.data[count] = StringVector::AddString(fullkey.vector, path_str + "." + vkey_str); + auto fullkey_str = path_str; + AppendObjectPathElement(vkey.get(), fullkey_str); + fullkey.data[count] = StringVector::AddString(fullkey.vector, fullkey_str); } else if (arr_el) { // Array element const auto arr_path = "[" + to_string(recursion_nodes.back().child_index) + "]"; fullkey.data[count] = StringVector::AddString(fullkey.vector, path_str + arr_path); diff --git a/src/duckdb/extension/json/json_functions/json_transform.cpp b/src/duckdb/extension/json/json_functions/json_transform.cpp index 298e690e7..c2e2523b9 100644 --- a/src/duckdb/extension/json/json_functions/json_transform.cpp +++ b/src/duckdb/extension/json/json_functions/json_transform.cpp @@ -46,7 +46,7 @@ static LogicalType StructureToTypeObject(yyjson_val *obj, ClientContext &context yyjson_val *key, *val; yyjson_obj_foreach(obj, idx, max, key, val) { val = yyjson_obj_iter_get_val(key); - auto key_str = unsafe_yyjson_get_str(key); + string key_str(unsafe_yyjson_get_str(key), unsafe_yyjson_get_len(key)); if (names.find(key_str) != names.end()) { JSONCommon::ThrowValFormatError("Duplicate keys in object in JSON structure: %s", val); } diff --git a/src/duckdb/extension/json/json_multi_file_info.cpp b/src/duckdb/extension/json/json_multi_file_info.cpp index 607bf62b9..6b7db052c 100644 --- a/src/duckdb/extension/json/json_multi_file_info.cpp +++ b/src/duckdb/extension/json/json_multi_file_info.cpp @@ -232,7 +232,9 @@ bool JSONMultiFileInfo::ParseCopyOption(ClientContext &context, const string &ke } else { JSONCheckSingleParameter(key, values); options.auto_detect = BooleanValue::Get(values.back().DefaultCastAs(LogicalTypeId::BOOLEAN)); - options.format = JSONFormat::NEWLINE_DELIMITED; + if (options.format == JSONFormat::AUTO_DETECT) { + options.format = JSONFormat::NEWLINE_DELIMITED; + } } return true; } diff --git a/src/duckdb/extension/json/json_reader.cpp b/src/duckdb/extension/json/json_reader.cpp index 25feff588..832301836 100644 --- a/src/duckdb/extension/json/json_reader.cpp +++ b/src/duckdb/extension/json/json_reader.cpp @@ -658,16 +658,21 @@ bool JSONReader::ParseJSON(JSONReaderScanState &scan_state, char *const json_sta err.pos = json_size; AddParseError(scan_state, scan_state.lines_or_objects_in_buffer, err, "Try auto-detecting the JSON format"); return false; - } else if (!options.ignore_errors && read_size < json_size) { + } + if (read_size < json_size) { idx_t off = read_size; idx_t rem = json_size; SkipWhitespace(json_start, off, rem); if (off != rem) { // Between end of document and boundary should be whitespace only - err.code = YYJSON_READ_ERROR_UNEXPECTED_CONTENT; - err.msg = "unexpected content after document"; - err.pos = read_size; - AddParseError(scan_state, scan_state.lines_or_objects_in_buffer, err, "Try auto-detecting the JSON format"); - return false; + if (!options.ignore_errors) { + err.code = YYJSON_READ_ERROR_UNEXPECTED_CONTENT; + err.msg = "unexpected content after document"; + err.pos = read_size; + AddParseError(scan_state, scan_state.lines_or_objects_in_buffer, err, + "Try auto-detecting the JSON format"); + return false; + } + doc = nullptr; } } diff --git a/src/duckdb/extension/parquet/column_reader.cpp b/src/duckdb/extension/parquet/column_reader.cpp index eedda2cad..f82f8f238 100644 --- a/src/duckdb/extension/parquet/column_reader.cpp +++ b/src/duckdb/extension/parquet/column_reader.cpp @@ -351,7 +351,7 @@ void ColumnReader::PrepareRead(optional_ptr filter, optional_ } // some basic sanity check if (page_hdr.compressed_page_size < 0 || page_hdr.uncompressed_page_size < 0) { - throw InvalidInputException("Failed to read file \"%s\": Page sizes can't be < 0", Reader().GetFileName()); + throw InvalidInputException("Failed to read file \"%s\": Page sizes must be >= 0", Reader().GetFileName()); } if (PageIsFilteredOut(page_hdr, filter)) { @@ -420,20 +420,43 @@ void ColumnReader::PreparePageV2(PageHeader &page_hdr) { return; } - // copy repeats & defines as-is because FOR SOME REASON they are uncompressed - auto uncompressed_bytes = page_hdr.data_page_header_v2.repetition_levels_byte_length + - page_hdr.data_page_header_v2.definition_levels_byte_length; - if (uncompressed_bytes > page_hdr.uncompressed_page_size) { + // copy repeats & defines as-is because FOR SOME REASON they are uncompressed. + // the page sizes are already validated >= 0 by the caller, but the level lengths are not, so guard + // them here. with all four i32 header fields non-negative, the sum below cannot overflow once widened + // to uint64_t, and the uint64_t casts in the comparisons below are safe. + if (page_hdr.data_page_header_v2.repetition_levels_byte_length < 0 || + page_hdr.data_page_header_v2.definition_levels_byte_length < 0) { + throw InvalidInputException( + "Failed to read file \"%s\": header inconsistency, repetition_levels_byte_length and " + "definition_levels_byte_length must be >= 0", + Reader().GetFileName()); + } + uint64_t uncompressed_bytes = static_cast(page_hdr.data_page_header_v2.repetition_levels_byte_length) + + page_hdr.data_page_header_v2.definition_levels_byte_length; + if (uncompressed_bytes > static_cast(page_hdr.uncompressed_page_size)) { throw InvalidInputException( "Failed to read file \"%s\": header inconsistency, uncompressed_page_size needs to be larger than " "repetition_levels_byte_length + definition_levels_byte_length", Reader().GetFileName()); } + if (static_cast(page_hdr.compressed_page_size) < uncompressed_bytes) { + throw InvalidInputException( + "Failed to read file \"%s\": header inconsistency, compressed_page_size is smaller than " + "repetition_levels_byte_length + definition_levels_byte_length", + Reader().GetFileName()); + } ReadData(block->ptr, uncompressed_bytes, page_hdr.type); auto compressed_bytes = page_hdr.compressed_page_size - uncompressed_bytes; + if (compressed_bytes == 0 && static_cast(page_hdr.uncompressed_page_size) > uncompressed_bytes) { + throw InvalidInputException( + "Failed to read file \"%s\": header inconsistency, compressed_page_size is too small for the " + "declared value region", + Reader().GetFileName()); + } + if (compressed_bytes > 0) { ResizeableBuffer compressed_buffer; compressed_buffer.resize(GetAllocator(), compressed_bytes); diff --git a/src/duckdb/extension/parquet/include/decode_utils.hpp b/src/duckdb/extension/parquet/include/decode_utils.hpp index 6f72e2912..ca6107820 100644 --- a/src/duckdb/extension/parquet/include/decode_utils.hpp +++ b/src/duckdb/extension/parquet/include/decode_utils.hpp @@ -36,6 +36,11 @@ class ParquetDecodeUtils { static void BitUnpack(ByteBuffer &src, bitpacking_width_t &bitpack_pos, T *dst, idx_t count, const bitpacking_width_t width) { CheckWidth(width); + if (width > sizeof(T) * BITPACK_DLEN) { + throw IOException("The width (%d) of the bitpacked data exceeds the maximum width (%d) for " + "the target type, the file might be corrupted.", + width, sizeof(T) * BITPACK_DLEN); + } const auto mask = BITPACK_MASKS[width]; src.available(count * width / BITPACK_DLEN); // check if buffer has enough space available once if (bitpack_pos == 0 && count >= BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE) { @@ -88,6 +93,12 @@ class ParquetDecodeUtils { template static void BitUnpackAlignedInternal(ByteBuffer &src, T *dst, const idx_t count, const bitpacking_width_t width) { D_ASSERT(count % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0); + if (width > sizeof(T) * BITPACK_DLEN) { + throw IOException("The width (%d) of the bitpacked data exceeds the maximum width (%d) for " + "the target type, the file might be corrupted.", + width, sizeof(T) * BITPACK_DLEN); + } + if (cast_pointer_to_uint64(src.ptr) % sizeof(T) == 0) { // Fast path: aligned BitpackingPrimitives::UnPackBuffer(data_ptr_cast(dst), src.ptr, count, width); diff --git a/src/duckdb/extension/parquet/include/parquet_writer.hpp b/src/duckdb/extension/parquet/include/parquet_writer.hpp index 2f4646eb7..c4ef18adc 100644 --- a/src/duckdb/extension/parquet/include/parquet_writer.hpp +++ b/src/duckdb/extension/parquet/include/parquet_writer.hpp @@ -109,10 +109,13 @@ class ParquetWriteTransformData { public: ColumnDataCollection &ApplyTransform(ColumnDataCollection &input); + bool MatchesTypes(const vector &other_types) const; private: //! The buffer to store the transformed chunks of a rowgroup ColumnDataCollection buffer; + //! The types used to bind the expressions and initialize the buffer + vector types; //! The expression(s) to apply to the input chunk vector> expressions; //! The expression executor used to transform the input chunk diff --git a/src/duckdb/extension/parquet/include/reader/variant/parquet_variant_iterator.hpp b/src/duckdb/extension/parquet/include/reader/variant/parquet_variant_iterator.hpp new file mode 100644 index 000000000..28393e5cb --- /dev/null +++ b/src/duckdb/extension/parquet/include/reader/variant/parquet_variant_iterator.hpp @@ -0,0 +1,256 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// reader/variant/parquet_variant_iterator.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/types/variant.hpp" +#include "duckdb/common/types/variant_iterator.hpp" +#include "duckdb/common/vector/vector_iterator.hpp" +#include "duckdb/common/vector/unified_vector_format.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/unique_ptr.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "reader/variant/variant_binary_decoder.hpp" + +namespace duckdb { + +struct VariantBuilder; + +//! A "group" in the Parquet shredded VARIANT layout: STRUCT(value BLOB, [typed_value ]). +//! The recursive view builds the (type-specific) VectorIterators of the group tree once, so each layer +//! can be navigated by index during the (per-row) single-pass conversion. +enum class ParquetGroupKind : uint8_t { LEAF, OBJECT, ARRAY }; + +struct ShreddedGroupView { + //! 'value' (the binary-encoded fallback / overlay) + unique_ptr> value; + + //! Whether the group has a 'typed_value' (a shredded representation) + bool has_typed_value = false; + ParquetGroupKind kind = ParquetGroupKind::LEAF; + LogicalType typed_type; + + //! LEAF: the typed primitive values (type-erased; read typed via GetData where T is known) + UnifiedVectorFormat leaf_format; + + //! ARRAY: the list entries + the element (group) sub-view + unique_ptr> list; + unique_ptr element; + + //! OBJECT: the struct's validity + the field names + their (group) sub-views + unique_ptr typed_validity; + vector field_names; + vector> fields; + + //! Build the view recursively from a group Vector + void Build(Vector &group); +}; + +class ParquetVariantIterator; +class ParquetObjectIterator; +class ParquetArrayIterator; + +//! A lightweight cursor pointing at a single logical Parquet VARIANT value. A value is either SHREDDED +//! (a position in a typed vector tree) or BINARY (a Spark variant-encoded value at a byte offset in a +//! 'value' blob). Both expose the same node concept, so the shared EmitIterator traverses them uniformly +//! and emits the correct VariantLogicalType. +class ParquetVariantNode { +public: + enum class Kind : uint8_t { NULL_VALUE, MISSING, SHREDDED, BINARY }; + +public: + ParquetVariantNode() : kind(Kind::NULL_VALUE) { + } + +public: + static ParquetVariantNode MakeNull() { + return ParquetVariantNode(Kind::NULL_VALUE); + } + static ParquetVariantNode MakeMissing() { + return ParquetVariantNode(Kind::MISSING); + } + //! A position in the shredded (typed) tree. 'overlay' is the binary OBJECT holding the leftover fields + //! of a partially-shredded object (or null when fully shredded / not an object); 'overlay_end' is one + //! past the end of the overlay's 'value' blob (for bounds-checking the binary reads). + static ParquetVariantNode MakeShredded(const ParquetVariantIterator &state, const ShreddedGroupView &view, + idx_t index, const_data_ptr_t overlay = nullptr, + const_data_ptr_t overlay_end = nullptr) { + ParquetVariantNode result(Kind::SHREDDED); + result.state = &state; + result.view = &view; + result.index = index; + result.binary = overlay; + result.binary_end = overlay_end; + return result; + } + //! A Spark variant-encoded value starting at 'data' (its header byte); 'end' is one past the end of the + //! 'value' blob the value lives in (for bounds-checking the binary reads). + static ParquetVariantNode MakeBinary(const ParquetVariantIterator &state, const_data_ptr_t data, + const_data_ptr_t end) { + ParquetVariantNode result(Kind::BINARY); + result.state = &state; + result.binary = data; + result.binary_end = end; + return result; + } + +public: + bool IsNull() const { + return kind == Kind::NULL_VALUE; + } + bool IsMissing() const { + return kind == Kind::MISSING; + } + + VariantLogicalType GetTypeId() const; + //! Returns the fixed-width primitive payload (loaded / re-encoded as T) + template + T GetData() const; + string_t GetString() const; + VariantDecimalProperties GetDecimalProperties() const; + ParquetObjectIterator GetObjectChildren(VariantIterationOrder order) const; + ParquetArrayIterator GetArrayChildren() const; + +private: + explicit ParquetVariantNode(Kind kind) : kind(kind) { + } + +private: + Kind kind; + optional_ptr state; + //! SHREDDED: the shredded position + optional_ptr view; + idx_t index = 0; + //! SHREDDED OBJECT: the (binary) overlay object holding leftover fields, or null. + //! BINARY: the value's start (header byte). + const_data_ptr_t binary = nullptr; + //! One past the end of the 'value' blob 'binary' points into (for bounds-checking the binary reads) + const_data_ptr_t binary_end = nullptr; +}; + +//! A single (key, value) entry of an OBJECT +struct ParquetObjectEntry { + string_t key; + ParquetVariantNode value; +}; + +//! Iterates the (key, value) children of an OBJECT: for a shredded object, the typed struct fields merged +//! with the leftover (binary overlay) fields - deduplicated (typed fields win) and sorted lexicographically; +//! for a binary object, the encoded fields. +class ParquetObjectIterator { +public: + //! Shredded object (typed fields + optional binary overlay) + ParquetObjectIterator(const ParquetVariantIterator &state, const ShreddedGroupView &view, idx_t index, + const_data_ptr_t overlay, const_data_ptr_t overlay_end); + //! Binary object + ParquetObjectIterator(const ParquetVariantIterator &state, const VariantMetadata &metadata, const_data_ptr_t data, + const_data_ptr_t end); + +public: + const ParquetObjectEntry *begin() const { // NOLINT: match stl API + return ordered_entries.data(); + } + const ParquetObjectEntry *end() const { // NOLINT: match stl API + return ordered_entries.data() + ordered_entries.size(); + } + +private: + void Finalize(); + +private: + vector ordered_entries; +}; + +//! Iterates the element values of an ARRAY (shredded list or binary array). Random-access. +class ParquetArrayIterator { +public: + //! Shredded array + ParquetArrayIterator(const ParquetVariantIterator &state, const ShreddedGroupView &view, idx_t index); + //! Binary array + ParquetArrayIterator(const ParquetVariantIterator &state, const VariantMetadata &metadata, const_data_ptr_t data, + const_data_ptr_t end); + +public: + idx_t size() const { + return length; + } + ParquetVariantNode operator[](idx_t i) const; + +private: + reference state; + bool shredded; + idx_t length; + + //! SHREDDED: the element (group) sub-view + the list offset of the elements + optional_ptr element; + idx_t base = 0; + + //! BINARY: where to read each element's value offset, and the values base + const_data_ptr_t field_offsets = nullptr; + const_data_ptr_t values = nullptr; + //! BINARY: one past the end of the 'value' blob (for bounds-checking the binary reads) + const_data_ptr_t binary_end = nullptr; + uint32_t field_offset_size = 0; +}; + +//! Iterates a shredded Parquet VARIANT column (metadata + group), handing out ParquetVariantNode cursors +//! per row. Binary (Spark-encoded) values are read directly from the 'value' blobs; fixed-width payloads +//! are fetched by value (no materialization). +class ParquetVariantIterator { +public: + ParquetVariantIterator(Vector &metadata, Vector &group); + //! Binary-only: each row is a full Spark variant-encoded value (the metadata blob immediately followed + //! by the value blob). There is no shredded group - the value is read right after the metadata. + explicit ParquetVariantIterator(Vector &metadata); + +public: + //! Reset the per-row state (lazily-decoded metadata) for a new row + void BeginRow(idx_t row); + //! Resolve the root value of 'row' (a missing root is promoted to a SQL NULL) + ParquetVariantNode Root(idx_t row) const; + //! Resolve the root of a binary-only row: the value blob starts right after the metadata + ParquetVariantNode BinaryRoot() const; + //! Resolve the value of the group 'view' at logical position 'index' + ParquetVariantNode ResolveGroup(const ShreddedGroupView &view, idx_t index) const; + + //! The (lazily-decoded) Variant metadata of the current row + const VariantMetadata &GetMetadata() const; + +private: + ShreddedGroupView root_view; + + VectorIterator metadata; + + idx_t current_row = 0; + mutable unique_ptr current_metadata; +}; + +//! BuildVariant source wrapping a ParquetVariantIterator (mirrors VariantIteratorSource in core) +struct ParquetVariantIteratorSource { + explicit ParquetVariantIteratorSource(ParquetVariantIterator &iterator) : iterator(iterator) { + } + bool Emit(idx_t row, VariantBuilder &builder); + + ParquetVariantIterator &iterator; +}; + +//! Convert a shredded Parquet VARIANT (metadata + group) into the canonical VARIANT 'result' in a single +//! pass through the shared VariantBuilder +class ParquetVariantConversion { +public: + static void Convert(Vector &metadata, Vector &group, Vector &result, idx_t count); + //! Convert binary Variant values (each row being the metadata blob followed by the value blob) into the + //! canonical VARIANT 'result' in a single pass + static void ConvertBinary(Vector &metadata_and_value, Vector &result, idx_t count); + //! 'variant_bytes_to_variant': decode a binary Variant value (metadata followed by value) into a VARIANT. + //! The inverse of 'variant_to_parquet_variant'. + static ScalarFunction GetBytesToVariantFunction(); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp index d30507f34..d9bfacdd7 100644 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp @@ -5,16 +5,11 @@ #include #include "duckdb/common/types/string_type.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/types/variant_value.hpp" -#include "yyjson.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string.hpp" #include "duckdb/common/typedefs.hpp" #include "duckdb/common/vector.hpp" -using namespace duckdb_yyjson; - namespace duckdb { //! ------------ Metadata ------------ @@ -41,8 +36,9 @@ struct VariantMetadata { public: VariantMetadataHeader header; - const_data_ptr_t offsets; - const_data_ptr_t bytes; + + //! Total byte length of the metadata region. + idx_t total_size = 0; //! The json object keys have to be null-terminated //! But we don't receive them null-terminated @@ -121,37 +117,4 @@ struct VariantValueMetadata { bool is_large; }; -struct VariantDecodeResult { -public: - VariantDecodeResult() = default; - ~VariantDecodeResult() { - if (doc) { - yyjson_mut_doc_free(doc); - } - if (data) { - free(data); - } - } - -public: - yyjson_mut_doc *doc = nullptr; - char *data = nullptr; -}; - -class VariantBinaryDecoder { -public: - VariantBinaryDecoder() = delete; - -public: - static VariantValue Decode(const VariantMetadata &metadata, const_data_ptr_t data); - -public: - static VariantValue PrimitiveTypeDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data); - static VariantValue ShortStringDecode(const VariantValueMetadata &value_metadata, const_data_ptr_t data); - static VariantValue ObjectDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, - const_data_ptr_t data); - static VariantValue ArrayDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, - const_data_ptr_t data); -}; - } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp deleted file mode 100644 index 8c7462b75..000000000 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include "duckdb/common/types/variant_value.hpp" -#include "reader/variant/variant_binary_decoder.hpp" - -namespace duckdb { - -class VariantShreddedConversion { -public: - VariantShreddedConversion() = delete; - -public: - static vector Convert(Vector &metadata, Vector &group, idx_t offset, idx_t length, idx_t total_size); - static vector ConvertShreddedLeaf(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, - idx_t length, idx_t total_size); - static vector ConvertShreddedArray(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, - idx_t length, idx_t total_size); - static vector ConvertShreddedObject(Vector &metadata, Vector &value, Vector &typed_value, - idx_t offset, idx_t length, idx_t total_size); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_extension.cpp b/src/duckdb/extension/parquet/parquet_extension.cpp index 961390d04..a727c1913 100644 --- a/src/duckdb/extension/parquet/parquet_extension.cpp +++ b/src/duckdb/extension/parquet/parquet_extension.cpp @@ -17,6 +17,16 @@ #include "zstd_file_system.hpp" #include "writer/primitive_column_writer.hpp" #include "writer/variant_column_writer.hpp" +#include "reader/variant_column_reader.hpp" + +#include +#include +#include +#include +#include +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" +#include "duckdb/common/constants.hpp" #include "duckdb/common/enums/file_compression_type.hpp" #include "duckdb/common/file_system.hpp" #include "duckdb/common/helper.hpp" @@ -70,6 +80,7 @@ #include "duckdb/storage/storage_info.hpp" #include "parquet_field_id.hpp" #include "parquet_types.h" +#include "reader/variant/parquet_variant_iterator.hpp" namespace duckdb { class ClientContext; @@ -961,6 +972,9 @@ static void LoadInternal(ExtensionLoader &loader) { // variant_to_parquet_variant loader.RegisterFunction(VariantColumnWriter::GetTransformFunction()); + // bytes_to_variant + loader.RegisterFunction(ParquetVariantConversion::GetBytesToVariantFunction()); + CopyFunction function("parquet"); function.copy_to_select = ParquetWriteSelect; function.copy_to_bind = ParquetWriteBind; diff --git a/src/duckdb/extension/parquet/parquet_reader.cpp b/src/duckdb/extension/parquet/parquet_reader.cpp index 04357972c..facb371e5 100644 --- a/src/duckdb/extension/parquet/parquet_reader.cpp +++ b/src/duckdb/extension/parquet/parquet_reader.cpp @@ -986,8 +986,14 @@ MultiFileColumnDefinition ParquetReader::ParseColumnDefinition(const FileMetaDat result.identifier = Value::INTEGER(parent_column_schema.field_id); } } - for (auto &child : element.children) { - result.children.push_back(ParseColumnDefinition(file_meta_data, child)); + // A GEOMETRY column is a leaf at the logical level - it only wraps an inner BLOB child internally so that the + // reader can validate/transform the WKB. Exposing that child here would make the column definition diverge from + // the (childless) global GEOMETRY column, breaking trivial column mapping and disabling row group pruning for + // spatial predicates. Treat it as a leaf. + if (element.schema_type != ParquetColumnSchemaType::GEOMETRY) { + for (auto &child : element.children) { + result.children.push_back(ParseColumnDefinition(file_meta_data, child)); + } } return result; } diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp index b8a282beb..7dd974e71 100644 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ b/src/duckdb/extension/parquet/parquet_writer.cpp @@ -471,9 +471,13 @@ class ParquetStatsAccumulator { ParquetWriteTransformData::ParquetWriteTransformData(ClientContext &context, const vector &types, vector> expressions_p) - : buffer(context, types, ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR), expressions(std::move(expressions_p)), - executor(context, expressions) { - chunk.Initialize(buffer.GetAllocator(), types); + : buffer(context, types, ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR), types(std::move(types)), + expressions(std::move(expressions_p)), executor(context, expressions) { + chunk.Initialize(buffer.GetAllocator(), this->types); +} + +bool ParquetWriteTransformData::MatchesTypes(const vector &other_types) const { + return types == other_types; } //! TODO: this doesnt work.. the ParquetWriteTransformData is shared with all threads, the method is stateful, but has @@ -640,22 +644,28 @@ void ParquetWriter::VerifyPreparedRowGroup(const PreparedRowGroup &prepared) con } void ParquetWriter::InitializePreprocessing(unique_ptr &transform_data) { - if (transform_data) { + vector transformed_types; + for (idx_t col_idx = 0; col_idx < column_writers.size(); col_idx++) { + auto &column_writer = *column_writers[col_idx]; + auto &original_type = options.sql_types[col_idx]; + if (!column_writer.HasTransform()) { + transformed_types.push_back(original_type); + continue; + } + transformed_types.push_back(column_writer.TransformedType()); + } + if (transform_data && transform_data->MatchesTypes(transformed_types)) { return; } - vector transformed_types; vector> transform_expressions; for (idx_t col_idx = 0; col_idx < column_writers.size(); col_idx++) { auto &column_writer = *column_writers[col_idx]; - auto &original_type = options.sql_types[col_idx]; - auto expr = make_uniq(original_type, col_idx); + auto expr = make_uniq(options.sql_types[col_idx], col_idx); if (!column_writer.HasTransform()) { - transformed_types.push_back(original_type); transform_expressions.push_back(std::move(expr)); continue; } - transformed_types.push_back(column_writer.TransformedType()); transform_expressions.push_back(column_writer.TransformExpression(std::move(expr))); } transform_data = make_uniq(context, transformed_types, std::move(transform_expressions)); diff --git a/src/duckdb/extension/parquet/reader/variant/parquet_variant_iterator.cpp b/src/duckdb/extension/parquet/reader/variant/parquet_variant_iterator.cpp new file mode 100644 index 000000000..cad60478b --- /dev/null +++ b/src/duckdb/extension/parquet/reader/variant/parquet_variant_iterator.cpp @@ -0,0 +1,736 @@ +#include "reader/variant/parquet_variant_iterator.hpp" + +#include "duckdb/common/types/variant/variant_builder.hpp" +#include "duckdb/common/vector/struct_vector.hpp" +#include "duckdb/common/vector/list_vector.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/hugeint.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/numeric_utils.hpp" +#include "reader/uuid_column_reader.hpp" +#include "utf8proc_wrapper.hpp" + +#include +#include +#include +#include + +namespace duckdb { + +namespace { + +//! Throw if reading 'size' bytes starting at 'ptr' would read past the end of the value buffer +void CheckBinaryRead(const_data_ptr_t ptr, idx_t size, const_data_ptr_t end) { + if (ptr + size > end) { + throw IOException("Data corruption detected, read of length_in_bytes (%d) would exceed buffer capacity", size); + } +} + +//! Read a little-endian unsigned integer of 'size' bytes (without advancing), bounds-checked against 'end' +idx_t ReadVarLE(idx_t size, const_data_ptr_t ptr, const_data_ptr_t end) { + D_ASSERT(size <= sizeof(idx_t)); + CheckBinaryRead(ptr, size, end); + idx_t result = 0; + memcpy(&result, ptr, size); + return result; +} + +//! Read a fixed-width little-endian value, bounds-checked against 'end' +template +T LoadChecked(const_data_ptr_t ptr, const_data_ptr_t end) { + CheckBinaryRead(ptr, sizeof(T), end); + return Load(ptr); +} + +//! Lazy reader over a Spark variant-encoded OBJECT (the value starts at its header byte). All structural +//! reads are bounds-checked against 'end' (one past the end of the 'value' blob). +struct BinaryObjectReader { + BinaryObjectReader(const VariantMetadata &metadata, const_data_ptr_t value_start, const_data_ptr_t end) + : metadata(metadata), end(end) { + auto value_metadata = VariantValueMetadata::FromHeaderByte(value_start[0]); + D_ASSERT(value_metadata.basic_type == VariantBasicType::OBJECT); + field_id_size = value_metadata.field_id_size; + field_offset_size = value_metadata.field_offset_size; + auto data = value_start + 1; + if (value_metadata.is_large) { + count = LoadChecked(data, end); + data += sizeof(uint32_t); + } else { + count = LoadChecked(data, end); + data += sizeof(uint8_t); + } + field_ids = data; + field_offsets = data + (count * field_id_size); + values = field_offsets + (NumericCast(count + 1) * field_offset_size); + } + + string_t Key(idx_t i) const { + auto field_id = ReadVarLE(field_id_size, field_ids + (i * field_id_size), end); + if (field_id >= metadata.strings.size()) { + throw IOException("Corrupted VARIANT 'value' buffer"); + } + auto &key = metadata.strings[field_id]; + return string_t(key.c_str(), NumericCast(key.size())); + } + const_data_ptr_t Child(idx_t i) const { + auto offset = ReadVarLE(field_offset_size, field_offsets + (i * field_offset_size), end); + auto child = values + offset; + //! The child's header byte must be readable + CheckBinaryRead(child, 1, end); + return child; + } + + const VariantMetadata &metadata; + const_data_ptr_t end; + const_data_ptr_t field_ids; + const_data_ptr_t field_offsets; + const_data_ptr_t values; + uint32_t field_id_size; + uint32_t field_offset_size; + idx_t count; +}; + +//! Lazy reader over a Spark variant-encoded ARRAY (the value starts at its header byte). All structural +//! reads are bounds-checked against 'end' (one past the end of the 'value' blob). +struct BinaryArrayReader { + BinaryArrayReader(const_data_ptr_t value_start, const_data_ptr_t end) : end(end) { + auto value_metadata = VariantValueMetadata::FromHeaderByte(value_start[0]); + D_ASSERT(value_metadata.basic_type == VariantBasicType::ARRAY); + field_offset_size = value_metadata.field_offset_size; + auto data = value_start + 1; + if (value_metadata.is_large) { + count = LoadChecked(data, end); + data += sizeof(uint32_t); + } else { + count = LoadChecked(data, end); + data += sizeof(uint8_t); + } + field_offsets = data; + values = field_offsets + (NumericCast(count + 1) * field_offset_size); + } + + const_data_ptr_t end; + const_data_ptr_t field_offsets; + const_data_ptr_t values; + uint32_t field_offset_size; + idx_t count; +}; + +//! The VariantLogicalType of a (valid) shredded leaf at logical position 'index'. Mirrors the writer's +//! type mapping; note BINARY is kept as a BLOB so the type is preserved (the base64 conversion happens +//! only when serializing the VARIANT to JSON). +VariantLogicalType ShreddedLeafTypeId(const ShreddedGroupView &view, idx_t index) { + switch (view.typed_type.id()) { + case LogicalTypeId::BOOLEAN: { + auto leaf_index = view.leaf_format.sel->get_index(index); + return UnifiedVectorFormat::GetData(view.leaf_format)[leaf_index] ? VariantLogicalType::BOOL_TRUE + : VariantLogicalType::BOOL_FALSE; + } + case LogicalTypeId::TINYINT: + return VariantLogicalType::INT8; + case LogicalTypeId::SMALLINT: + return VariantLogicalType::INT16; + case LogicalTypeId::INTEGER: + return VariantLogicalType::INT32; + case LogicalTypeId::BIGINT: + return VariantLogicalType::INT64; + case LogicalTypeId::FLOAT: + return VariantLogicalType::FLOAT; + case LogicalTypeId::DOUBLE: + return VariantLogicalType::DOUBLE; + case LogicalTypeId::DECIMAL: + return VariantLogicalType::DECIMAL; + case LogicalTypeId::DATE: + return VariantLogicalType::DATE; + case LogicalTypeId::TIME: + return VariantLogicalType::TIME_MICROS; + case LogicalTypeId::TIMESTAMP_TZ: + return VariantLogicalType::TIMESTAMP_MICROS_TZ; + case LogicalTypeId::TIMESTAMP_TZ_NS: + return VariantLogicalType::TIMESTAMP_NANOS_TZ; + case LogicalTypeId::TIMESTAMP: + return VariantLogicalType::TIMESTAMP_MICROS; + case LogicalTypeId::TIMESTAMP_NS: + return VariantLogicalType::TIMESTAMP_NANOS; + case LogicalTypeId::BLOB: + return VariantLogicalType::BLOB; + case LogicalTypeId::VARCHAR: + return VariantLogicalType::VARCHAR; + case LogicalTypeId::UUID: + return VariantLogicalType::UUID; + default: + throw NotImplementedException("Variant shredding on type: '%s' is not implemented", view.typed_type.ToString()); + } +} + +//! The VariantLogicalType of a Spark variant-encoded value (the value starts at its header byte) +VariantLogicalType BinaryTypeId(const_data_ptr_t data) { + auto value_metadata = VariantValueMetadata::FromHeaderByte(data[0]); + switch (value_metadata.basic_type) { + case VariantBasicType::SHORT_STRING: + return VariantLogicalType::VARCHAR; + case VariantBasicType::OBJECT: + return VariantLogicalType::OBJECT; + case VariantBasicType::ARRAY: + return VariantLogicalType::ARRAY; + case VariantBasicType::PRIMITIVE: + break; + default: + throw InternalException("Unexpected VariantBasicType"); + } + switch (value_metadata.primitive_type) { + case VariantPrimitiveType::NULL_TYPE: + return VariantLogicalType::VARIANT_NULL; + case VariantPrimitiveType::BOOLEAN_TRUE: + return VariantLogicalType::BOOL_TRUE; + case VariantPrimitiveType::BOOLEAN_FALSE: + return VariantLogicalType::BOOL_FALSE; + case VariantPrimitiveType::INT8: + return VariantLogicalType::INT8; + case VariantPrimitiveType::INT16: + return VariantLogicalType::INT16; + case VariantPrimitiveType::INT32: + return VariantLogicalType::INT32; + case VariantPrimitiveType::INT64: + return VariantLogicalType::INT64; + case VariantPrimitiveType::DOUBLE: + return VariantLogicalType::DOUBLE; + case VariantPrimitiveType::FLOAT: + return VariantLogicalType::FLOAT; + case VariantPrimitiveType::DECIMAL4: + case VariantPrimitiveType::DECIMAL8: + case VariantPrimitiveType::DECIMAL16: + return VariantLogicalType::DECIMAL; + case VariantPrimitiveType::DATE: + return VariantLogicalType::DATE; + case VariantPrimitiveType::TIMESTAMP_MICROS: + return VariantLogicalType::TIMESTAMP_MICROS_TZ; + case VariantPrimitiveType::TIMESTAMP_NTZ_MICROS: + return VariantLogicalType::TIMESTAMP_MICROS; + case VariantPrimitiveType::BINARY: + //! Keep the raw bytes as a BLOB so the type is preserved (base64 conversion happens at JSON time) + return VariantLogicalType::BLOB; + case VariantPrimitiveType::STRING: + return VariantLogicalType::VARCHAR; + case VariantPrimitiveType::TIME_NTZ_MICROS: + return VariantLogicalType::TIME_MICROS; + case VariantPrimitiveType::TIMESTAMP_NANOS: + return VariantLogicalType::TIMESTAMP_NANOS_TZ; + case VariantPrimitiveType::TIMESTAMP_NTZ_NANOS: + return VariantLogicalType::TIMESTAMP_NANOS; + case VariantPrimitiveType::UUID: + return VariantLogicalType::UUID; + default: + throw NotImplementedException("Variant PrimitiveType (%d) is not supported", + static_cast(value_metadata.primitive_type)); + } +} + +//! The implied precision of a decimal value: floor(log10(val)) + 1 +template +uint32_t ComputeDecimalWidth(T value) { + if (value == 0) { + return 1; + } + auto abs_val = value; + if (abs_val < 0) { + abs_val = -abs_val; + } + return static_cast(floor(log10(static_cast(abs_val))) + 1); +} + +//! Whether T is a physical storage type a DECIMAL value can be re-encoded into +template +struct IsDecimalStorage : std::false_type {}; +template <> +struct IsDecimalStorage : std::true_type {}; +template <> +struct IsDecimalStorage : std::true_type {}; +template <> +struct IsDecimalStorage : std::true_type {}; +template <> +struct IsDecimalStorage : std::true_type {}; + +template +T CastDecimalValue(SRC value) { + if constexpr (std::is_same::value) { + return hugeint_t(value); + } else if constexpr (std::is_same::value) { + //! Only reachable for DECIMAL16, which always re-encodes to hugeint (the branch above) + return Hugeint::Cast(value); + } else { + return NumericCast(value); + } +} + +//! Read the (Parquet UUID-ordered) value as T. UUID is always read as hugeint_t. +template +T ReadBinaryUUID(const_data_ptr_t payload) { + if constexpr (std::is_same::value) { + return UUIDValueConversion::ReadParquetUUID(payload); + } else { + throw InternalException("Variant UUID must be read as hugeint_t"); + } +} + +//! Read the decimal value (re-encoded as T = the physical type implied by the width). 'payload' points at +//! the scale byte. T must be one of int16/int32/int64/hugeint (see VariantDecimalPhysicalType). +template +T ReadBinaryDecimalValue(VariantPrimitiveType primitive_type, const_data_ptr_t payload) { + if constexpr (IsDecimalStorage::value) { + auto value_data = payload + sizeof(uint8_t); // skip the scale byte + switch (primitive_type) { + case VariantPrimitiveType::DECIMAL4: + return CastDecimalValue(Load(value_data)); + case VariantPrimitiveType::DECIMAL8: + return CastDecimalValue(Load(value_data)); + default: { + D_ASSERT(primitive_type == VariantPrimitiveType::DECIMAL16); + hugeint_t value; + value.lower = Load(value_data); + value.upper = Load(value_data + sizeof(uint64_t)); + return CastDecimalValue(value); + } + } + } else { + throw InternalException("Variant DECIMAL must be read as int16/int32/int64/hugeint"); + } +} + +} // namespace + +//===--------------------------------------------------------------------===// +// ShreddedGroupView +//===--------------------------------------------------------------------===// +void ShreddedGroupView::Build(Vector &group) { + D_ASSERT(group.GetType().id() == LogicalTypeId::STRUCT); + auto &entries = StructVector::GetEntries(group); + auto &child_types = StructType::GetChildTypes(group.GetType()); + D_ASSERT(entries.size() == child_types.size()); + + //! From the spec: the Parquet columns storing variant metadata and values must be accessed by name + optional_ptr value_vec; + optional_ptr typed_vec; + for (idx_t i = 0; i < entries.size(); i++) { + auto &name = child_types[i].first; + if (name == "value") { + value_vec = entries[i]; + } else if (name == "typed_value") { + typed_vec = entries[i]; + } else { + throw InvalidInputException("Variant group can only contain 'value'/'typed_value', not: %s", name); + } + } + if (!value_vec) { + throw InvalidInputException("Required column 'value' not found in Variant group"); + } + + value = make_uniq>(*value_vec); + + if (!typed_vec) { + has_typed_value = false; + return; + } + has_typed_value = true; + typed_type = typed_vec->GetType(); + + switch (typed_type.id()) { + case LogicalTypeId::STRUCT: { + kind = ParquetGroupKind::OBJECT; + typed_validity = make_uniq(*typed_vec); + auto &fields_meta = StructType::GetChildTypes(typed_type); + auto &field_entries = StructVector::GetEntries(*typed_vec); + for (idx_t i = 0; i < field_entries.size(); i++) { + field_names.push_back(fields_meta[i].first.GetIdentifierName()); + auto field_view = make_uniq(); + field_view->Build(field_entries[i]); + fields.push_back(std::move(field_view)); + } + break; + } + case LogicalTypeId::LIST: { + kind = ParquetGroupKind::ARRAY; + list = make_uniq>(*typed_vec); + element = make_uniq(); + element->Build(ListVector::GetChildMutable(*typed_vec)); + break; + } + default: + kind = ParquetGroupKind::LEAF; + typed_vec->ToUnifiedFormat(leaf_format); + break; + } +} + +//===--------------------------------------------------------------------===// +// ParquetVariantIterator +//===--------------------------------------------------------------------===// +ParquetVariantIterator::ParquetVariantIterator(Vector &metadata_vec, Vector &group) : metadata(metadata_vec) { + root_view.Build(group); +} + +ParquetVariantIterator::ParquetVariantIterator(Vector &metadata_vec) : metadata(metadata_vec) { +} + +void ParquetVariantIterator::BeginRow(idx_t row) { + current_row = row; + current_metadata.reset(); +} + +const VariantMetadata &ParquetVariantIterator::GetMetadata() const { + if (!current_metadata) { + current_metadata = make_uniq(metadata[current_row].GetValueUnsafe()); + } + return *current_metadata; +} + +ParquetVariantNode ParquetVariantIterator::ResolveGroup(const ShreddedGroupView &view, idx_t index) const { + if (view.has_typed_value) { + bool typed_valid = false; + switch (view.kind) { + case ParquetGroupKind::LEAF: + typed_valid = view.leaf_format.validity.RowIsValid(view.leaf_format.sel->get_index(index)); + break; + case ParquetGroupKind::ARRAY: + typed_valid = (*view.list)[index].IsValid(); + break; + case ParquetGroupKind::OBJECT: + typed_valid = view.typed_validity->IsValid(index); + break; + } + if (typed_valid) { + if (view.kind == ParquetGroupKind::OBJECT) { + //! (Partially) shredded object - the binary 'value', if present, holds the leftover fields + const_data_ptr_t overlay = nullptr; + const_data_ptr_t overlay_end = nullptr; + auto value_entry = (*view.value)[index]; + if (value_entry.IsValid()) { + auto &overlay_blob = value_entry.GetValueUnsafe(); + auto overlay_data = const_data_ptr_cast(overlay_blob.GetData()); + overlay_end = overlay_data + overlay_blob.GetSize(); + CheckBinaryRead(overlay_data, 1, overlay_end); + if (VariantValueMetadata::FromHeaderByte(overlay_data[0]).basic_type != VariantBasicType::OBJECT) { + throw InvalidInputException( + "Partially shredded objects have to encode Object Variants in the 'value'"); + } + overlay = overlay_data; + } + return ParquetVariantNode::MakeShredded(*this, view, index, overlay, overlay_end); + } + //! LEAF or ARRAY - the binary 'value' is irrelevant (a leaf is never partially shredded) + return ParquetVariantNode::MakeShredded(*this, view, index); + } + } + + //! No (valid) shredded value - fall back to the binary 'value' + auto value_entry = (*view.value)[index]; + if (value_entry.IsValid()) { + auto &value_blob = value_entry.GetValueUnsafe(); + auto data = const_data_ptr_cast(value_blob.GetData()); + auto end = data + value_blob.GetSize(); + CheckBinaryRead(data, 1, end); + if (view.has_typed_value && view.kind == ParquetGroupKind::OBJECT && + VariantValueMetadata::FromHeaderByte(data[0]).basic_type == VariantBasicType::OBJECT) { + throw InvalidInputException( + "When 'typed_value' for a shredded Object is NULL, 'value' can not contain an Object value"); + } + return ParquetVariantNode::MakeBinary(*this, data, end); + } + return ParquetVariantNode::MakeMissing(); +} + +ParquetVariantNode ParquetVariantIterator::Root(idx_t row) const { + auto root = ResolveGroup(root_view, row); + //! A root value is never "missing" - treat any such case as a SQL NULL + return root.IsMissing() ? ParquetVariantNode::MakeNull() : root; +} + +ParquetVariantNode ParquetVariantIterator::BinaryRoot() const { + //! The metadata and the value share the same blob: the value bytes start right after the metadata + auto &variant_metadata = GetMetadata(); + auto blob_start = const_data_ptr_cast(variant_metadata.metadata.GetData()); + auto blob_end = blob_start + variant_metadata.metadata.GetSize(); + auto value_start = blob_start + variant_metadata.total_size; + //! The value's header byte must be readable + CheckBinaryRead(value_start, 1, blob_end); + return ParquetVariantNode::MakeBinary(*this, value_start, blob_end); +} + +//===--------------------------------------------------------------------===// +// ParquetVariantNode +//===--------------------------------------------------------------------===// +VariantLogicalType ParquetVariantNode::GetTypeId() const { + switch (kind) { + case Kind::NULL_VALUE: + return VariantLogicalType::VARIANT_NULL; + case Kind::SHREDDED: + switch (view->kind) { + case ParquetGroupKind::OBJECT: + return VariantLogicalType::OBJECT; + case ParquetGroupKind::ARRAY: + return VariantLogicalType::ARRAY; + default: + return ShreddedLeafTypeId(*view, index); + } + case Kind::BINARY: + return BinaryTypeId(binary); + default: + throw InternalException("ParquetVariantNode::GetTypeId on a MISSING value"); + } +} + +template +T ParquetVariantNode::GetData() const { + if (kind == Kind::SHREDDED) { + return UnifiedVectorFormat::GetData(view->leaf_format)[view->leaf_format.sel->get_index(index)]; + } + D_ASSERT(kind == Kind::BINARY); + auto value_metadata = VariantValueMetadata::FromHeaderByte(binary[0]); + auto payload = binary + 1; + switch (value_metadata.primitive_type) { + case VariantPrimitiveType::UUID: + CheckBinaryRead(payload, sizeof(hugeint_t), binary_end); + return ReadBinaryUUID(payload); + case VariantPrimitiveType::DECIMAL4: + CheckBinaryRead(payload, sizeof(uint8_t) + sizeof(int32_t), binary_end); + return ReadBinaryDecimalValue(value_metadata.primitive_type, payload); + case VariantPrimitiveType::DECIMAL8: + CheckBinaryRead(payload, sizeof(uint8_t) + sizeof(int64_t), binary_end); + return ReadBinaryDecimalValue(value_metadata.primitive_type, payload); + case VariantPrimitiveType::DECIMAL16: + CheckBinaryRead(payload, sizeof(uint8_t) + sizeof(hugeint_t), binary_end); + return ReadBinaryDecimalValue(value_metadata.primitive_type, payload); + default: + //! Fixed-width primitives are stored in the canonical little-endian layout + return LoadChecked(payload, binary_end); + } +} + +string_t ParquetVariantNode::GetString() const { + if (kind == Kind::SHREDDED) { + auto str = UnifiedVectorFormat::GetData(view->leaf_format)[view->leaf_format.sel->get_index(index)]; + if (view->typed_type.id() == LogicalTypeId::BLOB) { + //! Keep the raw bytes - the value is emitted as a BLOB (base64 conversion happens at JSON time) + return str; + } + if (!Utf8Proc::IsValid(str.GetData(), str.GetSize())) { + throw InternalException("Can't decode Variant string, it isn't valid UTF8"); + } + return str; + } + D_ASSERT(kind == Kind::BINARY); + auto value_metadata = VariantValueMetadata::FromHeaderByte(binary[0]); + auto payload = binary + 1; + if (value_metadata.basic_type == VariantBasicType::SHORT_STRING) { + auto string_data = const_char_ptr_cast(payload); + CheckBinaryRead(payload, value_metadata.string_size, binary_end); + if (!Utf8Proc::IsValid(string_data, value_metadata.string_size)) { + throw InternalException("Can't decode Variant short-string, string isn't valid UTF8"); + } + return string_t(string_data, value_metadata.string_size); + } + auto size = LoadChecked(payload, binary_end); + auto string_data = const_char_ptr_cast(payload + sizeof(uint32_t)); + CheckBinaryRead(payload + sizeof(uint32_t), size, binary_end); + if (value_metadata.primitive_type == VariantPrimitiveType::BINARY) { + //! Keep the raw bytes - the value is emitted as a BLOB (base64 conversion happens at JSON time) + return string_t(string_data, size); + } + if (!Utf8Proc::IsValid(string_data, size)) { + throw InternalException("Can't decode Variant string, it isn't valid UTF8"); + } + return string_t(string_data, size); +} + +VariantDecimalProperties ParquetVariantNode::GetDecimalProperties() const { + if (kind == Kind::SHREDDED) { + uint8_t width; + uint8_t scale; + view->typed_type.GetDecimalProperties(width, scale); + return VariantDecimalProperties(width, scale); + } + D_ASSERT(kind == Kind::BINARY); + auto value_metadata = VariantValueMetadata::FromHeaderByte(binary[0]); + auto payload = binary + 1; + uint8_t scale = LoadChecked(payload, binary_end); + auto value_data = payload + sizeof(uint8_t); + switch (value_metadata.primitive_type) { + case VariantPrimitiveType::DECIMAL4: + return VariantDecimalProperties(ComputeDecimalWidth(LoadChecked(value_data, binary_end)), + scale); + case VariantPrimitiveType::DECIMAL8: + return VariantDecimalProperties(ComputeDecimalWidth(LoadChecked(value_data, binary_end)), + scale); + default: + D_ASSERT(value_metadata.primitive_type == VariantPrimitiveType::DECIMAL16); + return VariantDecimalProperties(DecimalWidth::max, scale); + } +} + +ParquetObjectIterator ParquetVariantNode::GetObjectChildren(VariantIterationOrder order) const { + (void)order; + if (kind == Kind::SHREDDED) { + return ParquetObjectIterator(*state, *view, index, binary, binary_end); + } + D_ASSERT(kind == Kind::BINARY); + return ParquetObjectIterator(*state, state->GetMetadata(), binary, binary_end); +} + +ParquetArrayIterator ParquetVariantNode::GetArrayChildren() const { + if (kind == Kind::SHREDDED) { + return ParquetArrayIterator(*state, *view, index); + } + D_ASSERT(kind == Kind::BINARY); + return ParquetArrayIterator(*state, state->GetMetadata(), binary, binary_end); +} + +//===--------------------------------------------------------------------===// +// ParquetObjectIterator +//===--------------------------------------------------------------------===// +void ParquetObjectIterator::Finalize() { + std::sort(ordered_entries.begin(), ordered_entries.end(), + [](const ParquetObjectEntry &a, const ParquetObjectEntry &b) { return a.key < b.key; }); +} + +ParquetObjectIterator::ParquetObjectIterator(const ParquetVariantIterator &state, const ShreddedGroupView &view, + idx_t index, const_data_ptr_t overlay, const_data_ptr_t overlay_end) { + //! Typed (shredded) fields - skipping the ones that are missing for this row + std::set typed_keys; + for (idx_t i = 0; i < view.fields.size(); i++) { + auto node = state.ResolveGroup(*view.fields[i], index); + if (node.IsMissing()) { + continue; + } + auto &name = view.field_names[i]; + typed_keys.insert(name); + ordered_entries.push_back( + ParquetObjectEntry {string_t(name.c_str(), NumericCast(name.size())), node}); + } + //! Leftover (overlay) fields from the binary 'value' - typed fields win on key collisions + if (overlay) { + BinaryObjectReader reader(state.GetMetadata(), overlay, overlay_end); + for (idx_t i = 0; i < reader.count; i++) { + auto key = reader.Key(i); + if (typed_keys.count(key.GetString())) { + continue; + } + ordered_entries.push_back( + ParquetObjectEntry {key, ParquetVariantNode::MakeBinary(state, reader.Child(i), overlay_end)}); + } + } + Finalize(); +} + +ParquetObjectIterator::ParquetObjectIterator(const ParquetVariantIterator &state, const VariantMetadata &metadata, + const_data_ptr_t data, const_data_ptr_t end) { + BinaryObjectReader reader(metadata, data, end); + for (idx_t i = 0; i < reader.count; i++) { + ordered_entries.push_back( + ParquetObjectEntry {reader.Key(i), ParquetVariantNode::MakeBinary(state, reader.Child(i), end)}); + } + Finalize(); +} + +//===--------------------------------------------------------------------===// +// ParquetArrayIterator +//===--------------------------------------------------------------------===// +ParquetArrayIterator::ParquetArrayIterator(const ParquetVariantIterator &state, const ShreddedGroupView &view, + idx_t index) + : state(state), shredded(true), element(view.element.get()) { + auto entry = (*view.list)[index].GetValueUnsafe(); + base = entry.offset; + length = entry.length; +} + +ParquetArrayIterator::ParquetArrayIterator(const ParquetVariantIterator &state, const VariantMetadata &metadata, + const_data_ptr_t data, const_data_ptr_t end) + : state(state), shredded(false) { + (void)metadata; + BinaryArrayReader reader(data, end); + length = reader.count; + field_offsets = reader.field_offsets; + values = reader.values; + binary_end = end; + field_offset_size = reader.field_offset_size; +} + +ParquetVariantNode ParquetArrayIterator::operator[](idx_t i) const { + if (shredded) { + return state.get().ResolveGroup(*element, base + i); + } + auto offset = ReadVarLE(field_offset_size, field_offsets + (i * field_offset_size), binary_end); + auto child = values + offset; + //! The child's header byte must be readable + CheckBinaryRead(child, 1, binary_end); + return ParquetVariantNode::MakeBinary(state.get(), child, binary_end); +} + +//===--------------------------------------------------------------------===// +// ParquetVariantIteratorSource +//===--------------------------------------------------------------------===// +bool ParquetVariantIteratorSource::Emit(idx_t row, VariantBuilder &builder) { + iterator.BeginRow(row); + auto root = iterator.Root(row); + if (root.IsNull()) { + return true; + } + //! A root that resolves to a (variant) NULL is a genuine SQL NULL row + if (root.GetTypeId() == VariantLogicalType::VARIANT_NULL) { + return true; + } + EmitIterator(root, builder); + return false; +} + +//===--------------------------------------------------------------------===// +// ParquetVariantConversion +//===--------------------------------------------------------------------===// +void ParquetVariantConversion::Convert(Vector &metadata, Vector &group, Vector &result, idx_t count) { + ParquetVariantIterator iterator(metadata, group); + ParquetVariantIteratorSource source(iterator); + BuildVariant(source, count, result); +} + +namespace { + +//! BuildVariant source over binary Variant blobs (each row being the metadata followed by the value) +struct ParquetBinaryVariantSource { + ParquetBinaryVariantSource(ParquetVariantIterator &iterator, Vector &blob) : iterator(iterator) { + blob.ToUnifiedFormat(blob_format); + } + + bool Emit(idx_t row, VariantBuilder &builder) { + if (!blob_format.validity.RowIsValid(blob_format.sel->get_index(row))) { + return true; + } + iterator.BeginRow(row); + EmitIterator(iterator.BinaryRoot(), builder); + return false; + } + + ParquetVariantIterator &iterator; + UnifiedVectorFormat blob_format; +}; + +} // namespace + +void ParquetVariantConversion::ConvertBinary(Vector &metadata_and_value, Vector &result, idx_t count) { + ParquetVariantIterator iterator(metadata_and_value); + ParquetBinaryVariantSource source(iterator, metadata_and_value); + BuildVariant(source, count, result); +} + +static void VariantBytesToVariantFunction(DataChunk &input, ExpressionState &state, Vector &result) { + ParquetVariantConversion::ConvertBinary(input.data[0], result, input.size()); +} + +ScalarFunction ParquetVariantConversion::GetBytesToVariantFunction() { + ScalarFunction function("variant_bytes_to_variant", {LogicalType::BLOB}, LogicalType::VARIANT(), + VariantBytesToVariantFunction); + function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + return function; +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp b/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp index 3a52ef09f..8d2dcac4c 100644 --- a/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp +++ b/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp @@ -1,21 +1,6 @@ #include "reader/variant/variant_binary_decoder.hpp" #include -#include -#include - -#include "utf8proc_wrapper.hpp" -#include "reader/uuid_column_reader.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/decimal.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/blob.hpp" -#include "duckdb/common/assert.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/hugeint.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/types/datetime.hpp" -#include "duckdb/common/types/value.hpp" static constexpr uint8_t VERSION_MASK = 0xF; static constexpr uint8_t SORTED_STRINGS_MASK = 0x1; @@ -44,13 +29,18 @@ namespace duckdb { namespace { -static idx_t ReadVariableLengthLittleEndian(idx_t length_in_bytes, const_data_ptr_t &ptr) { +static idx_t ReadVariableLengthLittleEndian(idx_t length_in_bytes, const_data_ptr_t ptr, idx_t &offset, + const idx_t capacity) { if (length_in_bytes > sizeof(idx_t)) { throw NotImplementedException("Can't read little-endian value of %d bytes", length_in_bytes); } + if (offset + length_in_bytes > capacity) { + throw IOException("Data corruption detected, read of length_in_bytes (%d) would exceed buffer capacity", + length_in_bytes); + } idx_t result = 0; - memcpy(reinterpret_cast(&result), ptr, length_in_bytes); - ptr += length_in_bytes; + memcpy(reinterpret_cast(&result), ptr + offset, length_in_bytes); + offset += length_in_bytes; return result; } @@ -71,21 +61,34 @@ VariantMetadataHeader VariantMetadataHeader::FromHeaderByte(uint8_t byte) { } VariantMetadata::VariantMetadata(const string_t &metadata) : metadata(metadata) { - auto metadata_data = metadata.GetData(); + auto metadata_data = reinterpret_cast(metadata.GetData()); + const auto metadata_buffer_capacity = metadata.GetSize(); + if (!metadata_data || metadata.GetSize() < 1) { + throw IOException("Corrupted VARIANT 'metadata' buffer, empty or nullptr"); + } - header = VariantMetadataHeader::FromHeaderByte(metadata_data[0]); + idx_t metadata_offset = 0; + header = VariantMetadataHeader::FromHeaderByte(metadata_data[metadata_offset]); + metadata_offset += sizeof(uint8_t); - const_data_ptr_t ptr = reinterpret_cast(metadata_data + sizeof(uint8_t)); - idx_t dictionary_size = ReadVariableLengthLittleEndian(header.offset_size, ptr); + idx_t dictionary_size = + ReadVariableLengthLittleEndian(header.offset_size, metadata_data, metadata_offset, metadata_buffer_capacity); - auto offsets = ptr; - auto bytes = offsets + ((dictionary_size + 1) * header.offset_size); - idx_t last_offset = ReadVariableLengthLittleEndian(header.offset_size, ptr); + auto data_start = metadata_offset + ((dictionary_size + 1) * header.offset_size); + idx_t last_offset = + ReadVariableLengthLittleEndian(header.offset_size, metadata_data, metadata_offset, metadata_buffer_capacity); for (idx_t i = 0; i < dictionary_size; i++) { - auto next_offset = ReadVariableLengthLittleEndian(header.offset_size, ptr); - strings.emplace_back(reinterpret_cast(bytes + last_offset), next_offset - last_offset); + auto next_offset = ReadVariableLengthLittleEndian(header.offset_size, metadata_data, metadata_offset, + metadata_buffer_capacity); + const idx_t string_size = next_offset - last_offset; + if (data_start + last_offset + string_size > metadata_buffer_capacity) { + throw IOException("Corrupted VARIANT 'metadata' buffer"); + } + strings.emplace_back(reinterpret_cast(metadata_data + data_start + last_offset), string_size); last_offset = next_offset; } + //! header byte + offsets region + string bytes + total_size = metadata_offset + last_offset; } VariantValueMetadata VariantValueMetadata::FromHeaderByte(uint8_t byte) { @@ -118,247 +121,4 @@ VariantValueMetadata VariantValueMetadata::FromHeaderByte(uint8_t byte) { return result; } -template -static T DecodeDecimal(const_data_ptr_t data, uint8_t &scale, uint8_t &width) { - scale = Load(data); - data++; - - auto result = Load(data); - auto abs_val = result; - if (abs_val < 0) { - abs_val = -abs_val; - } - uint8_t digits = floor(log10(abs_val)) + 1; - width = digits; - return result; -} - -template <> -hugeint_t DecodeDecimal(const_data_ptr_t data, uint8_t &scale, uint8_t &width) { - scale = Load(data); - data++; - - hugeint_t result; - result.lower = Load(data); - result.upper = Load(data + sizeof(uint64_t)); - //! FIXME: The spec says: - //! The implied precision of a decimal value is `floor(log_10(val)) + 1` - width = DecimalWidth::max; - return result; -} - -VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantValueMetadata &value_metadata, - const_data_ptr_t data) { - switch (value_metadata.primitive_type) { - case VariantPrimitiveType::NULL_TYPE: { - return VariantValue::NullValue(); - } - case VariantPrimitiveType::BOOLEAN_TRUE: { - return VariantValue(Value::BOOLEAN(true)); - } - case VariantPrimitiveType::BOOLEAN_FALSE: { - return VariantValue(Value::BOOLEAN(false)); - } - case VariantPrimitiveType::INT8: { - auto value = Load(data); - return VariantValue(Value::TINYINT(value)); - } - case VariantPrimitiveType::INT16: { - auto value = Load(data); - return VariantValue(Value::SMALLINT(value)); - } - case VariantPrimitiveType::INT32: { - auto value = Load(data); - return VariantValue(Value::INTEGER(value)); - } - case VariantPrimitiveType::INT64: { - auto value = Load(data); - return VariantValue(Value::BIGINT(value)); - } - case VariantPrimitiveType::DOUBLE: { - double value = Load(data); - return VariantValue(Value::DOUBLE(value)); - } - case VariantPrimitiveType::FLOAT: { - float value = Load(data); - return VariantValue(Value::FLOAT(value)); - } - case VariantPrimitiveType::DECIMAL4: { - uint8_t scale; - uint8_t width; - - auto value = DecodeDecimal(data, scale, width); - return VariantValue(Value::DECIMAL(value, width, scale)); - } - case VariantPrimitiveType::DECIMAL8: { - uint8_t scale; - uint8_t width; - - auto value = DecodeDecimal(data, scale, width); - return VariantValue(Value::DECIMAL(value, width, scale)); - } - case VariantPrimitiveType::DECIMAL16: { - uint8_t scale; - uint8_t width; - - auto value = DecodeDecimal(data, scale, width); - return VariantValue(Value::DECIMAL(value, width, scale)); - } - case VariantPrimitiveType::DATE: { - date_t value; - value.days = Load(data); - return VariantValue(Value::DATE(value)); - } - case VariantPrimitiveType::TIMESTAMP_MICROS: { - timestamp_tz_t micros_ts_tz; - micros_ts_tz.value = Load(data); - return VariantValue(Value::TIMESTAMPTZ(micros_ts_tz)); - } - case VariantPrimitiveType::TIMESTAMP_NTZ_MICROS: { - timestamp_t micros_ts; - micros_ts.value = Load(data); - - auto value = Value::TIMESTAMP(micros_ts); - return VariantValue(std::move(value)); - } - case VariantPrimitiveType::BINARY: { - //! Follow the JSON serialization guide by converting BINARY to Base64: - //! For example: `"dmFyaWFudAo="` - auto size = Load(data); - auto string_data = reinterpret_cast(data + sizeof(uint32_t)); - auto base64_string = Blob::ToBase64(string_t(string_data, size)); - return VariantValue(Value(base64_string)); - } - case VariantPrimitiveType::STRING: { - auto size = Load(data); - auto string_data = reinterpret_cast(data + sizeof(uint32_t)); - if (!Utf8Proc::IsValid(string_data, size)) { - throw InternalException("Can't decode Variant short-string, string isn't valid UTF8"); - } - return VariantValue(Value(string(string_data, size))); - } - case VariantPrimitiveType::TIME_NTZ_MICROS: { - dtime_t micros_time; - micros_time.micros = Load(data); - return VariantValue(Value::TIME(micros_time)); - } - case VariantPrimitiveType::TIMESTAMP_NANOS: { - timestamp_tz_ns_t nanos_ts; - nanos_ts.value = Load(data); - - return VariantValue(Value::TIMESTAMPTZNS(nanos_ts)); - } - case VariantPrimitiveType::TIMESTAMP_NTZ_NANOS: { - timestamp_ns_t nanos_ts; - nanos_ts.value = Load(data); - - auto value = Value::TIMESTAMPNS(nanos_ts); - return VariantValue(std::move(value)); - } - case VariantPrimitiveType::UUID: { - auto uuid_value = UUIDValueConversion::ReadParquetUUID(data); - return VariantValue(Value::UUID(uuid_value)); - } - default: - throw NotImplementedException("Variant PrimitiveTypeDecode not implemented for type (%d)", - static_cast(value_metadata.primitive_type)); - } -} - -VariantValue VariantBinaryDecoder::ShortStringDecode(const VariantValueMetadata &value_metadata, - const_data_ptr_t data) { - D_ASSERT(value_metadata.string_size < 64); - auto string_data = reinterpret_cast(data); - if (!Utf8Proc::IsValid(string_data, value_metadata.string_size)) { - throw InternalException("Can't decode Variant short-string, string isn't valid UTF8"); - } - return VariantValue(Value(string(string_data, value_metadata.string_size))); -} - -VariantValue VariantBinaryDecoder::ObjectDecode(const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, const_data_ptr_t data) { - VariantValue ret(VariantValueType::OBJECT); - - auto field_offset_size = value_metadata.field_offset_size; - auto field_id_size = value_metadata.field_id_size; - auto is_large = value_metadata.is_large; - - idx_t num_elements; - if (is_large) { - num_elements = Load(data); - data += sizeof(uint32_t); - } else { - num_elements = Load(data); - data += sizeof(uint8_t); - } - - auto field_ids = data; - auto field_offsets = data + (num_elements * field_id_size); - auto values = field_offsets + (NumericCast(num_elements + 1) * field_offset_size); - - idx_t last_offset = ReadVariableLengthLittleEndian(field_offset_size, field_offsets); - for (idx_t i = 0; i < num_elements; i++) { - auto field_id = ReadVariableLengthLittleEndian(field_id_size, field_ids); - auto next_offset = ReadVariableLengthLittleEndian(field_offset_size, field_offsets); - - auto value = Decode(metadata, values + last_offset); - auto &key = metadata.strings[field_id]; - - ret.AddChild(key, std::move(value)); - last_offset = next_offset; - } - return ret; -} - -VariantValue VariantBinaryDecoder::ArrayDecode(const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, const_data_ptr_t data) { - VariantValue ret(VariantValueType::ARRAY); - - auto field_offset_size = value_metadata.field_offset_size; - auto is_large = value_metadata.is_large; - - uint32_t num_elements; - if (is_large) { - num_elements = Load(data); - data += sizeof(uint32_t); - } else { - num_elements = Load(data); - data += sizeof(uint8_t); - } - - auto field_offsets = data; - auto values = field_offsets + (NumericCast(num_elements) + 1) * field_offset_size; - - idx_t last_offset = ReadVariableLengthLittleEndian(field_offset_size, field_offsets); - for (idx_t i = 0; i < num_elements; i++) { - auto next_offset = ReadVariableLengthLittleEndian(field_offset_size, field_offsets); - - ret.AddItem(Decode(metadata, values + last_offset)); - last_offset = next_offset; - } - return ret; -} - -VariantValue VariantBinaryDecoder::Decode(const VariantMetadata &variant_metadata, const_data_ptr_t data) { - auto value_metadata = VariantValueMetadata::FromHeaderByte(data[0]); - - data++; - switch (value_metadata.basic_type) { - case VariantBasicType::PRIMITIVE: { - return PrimitiveTypeDecode(value_metadata, data); - } - case VariantBasicType::SHORT_STRING: { - return ShortStringDecode(value_metadata, data); - } - case VariantBasicType::OBJECT: { - return ObjectDecode(variant_metadata, value_metadata, data); - } - case VariantBasicType::ARRAY: { - return ArrayDecode(variant_metadata, value_metadata, data); - } - default: - throw InternalException("Unexpected value for VariantBasicType"); - } -} - } // namespace duckdb diff --git a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp deleted file mode 100644 index c6aff53e4..000000000 --- a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp +++ /dev/null @@ -1,579 +0,0 @@ -#include -#include -#include -#include - -#include "duckdb/common/vector/list_vector.hpp" -#include "duckdb/common/vector/struct_vector.hpp" -#include "reader/variant/variant_shredded_conversion.hpp" -#include "utf8proc_wrapper.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/blob.hpp" -#include "duckdb/common/assert.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/hugeint.hpp" -#include "duckdb/common/optional_ptr.hpp" -#include "duckdb/common/string.hpp" -#include "duckdb/common/typedefs.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/types/datetime.hpp" -#include "duckdb/common/types/selection_vector.hpp" -#include "duckdb/common/types/string_type.hpp" -#include "duckdb/common/types/validity_mask.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/types/variant_value.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/vector.hpp" -#include "duckdb/common/vector/unified_vector_format.hpp" -#include "reader/variant/variant_binary_decoder.hpp" - -namespace duckdb { - -template -struct ConvertShreddedValue { - static VariantValue Convert(T val); - static VariantValue ConvertDecimal(T val, uint8_t width, uint8_t scale) { - throw InternalException("ConvertShreddedValue::ConvertDecimal not implemented for type"); - } - static VariantValue ConvertBlob(T val) { - throw InternalException("ConvertShreddedValue::ConvertBlob not implemented for type"); - } -}; - -//! boolean -template <> -VariantValue ConvertShreddedValue::Convert(bool val) { - return VariantValue(Value::BOOLEAN(val)); -} -//! int8 -template <> -VariantValue ConvertShreddedValue::Convert(int8_t val) { - return VariantValue(Value::TINYINT(val)); -} -//! int16 -template <> -VariantValue ConvertShreddedValue::Convert(int16_t val) { - return VariantValue(Value::SMALLINT(val)); -} -//! int32 -template <> -VariantValue ConvertShreddedValue::Convert(int32_t val) { - return VariantValue(Value::INTEGER(val)); -} -//! int64 -template <> -VariantValue ConvertShreddedValue::Convert(int64_t val) { - return VariantValue(Value::BIGINT(val)); -} -//! float -template <> -VariantValue ConvertShreddedValue::Convert(float val) { - return VariantValue(Value::FLOAT(val)); -} -//! double -template <> -VariantValue ConvertShreddedValue::Convert(double val) { - return VariantValue(Value::DOUBLE(val)); -} -//! NOTE: decimal2 - not in the spec, but some writers create this regardless -template <> -VariantValue ConvertShreddedValue::ConvertDecimal(int16_t val, uint8_t width, uint8_t scale) { - return VariantValue(Value::DECIMAL(val, width, scale)); -} -//! decimal4/decimal8/decimal16 -template <> -VariantValue ConvertShreddedValue::ConvertDecimal(int32_t val, uint8_t width, uint8_t scale) { - return VariantValue(Value::DECIMAL(val, width, scale)); -} -template <> -VariantValue ConvertShreddedValue::ConvertDecimal(int64_t val, uint8_t width, uint8_t scale) { - return VariantValue(Value::DECIMAL(val, width, scale)); -} -template <> -VariantValue ConvertShreddedValue::ConvertDecimal(hugeint_t val, uint8_t width, uint8_t scale) { - return VariantValue(Value::DECIMAL(val, width, scale)); -} -//! date -template <> -VariantValue ConvertShreddedValue::Convert(date_t val) { - return VariantValue(Value::DATE(val)); -} -//! time -template <> -VariantValue ConvertShreddedValue::Convert(dtime_t val) { - return VariantValue(Value::TIME(val)); -} -//! timestamptz(6) -template <> -VariantValue ConvertShreddedValue::Convert(timestamp_tz_t val) { - return VariantValue(Value::TIMESTAMPTZ(val)); -} -//! timestamptz(9) -template <> -VariantValue ConvertShreddedValue::Convert(timestamp_tz_ns_t val) { - return VariantValue(Value::TIMESTAMPTZNS(val)); -} -//! timestampntz(6) -template <> -VariantValue ConvertShreddedValue::Convert(timestamp_t val) { - return VariantValue(Value::TIMESTAMP(val)); -} -//! timestampntz(9) -template <> -VariantValue ConvertShreddedValue::Convert(timestamp_ns_t val) { - return VariantValue(Value::TIMESTAMPNS(val)); -} -//! binary -template <> -VariantValue ConvertShreddedValue::ConvertBlob(string_t val) { - return VariantValue(Value(Blob::ToBase64(val))); -} -//! string -template <> -VariantValue ConvertShreddedValue::Convert(string_t val) { - if (!Utf8Proc::IsValid(val.GetData(), val.GetSize())) { - throw InternalException("Can't decode Variant string, it isn't valid UTF8"); - } - return VariantValue(Value(val.GetString())); -} -//! uuid -template <> -VariantValue ConvertShreddedValue::Convert(hugeint_t val) { - return VariantValue(Value::UUID(val)); -} - -template -vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &blob, idx_t offset, idx_t length, - idx_t total_size) { - UnifiedVectorFormat metadata_format; - metadata.ToUnifiedFormat(metadata_format); - auto metadata_data = metadata_format.GetData(metadata_format); - - UnifiedVectorFormat typed_format; - vec.ToUnifiedFormat(typed_format); - auto data = typed_format.GetData(typed_format); - - UnifiedVectorFormat value_format; - blob.ToUnifiedFormat(value_format); - auto value_data = value_format.GetData(value_format); - - auto &validity = typed_format.validity; - auto &value_validity = value_format.validity; - auto &type = vec.GetType(); - - //! Values only used for Decimal conversion - uint8_t width; - uint8_t scale; - if (TYPE_ID == LogicalTypeId::DECIMAL) { - type.GetDecimalProperties(width, scale); - } - - vector ret(length); - if (validity.CannotHaveNull()) { - for (idx_t i = 0; i < length; i++) { - auto index = typed_format.sel->get_index(i + offset); - if (TYPE_ID == LogicalTypeId::DECIMAL) { - ret[i] = OP::ConvertDecimal(data[index], width, scale); - } else if (TYPE_ID == LogicalTypeId::BLOB) { - ret[i] = OP::ConvertBlob(data[index]); - } else { - ret[i] = OP::Convert(data[index]); - } - } - } else { - for (idx_t i = 0; i < length; i++) { - auto typed_index = typed_format.sel->get_index(i + offset); - auto value_index = value_format.sel->get_index(i + offset); - if (validity.RowIsValid(typed_index)) { - //! This is a leaf, partially shredded values aren't possible here - D_ASSERT(!value_validity.RowIsValid(value_index)); - if (TYPE_ID == LogicalTypeId::DECIMAL) { - ret[i] = OP::ConvertDecimal(data[typed_index], width, scale); - } else if (TYPE_ID == LogicalTypeId::BLOB) { - ret[i] = OP::ConvertBlob(data[typed_index]); - } else { - ret[i] = OP::Convert(data[typed_index]); - } - } else { - if (!value_validity.RowIsValid(value_index)) { - //! Value is missing for this field - continue; - } - D_ASSERT(value_validity.RowIsValid(value_index)); - auto metadata_value = metadata_data[metadata_format.sel->get_index(i)]; - VariantMetadata variant_metadata(metadata_value); - ret[i] = VariantBinaryDecoder::Decode(variant_metadata, - const_data_ptr_cast(value_data[value_index].GetData())); - } - } - } - return ret; -} - -vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &metadata, Vector &value, - Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { - D_ASSERT(!typed_value.GetType().IsNested()); - vector result; - - auto &type = typed_value.GetType(); - switch (type.id()) { - //! boolean - case LogicalTypeId::BOOLEAN: { - return ConvertTypedValues, LogicalTypeId::BOOLEAN>( - typed_value, metadata, value, offset, length, total_size); - } - //! int8 - case LogicalTypeId::TINYINT: { - return ConvertTypedValues, LogicalTypeId::TINYINT>( - typed_value, metadata, value, offset, length, total_size); - } - //! int16 - case LogicalTypeId::SMALLINT: { - return ConvertTypedValues, LogicalTypeId::SMALLINT>( - typed_value, metadata, value, offset, length, total_size); - } - //! int32 - case LogicalTypeId::INTEGER: { - return ConvertTypedValues, LogicalTypeId::INTEGER>( - typed_value, metadata, value, offset, length, total_size); - } - //! int64 - case LogicalTypeId::BIGINT: { - return ConvertTypedValues, LogicalTypeId::BIGINT>( - typed_value, metadata, value, offset, length, total_size); - } - //! float - case LogicalTypeId::FLOAT: { - return ConvertTypedValues, LogicalTypeId::FLOAT>( - typed_value, metadata, value, offset, length, total_size); - } - //! double - case LogicalTypeId::DOUBLE: { - return ConvertTypedValues, LogicalTypeId::DOUBLE>( - typed_value, metadata, value, offset, length, total_size); - } - //! decimal4/decimal8/decimal16 - case LogicalTypeId::DECIMAL: { - auto physical_type = type.InternalType(); - switch (physical_type) { - case PhysicalType::INT16: { - //! NOTE: This is not spec compliant, but some writers shred DECIMAL2 - return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); - } - case PhysicalType::INT32: { - return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); - } - case PhysicalType::INT64: { - return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); - } - case PhysicalType::INT128: { - return ConvertTypedValues, LogicalTypeId::DECIMAL>( - typed_value, metadata, value, offset, length, total_size); - } - default: - throw NotImplementedException("Decimal with PhysicalType (%s) not implemented for shredded Variant", - EnumUtil::ToString(physical_type)); - } - } - //! date - case LogicalTypeId::DATE: { - return ConvertTypedValues, LogicalTypeId::DATE>( - typed_value, metadata, value, offset, length, total_size); - } - //! time - case LogicalTypeId::TIME: { - return ConvertTypedValues, LogicalTypeId::TIME>( - typed_value, metadata, value, offset, length, total_size); - } - //! timestamptz(6) - case LogicalTypeId::TIMESTAMP_TZ: { - return ConvertTypedValues, LogicalTypeId::TIMESTAMP_TZ>( - typed_value, metadata, value, offset, length, total_size); - } - //! timestamptz(9) - case LogicalTypeId::TIMESTAMP_TZ_NS: { - return ConvertTypedValues, - LogicalTypeId::TIMESTAMP_TZ_NS>(typed_value, metadata, value, offset, length, - total_size); - } - //! timestampntz(6) - case LogicalTypeId::TIMESTAMP: { - return ConvertTypedValues, LogicalTypeId::TIMESTAMP>( - typed_value, metadata, value, offset, length, total_size); - } - //! timestampntz(9) - case LogicalTypeId::TIMESTAMP_NS: { - return ConvertTypedValues, LogicalTypeId::TIMESTAMP_NS>( - typed_value, metadata, value, offset, length, total_size); - } - //! binary - case LogicalTypeId::BLOB: { - return ConvertTypedValues, LogicalTypeId::BLOB>( - typed_value, metadata, value, offset, length, total_size); - } - //! string - case LogicalTypeId::VARCHAR: { - return ConvertTypedValues, LogicalTypeId::VARCHAR>( - typed_value, metadata, value, offset, length, total_size); - } - //! uuid - case LogicalTypeId::UUID: { - return ConvertTypedValues, LogicalTypeId::UUID>( - typed_value, metadata, value, offset, length, total_size); - } - default: - throw NotImplementedException("Variant shredding on type: '%s' is not implemented", type.ToString()); - } -} - -namespace { - -struct ShreddedVariantField { -public: - explicit ShreddedVariantField(const Identifier &field_name) : field_name(field_name.GetIdentifierName()) { - } - -public: - string field_name; - //! Values for the field, for all rows - vector values; -}; - -} // namespace - -static vector ConvertBinaryEncoding(Vector &metadata, Vector &value, idx_t offset, idx_t length, - idx_t total_size) { - UnifiedVectorFormat value_format; - value.ToUnifiedFormat(value_format); - auto value_data = value_format.GetData(value_format); - auto &validity = value_format.validity; - - UnifiedVectorFormat metadata_format; - metadata.ToUnifiedFormat(metadata_format); - auto metadata_data = metadata_format.GetData(metadata_format); - auto metadata_validity = metadata_format.validity; - - //! Fills every row with MISSING, turned into NULL later if this is not in an OBJECT field - vector ret(length); - for (idx_t i = 0; i < length; i++) { - auto index = value_format.sel->get_index(i + offset); - if (validity.RowIsValid(index)) { - auto &metadata_value = metadata_data[metadata_format.sel->get_index(i)]; - VariantMetadata variant_metadata(metadata_value); - auto binary_value = value_data[index].GetData(); - ret[i] = VariantBinaryDecoder::Decode(variant_metadata, const_data_ptr_cast(binary_value)); - } - } - return ret; -} - -static VariantValue ConvertPartiallyShreddedObject(vector &shredded_fields, - const UnifiedVectorFormat &metadata_format, - const UnifiedVectorFormat &value_format, idx_t i, idx_t offset) { - auto ret = VariantValue(VariantValueType::OBJECT); - auto index = value_format.sel->get_index(i + offset); - auto value_data = value_format.GetData(value_format); - auto metadata_data = metadata_format.GetData(metadata_format); - auto &value_validity = value_format.validity; - - for (idx_t field_index = 0; field_index < shredded_fields.size(); field_index++) { - auto &shredded_field = shredded_fields[field_index]; - auto &field_value = shredded_field.values[i]; - ret.AddChild(shredded_field.field_name, std::move(field_value)); - } - - if (value_validity.RowIsValid(index)) { - //! Object is partially shredded, decode the object and merge the values - auto &metadata_value = metadata_data[metadata_format.sel->get_index(i)]; - VariantMetadata variant_metadata(metadata_value); - auto binary_value = value_data[index].GetData(); - auto unshredded = VariantBinaryDecoder::Decode(variant_metadata, const_data_ptr_cast(binary_value)); - if (unshredded.value_type != VariantValueType::OBJECT) { - throw InvalidInputException("Partially shredded objects have to encode Object Variants in the 'value'"); - } - auto object_children = unshredded.TakeObjectChildren(); - for (auto &item : object_children) { - ret.AddChild(item.first, std::move(item.second)); - } - } - return ret; -} - -vector VariantShreddedConversion::ConvertShreddedObject(Vector &metadata, Vector &value, - Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { - auto &type = typed_value.GetType(); - D_ASSERT(type.id() == LogicalTypeId::STRUCT); - auto &fields = StructType::GetChildTypes(type); - auto &entries = StructVector::GetEntries(typed_value); - D_ASSERT(entries.size() == fields.size()); - - //! 'value' - UnifiedVectorFormat value_format; - value.ToUnifiedFormat(value_format); - auto value_data = value_format.GetData(value_format); - auto &validity = value_format.validity; - (void)validity; - - //! 'metadata' - UnifiedVectorFormat metadata_format; - metadata.ToUnifiedFormat(metadata_format); - auto metadata_data = metadata_format.GetData(metadata_format); - - //! 'typed_value' - UnifiedVectorFormat typed_format; - typed_value.ToUnifiedFormat(typed_format); - auto &typed_validity = typed_format.validity; - - //! Process all fields to get the shredded field values - vector shredded_fields; - shredded_fields.reserve(fields.size()); - for (idx_t i = 0; i < fields.size(); i++) { - auto &field = fields[i]; - auto &field_name = field.first; - auto &field_vec = entries[i]; - - shredded_fields.emplace_back(field_name); - auto &shredded_field = shredded_fields.back(); - shredded_field.values = Convert(metadata, field_vec, offset, length, total_size); - } - - vector ret(length); - if (typed_validity.CannotHaveNull()) { - for (idx_t i = 0; i < length; i++) { - ret[i] = ConvertPartiallyShreddedObject(shredded_fields, metadata_format, value_format, i, offset); - } - } else { - //! For some of the rows, the value is not an object - for (idx_t i = 0; i < length; i++) { - auto typed_index = typed_format.sel->get_index(i + offset); - auto value_index = value_format.sel->get_index(i + offset); - if (typed_validity.RowIsValid(typed_index)) { - ret[i] = ConvertPartiallyShreddedObject(shredded_fields, metadata_format, value_format, i, offset); - } else { - if (!validity.RowIsValid(value_index)) { - //! This object is a field in the parent object, the value is missing, skip it - continue; - } - D_ASSERT(validity.RowIsValid(value_index)); - auto &metadata_value = metadata_data[metadata_format.sel->get_index(i)]; - VariantMetadata variant_metadata(metadata_value); - auto binary_value = value_data[value_index].GetData(); - ret[i] = VariantBinaryDecoder::Decode(variant_metadata, const_data_ptr_cast(binary_value)); - if (ret[i].value_type == VariantValueType::OBJECT) { - throw InvalidInputException( - "When 'typed_value' for a shredded Object is NULL, 'value' can not contain an Object value"); - } - } - } - } - return ret; -} - -vector VariantShreddedConversion::ConvertShreddedArray(Vector &metadata, Vector &value, - Vector &typed_value, idx_t offset, idx_t length, - idx_t total_size) { - auto &child = ListVector::GetChildMutable(typed_value); - auto list_size = ListVector::GetListSize(typed_value); - - //! 'value' - UnifiedVectorFormat value_format; - value.ToUnifiedFormat(value_format); - auto value_data = value_format.GetData(value_format); - - //! 'metadata' - UnifiedVectorFormat metadata_format; - metadata.ToUnifiedFormat(metadata_format); - auto metadata_data = metadata_format.GetData(metadata_format); - - //! 'typed_value' - UnifiedVectorFormat list_format; - typed_value.ToUnifiedFormat(list_format); - auto list_data = list_format.GetData(list_format); - auto &validity = list_format.validity; - auto &value_validity = value_format.validity; - - vector ret(length); - if (validity.CannotHaveNull()) { - //! We can be sure that none of the values are binary encoded - for (idx_t i = 0; i < length; i++) { - auto typed_index = list_format.sel->get_index(i + offset); - auto entry = list_data[typed_index]; - Vector child_metadata(metadata.GetValue(i), count_t(entry.length)); - ret[i] = VariantValue(VariantValueType::ARRAY); - ret[i].SetItems(Convert(child_metadata, child, entry.offset, entry.length, list_size)); - } - } else { - for (idx_t i = 0; i < length; i++) { - auto typed_index = list_format.sel->get_index(i + offset); - auto value_index = value_format.sel->get_index(i + offset); - if (validity.RowIsValid(typed_index)) { - auto entry = list_data[typed_index]; - Vector child_metadata(metadata.GetValue(i), count_t(entry.length)); - ret[i] = VariantValue(VariantValueType::ARRAY); - ret[i].SetItems(Convert(child_metadata, child, entry.offset, entry.length, list_size)); - } else { - if (!value_validity.RowIsValid(value_index)) { - //! Value is missing for this field - continue; - } - D_ASSERT(value_validity.RowIsValid(value_index)); - auto metadata_value = metadata_data[metadata_format.sel->get_index(i)]; - VariantMetadata variant_metadata(metadata_value); - ret[i] = VariantBinaryDecoder::Decode(variant_metadata, - const_data_ptr_cast(value_data[value_index].GetData())); - } - } - } - return ret; -} - -vector VariantShreddedConversion::Convert(Vector &metadata, Vector &group, idx_t offset, idx_t length, - idx_t total_size) { - D_ASSERT(group.GetType().id() == LogicalTypeId::STRUCT); - - auto &group_entries = StructVector::GetEntries(group); - auto &group_type_children = StructType::GetChildTypes(group.GetType()); - D_ASSERT(group_type_children.size() == group_entries.size()); - - //! From the spec: - //! The Parquet columns used to store variant metadata and values must be accessed by name, not by position. - optional_ptr value; - optional_ptr typed_value; - for (idx_t i = 0; i < group_entries.size(); i++) { - auto &name = group_type_children[i].first; - auto &vec = group_entries[i]; - if (name == "value") { - value = &vec; - } else if (name == "typed_value") { - typed_value = &vec; - } else { - throw InvalidInputException("Variant group can only contain 'value'/'typed_value', not: %s", name); - } - } - if (!value) { - throw InvalidInputException("Required column 'value' not found in Variant group"); - } - - if (typed_value) { - auto &type = typed_value->GetType(); - vector ret; - if (type.id() == LogicalTypeId::STRUCT) { - return ConvertShreddedObject(metadata, *value, *typed_value, offset, length, total_size); - } else if (type.id() == LogicalTypeId::LIST) { - return ConvertShreddedArray(metadata, *value, *typed_value, offset, length, total_size); - } else { - return ConvertShreddedLeaf(metadata, *value, *typed_value, offset, length, total_size); - } - } else { - return ConvertBinaryEncoding(metadata, *value, offset, length, total_size); - } -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/reader/variant_column_reader.cpp b/src/duckdb/extension/parquet/reader/variant_column_reader.cpp index f0d450171..16a863fdd 100644 --- a/src/duckdb/extension/parquet/reader/variant_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/variant_column_reader.cpp @@ -6,7 +6,7 @@ #include "duckdb/common/vector/struct_vector.hpp" #include "reader/variant_column_reader.hpp" -#include "reader/variant/variant_shredded_conversion.hpp" +#include "reader/variant/parquet_variant_iterator.hpp" #include "column_reader.hpp" #include "duckdb/common/assert.hpp" #include "duckdb/common/constants.hpp" @@ -15,7 +15,6 @@ #include "duckdb/common/optional_ptr.hpp" #include "duckdb/common/typedefs.hpp" #include "duckdb/common/types.hpp" -#include "duckdb/common/types/variant_value.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/unique_ptr.hpp" #include "duckdb/common/vector.hpp" @@ -119,7 +118,6 @@ idx_t VariantColumnReader::Read(ColumnReaderInput &input, Vector &result) { "The Variant column did not contain the same amount of values for 'metadata' and 'value'"); } - vector intermediate; if (typed_value_reader) { ColumnReaderInput child_input(num_values, define_out, repeat_out); auto typed_values = typed_value_reader->Read(child_input, group_entries[1]); @@ -128,9 +126,7 @@ idx_t VariantColumnReader::Read(ColumnReaderInput &input, Vector &result) { "The shredded Variant column did not contain the same amount of values for 'typed_value' and 'value'"); } } - intermediate = - VariantShreddedConversion::Convert(metadata_intermediate, intermediate_group, 0, num_values, num_values); - VariantValue::ToVARIANT(intermediate, result); + ParquetVariantConversion::Convert(metadata_intermediate, intermediate_group, result, num_values); read_count = value_values; return read_count.GetIndex(); diff --git a/src/duckdb/src/catalog/catalog_set.cpp b/src/duckdb/src/catalog/catalog_set.cpp index f96b398ab..fd931168f 100644 --- a/src/duckdb/src/catalog/catalog_set.cpp +++ b/src/duckdb/src/catalog/catalog_set.cpp @@ -361,6 +361,8 @@ bool CatalogSet::AlterEntry(CatalogTransaction transaction, const Identifier &na // Mark this entry as being created by this transaction value->timestamp = transaction.transaction_id; value->set = this; + // Preserve the oid across the alter: an altered entry is the same logical object as before + value->oid = entry->oid; if (!(value->name == entry->name)) { if (!RenameEntryInternal(transaction, *entry, value->name, alter_info, read_lock)) { @@ -463,7 +465,7 @@ void CatalogSet::VerifyExistenceOfDependency(transaction_t commit_id, CatalogEnt // Make sure that we don't see any uncommitted changes auto transaction_id = MAX_TRANSACTION_ID; // This will allow us to see all committed changes made before this COMMIT happened - auto tx_start_time = commit_id; + auto tx_start_time = commit_id + 1; CatalogTransaction commit_transaction(duck_catalog.GetDatabase(), transaction_id, tx_start_time); D_ASSERT(entry.type == CatalogType::DEPENDENCY_ENTRY); diff --git a/src/duckdb/src/catalog/dependency_manager.cpp b/src/duckdb/src/catalog/dependency_manager.cpp index 540cb11ef..52c8d3a8d 100644 --- a/src/duckdb/src/catalog/dependency_manager.cpp +++ b/src/duckdb/src/catalog/dependency_manager.cpp @@ -220,6 +220,9 @@ void DependencyManager::CreateDependent(CatalogTransaction transaction, const De } void DependencyManager::CreateDependency(CatalogTransaction transaction, DependencyInfo &info) { + auto subject_entry = LookupEntry(transaction, info.subject.entry); + info.subject.oid = subject_entry ? subject_entry->oid : optional_idx(); + DependencyCatalogSet subjects(Subjects(), info.dependent.entry); DependencyCatalogSet dependents(Dependents(), info.subject.entry); @@ -277,8 +280,9 @@ void DependencyManager::CreateDependencies(CatalogTransaction transaction, const // add the object to the dependents_map of each object that it depends on for (auto &dependency : dependencies.Set()) { - DependencyInfo info {/*dependent = */ DependencyDependent {GetLookupProperties(object), dependency_flags}, - /*subject = */ DependencySubject {dependency.entry, DependencySubjectFlags()}}; + DependencyInfo info { + /*dependent = */ DependencyDependent {GetLookupProperties(object), dependency_flags}, + /*subject = */ DependencySubject {dependency.entry, DependencySubjectFlags(), optional_idx()}}; CreateDependency(transaction, info); } } @@ -315,12 +319,8 @@ CatalogEntryInfo DependencyManager::GetLookupProperties(const CatalogEntry &entr } } -optional_ptr DependencyManager::LookupEntry(CatalogTransaction transaction, CatalogEntry &dependency) { - if (dependency.type != CatalogType::DEPENDENCY_ENTRY) { - return &dependency; - } - auto info = GetLookupProperties(dependency); - +optional_ptr DependencyManager::LookupEntry(CatalogTransaction transaction, + const CatalogEntryInfo &info) { auto &type = info.type; auto &schema = info.schema; auto &name = info.name; @@ -331,8 +331,14 @@ optional_ptr DependencyManager::LookupEntry(CatalogTransaction tra // This is a schema entry, perform the callback only providing the schema return reinterpret_cast(schema_entry.get()); } - auto entry = schema_entry->GetEntry(transaction, type, name); - return entry; + return schema_entry->GetEntry(transaction, type, name); +} + +optional_ptr DependencyManager::LookupEntry(CatalogTransaction transaction, CatalogEntry &dependency) { + if (dependency.type != CatalogType::DEPENDENCY_ENTRY) { + return &dependency; + } + return LookupEntry(transaction, GetLookupProperties(dependency)); } void DependencyManager::CleanupDependencies(CatalogTransaction transaction, CatalogEntry &object) { @@ -471,6 +477,13 @@ void DependencyManager::VerifyExistence(CatalogTransaction transaction, Dependen throw DependencyException("Could not commit creation of dependency, subject \"%s\" has been deleted", object.SourceInfo().name); } + // The subject still exists by name - check if it is the same object the dependency was created against + if (!subject.flags.IsOwnership() && subject.oid.IsValid() && lookup_result.result && + lookup_result.result->oid != subject.oid.GetIndex()) { + throw DependencyException( + "Could not commit creation of dependency, subject \"%s\" was dropped and re-created by another transaction", + object.EntryInfo().name); + } } void DependencyManager::VerifyCommitDrop(CatalogTransaction transaction, transaction_t start_time, @@ -791,7 +804,8 @@ void DependencyManager::AddOwnership(CatalogTransaction transaction, CatalogEntr DependencyInfo info { /*dependent = */ DependencyDependent {GetLookupProperties(owner), DependencyDependentFlags().SetOwnedBy()}, - /*subject = */ DependencySubject {GetLookupProperties(entry), DependencySubjectFlags().SetOwnership()}}; + /*subject = */ DependencySubject {GetLookupProperties(entry), DependencySubjectFlags().SetOwnership(), + optional_idx()}}; CreateDependency(transaction, info); } diff --git a/src/duckdb/src/common/adbc/adbc.cpp b/src/duckdb/src/common/adbc/adbc.cpp index 36843340a..02eb883df 100644 --- a/src/duckdb/src/common/adbc/adbc.cpp +++ b/src/duckdb/src/common/adbc/adbc.cpp @@ -18,6 +18,15 @@ #include #include static void ReleaseError(struct AdbcError *error); +static void ReleaseErrorWithDuckDBDetails(struct AdbcError *error); +static void ReleaseStreamErrorDetails(struct AdbcError *error); +static const char *DuckDBErrorTypeToString(duckdb_error_type type); +static void AppendDuckDBErrorDetails(struct AdbcError *error, duckdb_error_type type); +static void AppendDuckDBErrorDetails(struct AdbcError *error, duckdb_error_data res); + +struct DuckDBErrorDetails { + std::vector> entries; +}; #include @@ -67,9 +76,8 @@ AdbcStatusCode duckdb_adbc_init(int version, void *driver, struct AdbcError *err // Initialize 1.1.0 function pointers if version >= 1.1.0 if (version >= ADBC_VERSION_1_1_0) { - // TODO: ADBC 1.1.0 adds support for these functions - adbc_driver->ErrorGetDetailCount = nullptr; - adbc_driver->ErrorGetDetail = nullptr; + adbc_driver->ErrorGetDetailCount = duckdb_adbc::ErrorGetDetailCount; + adbc_driver->ErrorGetDetail = duckdb_adbc::ErrorGetDetail; adbc_driver->ErrorFromArrayStream = duckdb_adbc::ErrorFromArrayStream; adbc_driver->DatabaseGetOption = duckdb_adbc::DatabaseGetOption; @@ -239,9 +247,11 @@ void InitializeADBCError(AdbcError *error) { if (!error) { return; } - // Avoid leaking any DuckDB-owned error message. - // Only call DuckDB's own release callback. - if (error->message && error->release == ::ReleaseError) { + // Only call release for callbacks DuckDB owns. The stream wrapper sets + // adbc_error.message = last_error (strdup) with release = nullptr; calling + // delete[] on that would be a double-free and allocator mismatch. + if (error->release == ::ReleaseError || error->release == ::ReleaseErrorWithDuckDBDetails || + error->release == ::ReleaseStreamErrorDetails) { error->release(error); } error->message = nullptr; @@ -1051,6 +1061,10 @@ static int get_next(struct ArrowArrayStream *stream, struct ArrowArray *out) { if (result_wrapper->materialized) { auto mat = result_wrapper->materialized; if (mat->current >= mat->count) { + // Surface any error that was encountered during materialization + if (result_wrapper->last_error) { + return DuckDBError; + } return DuckDBSuccess; // end of stream } // Transfer ownership of the batch to the caller @@ -1070,10 +1084,22 @@ static int get_next(struct ArrowArrayStream *stream, struct ArrowArray *out) { } result_wrapper->last_error = strdup(err); result_wrapper->status_code = IsInterruptError(err) ? ADBC_STATUS_CANCELLED : ADBC_STATUS_INTERNAL; - // Populate adbc_error for AdbcErrorFromArrayStream + // Populate adbc_error for AdbcErrorFromArrayStream with rich metadata result_wrapper->adbc_error.message = result_wrapper->last_error; - result_wrapper->adbc_error.vendor_code = 0; - result_wrapper->adbc_error.release = nullptr; + result_wrapper->adbc_error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; + if (result_wrapper->adbc_error.private_data) { + delete static_cast(result_wrapper->adbc_error.private_data); + result_wrapper->adbc_error.private_data = nullptr; + } + auto *details = new (std::nothrow) DuckDBErrorDetails(); + if (details) { + details->entries.emplace_back( + "duckdb:error_type", DuckDBErrorTypeToString(duckdb_result_error_type(&result_wrapper->result))); + result_wrapper->adbc_error.private_data = details; + result_wrapper->adbc_error.release = ::ReleaseStreamErrorDetails; + } else { + result_wrapper->adbc_error.release = nullptr; + } return DuckDBError; } return DuckDBSuccess; @@ -1085,6 +1111,29 @@ static int get_next(struct ArrowArrayStream *stream, struct ArrowArray *out) { duckdb_destroy_data_chunk(&duckdb_chunk); if (conversion_success) { + auto conv_err_msg = duckdb_error_data_message(conversion_success); + if (conv_err_msg && conv_err_msg[0] != '\0') { + if (result_wrapper->last_error) { + free(result_wrapper->last_error); + } + result_wrapper->last_error = strdup(conv_err_msg); + result_wrapper->status_code = ADBC_STATUS_INTERNAL; + result_wrapper->adbc_error.message = result_wrapper->last_error; + result_wrapper->adbc_error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; + if (result_wrapper->adbc_error.private_data) { + delete static_cast(result_wrapper->adbc_error.private_data); + result_wrapper->adbc_error.private_data = nullptr; + } + auto *details = new (std::nothrow) DuckDBErrorDetails(); + if (details) { + details->entries.emplace_back( + "duckdb:error_type", DuckDBErrorTypeToString(duckdb_error_data_error_type(conversion_success))); + result_wrapper->adbc_error.private_data = details; + result_wrapper->adbc_error.release = ::ReleaseStreamErrorDetails; + } else { + result_wrapper->adbc_error.release = nullptr; + } + } duckdb_destroy_error_data(&conversion_success); return DuckDBError; } @@ -1290,6 +1339,7 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *catalog, const c auto res = duckdb_schema_from_arrow(connection, &arrow_schema_wrapper.arrow_schema, out_types.GetPtr()); if (res) { SetError(error, duckdb_error_data_message(res)); + AppendDuckDBErrorDetails(error, res); duckdb_destroy_error_data(&res); return ADBC_STATUS_INTERNAL; } @@ -1309,6 +1359,7 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *catalog, const c bool already_exists = error_msg && std::string(error_msg).find("already exists") != std::string::npos; bool interrupted = IsInterruptError(error_msg); SetError(error, error_msg); + AppendDuckDBErrorDetails(error, duckdb_result_error_type(&result)); duckdb_destroy_result(&result); if (interrupted) { return ADBC_STATUS_CANCELLED; @@ -1333,6 +1384,7 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *catalog, const c if (duckdb_query(connection, create_sql.c_str(), &result) == DuckDBError) { auto err = duckdb_result_error(&result); SetError(error, err); + AppendDuckDBErrorDetails(error, duckdb_result_error_type(&result)); bool interrupted = IsInterruptError(err); duckdb_destroy_result(&result); return interrupted ? ADBC_STATUS_CANCELLED : ADBC_STATUS_INTERNAL; @@ -1347,6 +1399,7 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *catalog, const c if (duckdb_query(connection, sql.c_str(), &result) == DuckDBError) { auto err = duckdb_result_error(&result); SetError(error, err); + AppendDuckDBErrorDetails(error, duckdb_result_error_type(&result)); bool interrupted = IsInterruptError(err); duckdb_destroy_result(&result); return interrupted ? ADBC_STATUS_CANCELLED : ADBC_STATUS_INTERNAL; @@ -1359,6 +1412,7 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *catalog, const c if (!appender.Valid()) { if (!appender.CreateError().empty()) { set_ingest_error(appender.CreateError()); + AppendDuckDBErrorDetails(error, appender.CreateErrorType()); } else { SetError(error, missing_table_error); } @@ -1376,6 +1430,7 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *catalog, const c &out_chunk.chunk); if (res) { SetError(error, duckdb_error_data_message(res)); + AppendDuckDBErrorDetails(error, res); duckdb_destroy_error_data(&res); } // Count rows for rows_affected, if a chunk was produced @@ -1392,6 +1447,7 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *catalog, const c SetError(error, missing_table_error); } bool interrupted = IsInterruptError(err); + AppendDuckDBErrorDetails(error, error_data); duckdb_destroy_error_data(&error_data); return interrupted ? ADBC_STATUS_CANCELLED : ADBC_STATUS_INTERNAL; } @@ -1543,6 +1599,7 @@ AdbcStatusCode StatementExecuteSchema(struct AdbcStatement *statement, struct Ar if (res) { SetError(error, duckdb_error_data_message(res)); + AppendDuckDBErrorDetails(error, res); duckdb_destroy_error_data(&res); return ADBC_STATUS_INVALID_ARGUMENT; } @@ -1600,6 +1657,7 @@ AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, stru if (res) { SetError(error, duckdb_error_data_message(res)); + AppendDuckDBErrorDetails(error, res); duckdb_destroy_error_data(&res); return ADBC_STATUS_INVALID_ARGUMENT; } @@ -1697,6 +1755,7 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr duckdb_schema_from_arrow(wrapper->connection, &arrow_schema_wrapper.arrow_schema, out_types.GetPtr()); if (res) { SetError(error, duckdb_error_data_message(res)); + AppendDuckDBErrorDetails(error, res); duckdb_destroy_error_data(&res); return ADBC_STATUS_INVALID_ARGUMENT; } @@ -1715,6 +1774,7 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr out_types.Get(), &out_chunk.chunk); if (res_conv) { SetError(error, duckdb_error_data_message(res_conv)); + AppendDuckDBErrorDetails(error, res_conv); duckdb_destroy_error_data(&res_conv); return ADBC_STATUS_INVALID_ARGUMENT; } @@ -1752,6 +1812,7 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr if (res != DuckDBSuccess) { auto err = duckdb_result_error(&stream_wrapper->result); SetError(error, err); + AppendDuckDBErrorDetails(error, duckdb_result_error_type(&stream_wrapper->result)); bool interrupted = IsInterruptError(err); return interrupted ? ADBC_STATUS_CANCELLED : ADBC_STATUS_INVALID_ARGUMENT; } @@ -1764,6 +1825,7 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr if (res != DuckDBSuccess) { auto err = duckdb_result_error(&stream_wrapper->result); SetError(error, err); + AppendDuckDBErrorDetails(error, duckdb_result_error_type(&stream_wrapper->result)); bool interrupted = IsInterruptError(err); return interrupted ? ADBC_STATUS_CANCELLED : ADBC_STATUS_INVALID_ARGUMENT; } @@ -2386,26 +2448,92 @@ AdbcStatusCode ConnectionGetObjects(struct AdbcConnection *connection, int depth LIST({ column_name: column_name, ordinal_position: ordinal_position, - remarks: '', - xdbc_data_type: NULL::SMALLINT, - xdbc_type_name: NULL::VARCHAR, - xdbc_column_size: NULL::INTEGER, - xdbc_decimal_digits: NULL::SMALLINT, - xdbc_num_prec_radix: NULL::SMALLINT, - xdbc_nullable: NULL::SMALLINT, - xdbc_column_def: NULL::VARCHAR, - xdbc_sql_data_type: NULL::SMALLINT, - xdbc_datetime_sub: NULL::SMALLINT, - xdbc_char_octet_length: NULL::INTEGER, - xdbc_is_nullable: NULL::VARCHAR, - xdbc_scope_catalog: NULL::VARCHAR, + remarks: comment, + xdbc_data_type: NULL::SMALLINT, -- Arrow type ID not derivable from SQL; SQL type codes are in xdbc_sql_data_type + xdbc_type_name: data_type, + xdbc_column_size: CASE + WHEN base_type = 'DATE' THEN 10::INTEGER + WHEN data_type IN ('TIME', 'TIME WITH TIME ZONE', 'TIME_NS') THEN 15::INTEGER + WHEN data_type LIKE 'TIMESTAMP%%' THEN 26::INTEGER + ELSE numeric_precision::INTEGER + END, + xdbc_decimal_digits: numeric_scale::SMALLINT, + xdbc_num_prec_radix: numeric_precision_radix::SMALLINT, + xdbc_nullable: CASE is_nullable + WHEN FALSE THEN 0::SMALLINT + WHEN TRUE THEN 1::SMALLINT + ELSE 2::SMALLINT + END, + xdbc_column_def: column_default, + xdbc_sql_data_type: CASE + WHEN data_type = 'TIMESTAMP WITH TIME ZONE' THEN 2014::SMALLINT + WHEN data_type LIKE 'TIMESTAMP%%' THEN 93::SMALLINT + WHEN data_type = 'TIME WITH TIME ZONE' THEN 2013::SMALLINT + WHEN data_type LIKE '%%]' THEN 2003::SMALLINT + WHEN type_codes[base_type] IS NOT NULL THEN type_codes[base_type]::SMALLINT + ELSE 1111::SMALLINT -- Types.OTHER: aligned with DuckDB JDBC default for unmapped types + END, + xdbc_datetime_sub: CASE + WHEN base_type = 'DATE' THEN 1::SMALLINT + WHEN data_type LIKE 'TIMESTAMP%%' THEN 3::SMALLINT + WHEN data_type IN ('TIME', 'TIME WITH TIME ZONE', 'TIME_NS') THEN 2::SMALLINT + ELSE NULL::SMALLINT + END, + xdbc_char_octet_length: CASE + WHEN base_type IN ('VARCHAR', 'BLOB') THEN character_maximum_length::INTEGER + ELSE NULL::INTEGER + END, + xdbc_is_nullable: CASE is_nullable + WHEN FALSE THEN 'NO' + WHEN TRUE THEN 'YES' + ELSE '' + END, + xdbc_scope_catalog: NULL::VARCHAR, -- REF types not supported in DuckDB xdbc_scope_schema: NULL::VARCHAR, xdbc_scope_table: NULL::VARCHAR, - xdbc_is_autoincrement: NULL::BOOLEAN, - xdbc_is_generatedcolumn: NULL::BOOLEAN, + xdbc_is_autoincrement: NULL::BOOLEAN, -- not exposed via duckdb_columns() + xdbc_is_generatedcolumn: NULL::BOOLEAN, -- not exposed via duckdb_columns() }) table_columns - FROM information_schema.columns - WHERE column_name LIKE %s + FROM ( + SELECT + database_name AS table_catalog, + schema_name AS table_schema, + table_name, + column_name, + column_index AS ordinal_position, + comment, + column_default, + is_nullable, + numeric_scale, + numeric_precision, + numeric_precision_radix, + character_maximum_length, + data_type, + STRING_SPLIT(data_type, '(')[1] AS base_type, -- normalize typemods for type-code lookup + -- JDBC java.sql.Types-compatible codes, matching DuckDB JDBC where possible. + MAP { + 'BOOLEAN': 16, + 'TINYINT': -6, + 'UTINYINT': 5, + 'SMALLINT': 5, + 'USMALLINT': 4, + 'INTEGER': 4, + 'UINTEGER': -5, + 'BIGINT': -5, + 'FLOAT': 6, + 'DOUBLE': 8, + 'DATE': 91, + 'TIME': 92, + 'TIME_NS': 92, + 'VARCHAR': 12, + 'BLOB': 2004, + 'DECIMAL': 3, + 'BIT': -7, + 'STRUCT': 2002, + } AS type_codes + FROM duckdb_columns() + WHERE column_name LIKE %s + ) cols GROUP BY table_catalog, table_schema, table_name ), constraints AS ( @@ -2498,6 +2626,26 @@ AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection *connection, struct return QueryInternal(connection, out, q, error); } +int ErrorGetDetailCount(const struct AdbcError *error) { + if (!error || error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA || !error->private_data) { + return 0; + } + const auto *details = static_cast(error->private_data); + return static_cast(details->entries.size()); +} + +struct AdbcErrorDetail ErrorGetDetail(const struct AdbcError *error, int index) { + if (!error || error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA || !error->private_data) { + return {nullptr, nullptr, 0}; + } + const auto *details = static_cast(error->private_data); + if (index < 0 || static_cast(index) >= details->entries.size()) { + return {nullptr, nullptr, 0}; + } + const auto &entry = details->entries[static_cast(index)]; + return {entry.first.c_str(), reinterpret_cast(entry.second.c_str()), entry.second.size()}; +} + } // namespace duckdb_adbc void duckdb::DuckDBAdbcConnectionWrapper::RegisterStream(duckdb_adbc::DuckDBAdbcStreamWrapper *stream) { @@ -2520,7 +2668,9 @@ void duckdb::DuckDBAdbcConnectionWrapper::MaterializeStreams() { continue; } - // Collect remaining batches from the streaming result + // Collect remaining batches from the streaming result. Errors encountered mid-stream + // are stored on result_wrapper so that get_next can return buffered batches first + // and then surface the error once they are exhausted. duckdb::vector batches; auto arrow_options = duckdb_result_get_arrow_options(&result_wrapper->result); while (true) { @@ -2529,12 +2679,62 @@ void duckdb::DuckDBAdbcConnectionWrapper::MaterializeStreams() { auto duckdb_chunk = duckdb_fetch_chunk(result_wrapper->result); if (!duckdb_chunk) { + // End of stream or error; distinguish by checking the result error message. + auto err = duckdb_result_error(&result_wrapper->result); + if (err && err[0] != '\0') { + if (result_wrapper->last_error) { + free(result_wrapper->last_error); + } + result_wrapper->last_error = strdup(err); + result_wrapper->status_code = + duckdb_adbc::IsInterruptError(err) ? ADBC_STATUS_CANCELLED : ADBC_STATUS_INTERNAL; + result_wrapper->adbc_error.message = result_wrapper->last_error; + result_wrapper->adbc_error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; + if (result_wrapper->adbc_error.private_data) { + delete static_cast(result_wrapper->adbc_error.private_data); + result_wrapper->adbc_error.private_data = nullptr; + } + auto *details = new (std::nothrow) DuckDBErrorDetails(); + if (details) { + details->entries.emplace_back( + "duckdb:error_type", + DuckDBErrorTypeToString(duckdb_result_error_type(&result_wrapper->result))); + result_wrapper->adbc_error.private_data = details; + result_wrapper->adbc_error.release = ::ReleaseStreamErrorDetails; + } else { + result_wrapper->adbc_error.release = nullptr; + } + } break; } auto conversion_err = duckdb_data_chunk_to_arrow(arrow_options, duckdb_chunk, &array); duckdb_destroy_data_chunk(&duckdb_chunk); if (conversion_err) { + // Store error before freeing so get_next can surface it after buffered batches + auto conv_err_msg = duckdb_error_data_message(conversion_err); + if (conv_err_msg && conv_err_msg[0] != '\0') { + if (result_wrapper->last_error) { + free(result_wrapper->last_error); + } + result_wrapper->last_error = strdup(conv_err_msg); + result_wrapper->status_code = ADBC_STATUS_INTERNAL; + result_wrapper->adbc_error.message = result_wrapper->last_error; + result_wrapper->adbc_error.vendor_code = ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA; + if (result_wrapper->adbc_error.private_data) { + delete static_cast(result_wrapper->adbc_error.private_data); + result_wrapper->adbc_error.private_data = nullptr; + } + auto *details = new (std::nothrow) DuckDBErrorDetails(); + if (details) { + details->entries.emplace_back( + "duckdb:error_type", DuckDBErrorTypeToString(duckdb_error_data_error_type(conversion_err))); + result_wrapper->adbc_error.private_data = details; + result_wrapper->adbc_error.release = ::ReleaseStreamErrorDetails; + } else { + result_wrapper->adbc_error.release = nullptr; + } + } duckdb_destroy_error_data(&conversion_err); if (array.release) { array.release(&array); @@ -2626,3 +2826,134 @@ void SetError(struct AdbcError *error, const std::string &message) { } error->release = ReleaseError; } + +static void ReleaseStreamErrorDetails(struct AdbcError *error) { + if (!error) { + return; + } + // message is owned by the stream wrapper (last_error), not freed here + delete static_cast(error->private_data); + error->private_data = nullptr; + error->release = nullptr; +} + +static void ReleaseErrorWithDuckDBDetails(struct AdbcError *error) { + if (!error) { + return; + } + delete[] error->message; + error->message = nullptr; + delete static_cast(error->private_data); + error->private_data = nullptr; + error->release = nullptr; +} + +static const char *DuckDBErrorTypeToString(duckdb_error_type type) { + switch (type) { + case DUCKDB_ERROR_INVALID: + return "Invalid"; + case DUCKDB_ERROR_OUT_OF_RANGE: + return "OutOfRange"; + case DUCKDB_ERROR_CONVERSION: + return "Conversion"; + case DUCKDB_ERROR_UNKNOWN_TYPE: + return "UnknownType"; + case DUCKDB_ERROR_DECIMAL: + return "Decimal"; + case DUCKDB_ERROR_MISMATCH_TYPE: + return "MismatchType"; + case DUCKDB_ERROR_DIVIDE_BY_ZERO: + return "DivideByZero"; + case DUCKDB_ERROR_OBJECT_SIZE: + return "ObjectSize"; + case DUCKDB_ERROR_INVALID_TYPE: + return "InvalidType"; + case DUCKDB_ERROR_SERIALIZATION: + return "Serialization"; + case DUCKDB_ERROR_TRANSACTION: + return "Transaction"; + case DUCKDB_ERROR_NOT_IMPLEMENTED: + return "NotImplemented"; + case DUCKDB_ERROR_EXPRESSION: + return "Expression"; + case DUCKDB_ERROR_CATALOG: + return "Catalog"; + case DUCKDB_ERROR_PARSER: + return "Parser"; + case DUCKDB_ERROR_PLANNER: + return "Planner"; + case DUCKDB_ERROR_SCHEDULER: + return "Scheduler"; + case DUCKDB_ERROR_EXECUTOR: + return "Executor"; + case DUCKDB_ERROR_CONSTRAINT: + return "Constraint"; + case DUCKDB_ERROR_INDEX: + return "Index"; + case DUCKDB_ERROR_STAT: + return "Stat"; + case DUCKDB_ERROR_CONNECTION: + return "Connection"; + case DUCKDB_ERROR_SYNTAX: + return "Syntax"; + case DUCKDB_ERROR_SETTINGS: + return "Settings"; + case DUCKDB_ERROR_BINDER: + return "Binder"; + case DUCKDB_ERROR_NETWORK: + return "Network"; + case DUCKDB_ERROR_OPTIMIZER: + return "Optimizer"; + case DUCKDB_ERROR_NULL_POINTER: + return "NullPointer"; + case DUCKDB_ERROR_IO: + return "IO"; + case DUCKDB_ERROR_INTERRUPT: + return "Interrupt"; + case DUCKDB_ERROR_FATAL: + return "Fatal"; + case DUCKDB_ERROR_INTERNAL: + return "Internal"; + case DUCKDB_ERROR_INVALID_INPUT: + return "InvalidInput"; + case DUCKDB_ERROR_OUT_OF_MEMORY: + return "OutOfMemory"; + case DUCKDB_ERROR_PERMISSION: + return "Permission"; + case DUCKDB_ERROR_PARAMETER_NOT_RESOLVED: + return "ParameterNotResolved"; + case DUCKDB_ERROR_PARAMETER_NOT_ALLOWED: + return "ParameterNotAllowed"; + case DUCKDB_ERROR_DEPENDENCY: + return "Dependency"; + case DUCKDB_ERROR_HTTP: + return "HTTP"; + case DUCKDB_ERROR_MISSING_EXTENSION: + return "MissingExtension"; + case DUCKDB_ERROR_AUTOLOAD: + return "Autoload"; + case DUCKDB_ERROR_SEQUENCE: + return "Sequence"; + case DUCKDB_INVALID_CONFIGURATION: + return "InvalidConfiguration"; + default: + return "Unknown"; + } +} + +static void AppendDuckDBErrorDetails(struct AdbcError *error, duckdb_error_data res) { + AppendDuckDBErrorDetails(error, duckdb_error_data_error_type(res)); +} + +static void AppendDuckDBErrorDetails(struct AdbcError *error, duckdb_error_type type) { + if (!error || error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + return; + } + auto *details = new (std::nothrow) DuckDBErrorDetails(); + if (!details) { + return; + } + details->entries.emplace_back("duckdb:error_type", DuckDBErrorTypeToString(type)); + error->private_data = details; + error->release = ::ReleaseErrorWithDuckDBDetails; +} diff --git a/src/duckdb/src/common/allocator/allocator.cpp b/src/duckdb/src/common/allocator/allocator.cpp index 25eae3c17..4044cdb82 100644 --- a/src/duckdb/src/common/allocator/allocator.cpp +++ b/src/duckdb/src/common/allocator/allocator.cpp @@ -10,6 +10,10 @@ #include +#ifdef __GLIBC__ +#include +#endif + #ifdef DUCKDB_DEBUG_ALLOCATION #include "duckdb/common/mutex.hpp" #include "duckdb/common/pair.hpp" @@ -175,6 +179,28 @@ Allocator &Allocator::DefaultAllocator() { return *DefaultAllocatorReference(); } +void Allocator::MallocTrim(idx_t pad) { +#ifdef __GLIBC__ + static constexpr int64_t TRIM_INTERVAL_MS = 100; + static atomic LAST_TRIM_TIMESTAMP_MS {0}; + + int64_t last_trim_timestamp_ms = LAST_TRIM_TIMESTAMP_MS.load(); + auto current_ts = Timestamp::GetCurrentTimestamp(); + auto current_timestamp_ms = Cast::Operation(current_ts).value; + + if (current_timestamp_ms - last_trim_timestamp_ms < TRIM_INTERVAL_MS) { + return; // We trimmed less than TRIM_INTERVAL_MS ago + } + if (!LAST_TRIM_TIMESTAMP_MS.compare_exchange_strong(last_trim_timestamp_ms, current_timestamp_ms, + std::memory_order_acquire, std::memory_order_relaxed)) { + return; // Another thread has updated LAST_TRIM_TIMESTAMP_MS since we loaded it + } + + // We successfully updated LAST_TRIM_TIMESTAMP_MS, we can trim + malloc_trim(pad); +#endif +} + //===--------------------------------------------------------------------===// // Debug Info (extended) //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/allocator/allocator_jemalloc.cpp b/src/duckdb/src/common/allocator/allocator_jemalloc.cpp index 84735b16e..74e25516b 100644 --- a/src/duckdb/src/common/allocator/allocator_jemalloc.cpp +++ b/src/duckdb/src/common/allocator/allocator_jemalloc.cpp @@ -80,6 +80,11 @@ bool Allocator::SupportsFlush() { } void Allocator::ThreadFlush(bool allocator_background_threads, idx_t threshold, idx_t thread_count) { + // jemalloc only manages allocation done through the Allocator interface. + // Any allocations done directly through "malloc" or "operator new" still + // go to the system allocator. So we also trim the system heap here. + MallocTrim(thread_count * threshold); + if (!allocator_background_threads) { // We flush after exceeding the threshold if (GetJemallocCTL("thread.peak.read") <= threshold) { @@ -112,6 +117,9 @@ void Allocator::FlushAll() { // Reset the peak after resetting SetJemallocCTL("thread.peak.reset"); + + // Also return the system heap (see ThreadFlush) to the OS + MallocTrim(0); } void Allocator::SetBackgroundThreads(bool enable) { diff --git a/src/duckdb/src/common/allocator/allocator_standard.cpp b/src/duckdb/src/common/allocator/allocator_standard.cpp index 45e03d46b..da6a959c8 100644 --- a/src/duckdb/src/common/allocator/allocator_standard.cpp +++ b/src/duckdb/src/common/allocator/allocator_standard.cpp @@ -3,15 +3,8 @@ #ifndef DUCKDB_ENABLE_JEMALLOC #include "duckdb/common/assert.hpp" -#include "duckdb/common/atomic.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/helper.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/types/timestamp.hpp" - -#ifdef __GLIBC__ -#include -#endif #ifdef DUCKDB_DEBUG_ALLOCATION #include "duckdb/common/mutex.hpp" @@ -52,28 +45,6 @@ bool Allocator::SupportsFlush() { #endif } -static void MallocTrim(idx_t pad) { -#ifdef __GLIBC__ - static constexpr int64_t TRIM_INTERVAL_MS = 100; - static atomic LAST_TRIM_TIMESTAMP_MS {0}; - - int64_t last_trim_timestamp_ms = LAST_TRIM_TIMESTAMP_MS.load(); - auto current_ts = Timestamp::GetCurrentTimestamp(); - auto current_timestamp_ms = Cast::Operation(current_ts).value; - - if (current_timestamp_ms - last_trim_timestamp_ms < TRIM_INTERVAL_MS) { - return; // We trimmed less than TRIM_INTERVAL_MS ago - } - if (!LAST_TRIM_TIMESTAMP_MS.compare_exchange_strong(last_trim_timestamp_ms, current_timestamp_ms, - std::memory_order_acquire, std::memory_order_relaxed)) { - return; // Another thread has updated LAST_TRIM_TIMESTAMP_MS since we loaded it - } - - // We successfully updated LAST_TRIM_TIMESTAMP_MS, we can trim - malloc_trim(pad); -#endif -} - void Allocator::ThreadFlush(bool allocator_background_threads, idx_t threshold, idx_t thread_count) { MallocTrim(thread_count * threshold); } diff --git a/src/duckdb/src/common/arrow/appender/append_data.cpp b/src/duckdb/src/common/arrow/appender/append_data.cpp index 645b1bfcf..73595a88d 100644 --- a/src/duckdb/src/common/arrow/appender/append_data.cpp +++ b/src/duckdb/src/common/arrow/appender/append_data.cpp @@ -26,4 +26,18 @@ void ArrowAppendData::AppendValidity(UnifiedVectorFormat &format, idx_t from, id } } +void ArrowAppendData::AppendChild(const Vector &input, idx_t from, idx_t to, idx_t input_size) { + if (extension_data && extension_data->duckdb_to_arrow) { + // Convert the DuckDB-typed input into the extension's internal Arrow type before + // handing it to the (internal-typed) child appender. Size the internal vector to the + // actual input_size: container children can exceed STANDARD_VECTOR_SIZE (e.g. a 2048-row + // LIST whose elements average two entries), and duckdb_to_arrow writes input_size values. + Vector internal(extension_data->GetInternalType(), MaxValue(input_size, STANDARD_VECTOR_SIZE)); + extension_data->duckdb_to_arrow(*options.client_context, input, internal, input_size); + append_vector(*this, internal, from, to, input_size); + } else { + append_vector(*this, input, from, to, input_size); + } +} + } // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp b/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp index 9d5ff74b3..37dfdea07 100644 --- a/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp +++ b/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp @@ -24,7 +24,7 @@ void ArrowFixedSizeListData::Append(ArrowAppendData &append_data, const Vector & auto array_size = ArrayType::GetSize(input.GetType()); auto &child_vector = ArrayVector::GetChild(input); auto &child_data = *append_data.child_data[0]; - child_data.append_vector(child_data, child_vector, from * array_size, to * array_size, size * array_size); + child_data.AppendChild(child_vector, from * array_size, to * array_size, size * array_size); append_data.row_count += size; } diff --git a/src/duckdb/src/common/arrow/appender/struct_data.cpp b/src/duckdb/src/common/arrow/appender/struct_data.cpp index 018e82013..f70940336 100644 --- a/src/duckdb/src/common/arrow/appender/struct_data.cpp +++ b/src/duckdb/src/common/arrow/appender/struct_data.cpp @@ -27,7 +27,7 @@ void ArrowStructData::Append(ArrowAppendData &append_data, const Vector &input, for (idx_t child_idx = 0; child_idx < children.size(); child_idx++) { auto &child = children[child_idx]; auto &child_data = *append_data.child_data[child_idx]; - child_data.append_vector(child_data, child, from, to, size); + child_data.AppendChild(child, from, to, input_size); } append_data.row_count += size; } diff --git a/src/duckdb/src/common/arrow/appender/union_data.cpp b/src/duckdb/src/common/arrow/appender/union_data.cpp index 382af4b9b..eff070099 100644 --- a/src/duckdb/src/common/arrow/appender/union_data.cpp +++ b/src/duckdb/src/common/arrow/appender/union_data.cpp @@ -49,7 +49,7 @@ void ArrowUnionData::Append(ArrowAppendData &append_data, const Vector &input, i for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { auto &child_buffer = append_data.child_data[child_idx]; auto &child = child_vectors[child_idx]; - child_buffer->append_vector(*child_buffer, child, 0, size, size); + child_buffer->AppendChild(child, 0, size, size); } append_data.row_count += size; } diff --git a/src/duckdb/src/common/arrow/arrow_appender.cpp b/src/duckdb/src/common/arrow/arrow_appender.cpp index 6e9f5c018..0c524849e 100644 --- a/src/duckdb/src/common/arrow/arrow_appender.cpp +++ b/src/duckdb/src/common/arrow/arrow_appender.cpp @@ -8,6 +8,7 @@ #include "duckdb/common/arrow/appender/append_data.hpp" #include "duckdb/common/arrow/appender/list.hpp" #include "duckdb/function/table/arrow/arrow_duck_schema.hpp" +#include "duckdb/main/config.hpp" namespace duckdb { @@ -19,14 +20,13 @@ ArrowAppender::ArrowAppender(vector types_p, const idx_t initial_ca unordered_map> extension_type_cast) : types(std::move(types_p)), options(options) { for (idx_t i = 0; i < types.size(); i++) { - unique_ptr entry; - bool bitshift_boolean = types[i].id() == LogicalTypeId::BOOLEAN && !options.arrow_lossless_conversion; - if (extension_type_cast.find(i) != extension_type_cast.end() && !bitshift_boolean) { - entry = InitializeChild(types[i], initial_capacity, options, extension_type_cast[i]); - } else { - entry = InitializeChild(types[i], initial_capacity, options); - } - root_data.push_back(std::move(entry)); + // Pass any explicit per-column extension override through to InitializeChild; when none + // is supplied it auto-resolves the extension (and applies the bitshift_boolean gate) so + // children of nested types pick up the same extension SetArrowFormat uses for the schema. + auto extension_it = extension_type_cast.find(i); + shared_ptr extension = + extension_it != extension_type_cast.end() ? extension_it->second : nullptr; + root_data.push_back(InitializeChild(types[i], initial_capacity, options, extension)); } } @@ -38,14 +38,7 @@ void ArrowAppender::Append(DataChunk &input, const idx_t from, const idx_t to, c D_ASSERT(types == input.GetTypes()); D_ASSERT(to >= from); for (idx_t i = 0; i < input.ColumnCount(); i++) { - if (root_data[i]->extension_data && root_data[i]->extension_data->duckdb_to_arrow) { - Vector input_data(root_data[i]->extension_data->GetInternalType()); - root_data[i]->extension_data->duckdb_to_arrow(*options.client_context, input.data[i], input_data, - input_size); - root_data[i]->append_vector(*root_data[i], input_data, from, to, input_size); - } else { - root_data[i]->append_vector(*root_data[i], input.data[i], from, to, input_size); - } + root_data[i]->AppendChild(input.data[i], from, to, input_size); } row_count += to - from; } @@ -316,12 +309,28 @@ unique_ptr ArrowAppender::InitializeChild(const LogicalType &ty ClientProperties &options, const shared_ptr &extension_type) { auto result = make_uniq(options); + + // Resolve the effective extension. An explicit override (from the top-level appender) wins. + // Otherwise auto-resolve from DBConfig so nested children use the same extension SetArrowFormat + // declares in the schema. BOOLEAN stays plain bit-packed when arrow_lossless_conversion is off + // (the bitshift_boolean gate), applied here so it holds at every nesting level. + shared_ptr effective_extension = extension_type; + const bool bitshift_boolean = type.id() == LogicalTypeId::BOOLEAN && !options.arrow_lossless_conversion; + if (bitshift_boolean) { + effective_extension = nullptr; + } else if (!effective_extension && options.client_context) { + const auto &db_config = DBConfig::GetConfig(*options.client_context); + if (db_config.HasArrowExtension(type)) { + effective_extension = db_config.GetArrowExtension(type).GetTypeExtension(); + } + } + LogicalType array_type = type; - if (extension_type) { - array_type = extension_type->GetInternalType(); + if (effective_extension) { + array_type = effective_extension->GetInternalType(); } InitializeFunctionPointers(*result, array_type); - result->extension_data = extension_type; + result->extension_data = effective_extension; const auto byte_count = (capacity + 7) / 8; result->GetValidityBuffer().reserve(byte_count); diff --git a/src/duckdb/src/common/arrow/arrow_type_extension.cpp b/src/duckdb/src/common/arrow/arrow_type_extension.cpp index 24cbba223..bb43c8d29 100644 --- a/src/duckdb/src/common/arrow/arrow_type_extension.cpp +++ b/src/duckdb/src/common/arrow/arrow_type_extension.cpp @@ -357,7 +357,7 @@ struct ArrowBool8 { result_data.WriteValue(source_ptr[i]); } } - static void DuckToArrow(ClientContext &context, Vector &source, Vector &result, idx_t count) { + static void DuckToArrow(ClientContext &context, const Vector &source, Vector &result, idx_t count) { auto entries = source.Values(); auto result_data = FlatVector::Writer(result, count); for (idx_t i = 0; i < count; i++) { @@ -541,7 +541,7 @@ struct ArrowGeometry { Geometry::FromBinary(source, result, count, true); } - static void DuckToArrow(ClientContext &context, Vector &source, Vector &result, idx_t count) { + static void DuckToArrow(ClientContext &context, const Vector &source, Vector &result, idx_t count) { Geometry::ToBinary(source, result); } }; diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index 683913391..c07a15708 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -107,7 +107,6 @@ #include "duckdb/common/types/row/tuple_data_states.hpp" #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/types/variant.hpp" -#include "duckdb/common/types/variant_value.hpp" #include "duckdb/common/types/vector_buffer.hpp" #include "duckdb/common/vector/map_vector.hpp" #include "duckdb/common/vector/union_vector.hpp" @@ -221,20 +220,19 @@ namespace duckdb { const StringUtil::EnumStringLiteral *GetARTConflictTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(ARTConflictType::NO_CONFLICT), "NO_CONFLICT" }, - { static_cast(ARTConflictType::CONSTRAINT), "CONSTRAINT" }, - { static_cast(ARTConflictType::TRANSACTION), "TRANSACTION" } + { static_cast(ARTConflictType::CONSTRAINT), "CONSTRAINT" } }; return values; } template<> const char* EnumUtil::ToChars(ARTConflictType value) { - return StringUtil::EnumToString(GetARTConflictTypeValues(), 3, "ARTConflictType", static_cast(value)); + return StringUtil::EnumToString(GetARTConflictTypeValues(), 2, "ARTConflictType", static_cast(value)); } template<> ARTConflictType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetARTConflictTypeValues(), 3, "ARTConflictType", value)); + return static_cast(StringUtil::StringToEnum(GetARTConflictTypeValues(), 2, "ARTConflictType", value)); } const StringUtil::EnumStringLiteral *GetARTHandlingResultValues() { @@ -829,6 +827,7 @@ BinderType EnumUtil::FromString(const char *value) { const StringUtil::EnumStringLiteral *GetBindingModeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(BindingMode::STANDARD_BINDING), "STANDARD_BINDING" }, + { static_cast(BindingMode::PREPARE), "PREPARE" }, { static_cast(BindingMode::EXTRACT_NAMES), "EXTRACT_NAMES" }, { static_cast(BindingMode::EXTRACT_REPLACEMENT_SCANS), "EXTRACT_REPLACEMENT_SCANS" }, { static_cast(BindingMode::EXTRACT_QUALIFIED_NAMES), "EXTRACT_QUALIFIED_NAMES" } @@ -838,12 +837,12 @@ const StringUtil::EnumStringLiteral *GetBindingModeValues() { template<> const char* EnumUtil::ToChars(BindingMode value) { - return StringUtil::EnumToString(GetBindingModeValues(), 4, "BindingMode", static_cast(value)); + return StringUtil::EnumToString(GetBindingModeValues(), 5, "BindingMode", static_cast(value)); } template<> BindingMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetBindingModeValues(), 4, "BindingMode", value)); + return static_cast(StringUtil::StringToEnum(GetBindingModeValues(), 5, "BindingMode", value)); } const StringUtil::EnumStringLiteral *GetBitpackingModeValues() { @@ -6074,26 +6073,6 @@ VariantStatsShreddingState EnumUtil::FromString(cons return static_cast(StringUtil::StringToEnum(GetVariantStatsShreddingStateValues(), 4, "VariantStatsShreddingState", value)); } -const StringUtil::EnumStringLiteral *GetVariantValueTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(VariantValueType::PRIMITIVE), "PRIMITIVE" }, - { static_cast(VariantValueType::OBJECT), "OBJECT" }, - { static_cast(VariantValueType::ARRAY), "ARRAY" }, - { static_cast(VariantValueType::MISSING), "MISSING" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(VariantValueType value) { - return StringUtil::EnumToString(GetVariantValueTypeValues(), 4, "VariantValueType", static_cast(value)); -} - -template<> -VariantValueType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetVariantValueTypeValues(), 4, "VariantValueType", value)); -} - const StringUtil::EnumStringLiteral *GetVectorBufferTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(VectorBufferType::STANDARD_BUFFER), "STANDARD_BUFFER" }, diff --git a/src/duckdb/src/common/gzip_file_system.cpp b/src/duckdb/src/common/gzip_file_system.cpp index 003cf4a6f..5e1f0e15a 100644 --- a/src/duckdb/src/common/gzip_file_system.cpp +++ b/src/duckdb/src/common/gzip_file_system.cpp @@ -240,8 +240,11 @@ void MiniZStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t while (remaining > 0) { auto output_remaining = UnsafeNumericCast((sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start); + // miniz's avail_in is a platform-dependent unsigned int, cap ingestion bytes to avoid overflow. + auto avail_in = MinValue(remaining, NumericLimits::Maximum()); + mz_stream_ptr->next_in = reinterpret_cast(uncompressed_data); - mz_stream_ptr->avail_in = NumericCast(remaining); + mz_stream_ptr->avail_in = NumericCast(avail_in); mz_stream_ptr->next_out = sd.out_buff_start; mz_stream_ptr->avail_out = NumericCast(output_remaining); @@ -257,9 +260,9 @@ void MiniZStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t UnsafeNumericCast(sd.out_buff_start - sd.out_buff.get())); sd.out_buff_start = sd.out_buff.get(); } - auto written = UnsafeNumericCast(remaining - mz_stream_ptr->avail_in); + auto written = NumericCast(avail_in - mz_stream_ptr->avail_in); uncompressed_data += written; - remaining = mz_stream_ptr->avail_in; + remaining -= NumericCast(written); } } diff --git a/src/duckdb/src/common/local_file_system.cpp b/src/duckdb/src/common/local_file_system.cpp index 3928f6f2c..feb3d4a5d 100644 --- a/src/duckdb/src/common/local_file_system.cpp +++ b/src/duckdb/src/common/local_file_system.cpp @@ -6,10 +6,12 @@ #include "duckdb/common/helper.hpp" #include "duckdb/common/memory_mapped_file.hpp" #include "duckdb/common/string_util.hpp" +#include "duckdb/common/thread.hpp" #include "duckdb/common/windows.hpp" #include "duckdb/function/scalar/string_common.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/database.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/logging/file_system_logger.hpp" #include "duckdb/logging/log_manager.hpp" #include "duckdb/common/multi_file/multi_file_list.hpp" @@ -201,16 +203,37 @@ bool LocalFileSystem::IsPipe(const string &filename, optional_ptr op #define O_DIRECT 0 #endif +static idx_t GetLocalFileSystemDelay(optional_ptr db) { + if (!db) { + return 0; + } + return Settings::Get(*db); +} + +static void ApplyLocalFileSystemDelay(optional_ptr db) { +#ifndef DUCKDB_NO_THREADS + auto delay_ms = GetLocalFileSystemDelay(db); + if (delay_ms > 0) { + ThreadUtil::SleepMs(delay_ms); + } +#endif +} + +static void ApplyLocalFileSystemDelay(optional_ptr opener) { + ApplyLocalFileSystemDelay(FileOpener::TryGetDatabase(opener)); +} + struct UnixFileHandle : public FileHandle { public: - UnixFileHandle(FileSystem &file_system, string path, int fd, FileOpenFlags flags) - : FileHandle(file_system, std::move(path), flags), fd(fd) { + UnixFileHandle(FileSystem &file_system, string path, int fd, FileOpenFlags flags, optional_ptr db) + : FileHandle(file_system, std::move(path), flags), fd(fd), db(db) { } ~UnixFileHandle() override { UnixFileHandle::Close(); } int fd; + optional_ptr db; // Kept for logging purposes idx_t current_pos = 0; @@ -501,6 +524,7 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF } // Open the file + ApplyLocalFileSystemDelay(opener); int fd = open(path.c_str(), open_flags, filesec); if (fd == -1) { @@ -529,7 +553,7 @@ unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenF TryAcquireFileLock(*this, fd, path, flags); - auto file_handle = make_uniq(*this, path, fd, flags); + auto file_handle = make_uniq(*this, path, fd, flags, FileOpener::TryGetDatabase(opener)); if (opener) { file_handle->TryAddLogger(*opener); DUCKDB_LOG_FILE_SYSTEM_OPEN((*file_handle)); @@ -558,7 +582,9 @@ idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { auto bytes_to_read = nr_bytes; - int fd = handle.Cast().fd; + auto &unix_handle = handle.Cast(); + ApplyLocalFileSystemDelay(unix_handle.db); + int fd = unix_handle.fd; auto read_buffer = char_ptr_cast(buffer); while (nr_bytes > 0) { int64_t bytes_read = @@ -582,6 +608,7 @@ void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, i int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { auto &unix_handle = handle.Cast(); + ApplyLocalFileSystemDelay(unix_handle.db); int fd = unix_handle.fd; int64_t bytes_read = read(fd, buffer, UnsafeNumericCast(nr_bytes)); if (bytes_read == -1) { @@ -596,7 +623,9 @@ int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes } void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - int fd = handle.Cast().fd; + auto &unix_handle = handle.Cast(); + ApplyLocalFileSystemDelay(unix_handle.db); + int fd = unix_handle.fd; auto write_buffer = char_ptr_cast(buffer); auto bytes_to_write = nr_bytes; @@ -624,6 +653,7 @@ void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, int64_t LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { auto &unix_handle = handle.Cast(); + ApplyLocalFileSystemDelay(unix_handle.db); int fd = unix_handle.fd; auto bytes_to_write = nr_bytes; @@ -696,6 +726,8 @@ void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { } bool LocalFileSystem::DirectoryExists(const string &directory, optional_ptr opener) { + ApplyLocalFileSystemDelay(opener); + if (!directory.empty()) { auto normalized_dir = ExpandPath(directory, opener); if (access(normalized_dir.c_str(), 0) == 0) { @@ -713,6 +745,8 @@ bool LocalFileSystem::DirectoryExists(const string &directory, optional_ptr opener) { struct stat st; + ApplyLocalFileSystemDelay(opener); + auto normalized_dir = ExpandPath(directory, opener); if (stat(normalized_dir.c_str(), &st) != 0) { /* Directory does not exist. EEXIST for race condition */ @@ -1074,7 +1108,7 @@ static timestamp_t FiletimeToTimeStamp(FILETIME file_time) { // Adapted from: https://stackoverflow.com/questions/6161776/convert-windows-filetime-to-second-in-unix-linux const auto WINDOWS_TICK = 10000000; const auto SEC_TO_UNIX_EPOCH = 11644473600LL; - return Timestamp::FromTimeT(fileTime64 / WINDOWS_TICK - SEC_TO_UNIX_EPOCH); + return Timestamp::FromEpochSeconds(fileTime64 / WINDOWS_TICK - SEC_TO_UNIX_EPOCH); } static FileMetadata StatsInternal(HANDLE hFile, const string &path) { @@ -1084,13 +1118,13 @@ static FileMetadata StatsInternal(HANDLE hFile, const string &path) { if (handle_type == FILE_TYPE_CHAR) { file_metadata.file_type = FileType::FILE_TYPE_CHARDEV; file_metadata.file_size = 0; - file_metadata.last_modification_time = Timestamp::FromTimeT(0); + file_metadata.last_modification_time = Timestamp::FromEpochSeconds(0); return file_metadata; } if (handle_type == FILE_TYPE_PIPE) { file_metadata.file_type = FileType::FILE_TYPE_FIFO; file_metadata.file_size = 0; - file_metadata.last_modification_time = Timestamp::FromTimeT(0); + file_metadata.last_modification_time = Timestamp::FromEpochSeconds(0); return file_metadata; } diff --git a/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp b/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp index 67378a196..28387c134 100644 --- a/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp +++ b/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp @@ -344,6 +344,18 @@ MapColumnMapComponent(ClientContext &context, return child_map; } +static bool IsInvalidMapKeyDefault(const ColumnMapResult &mapping) { + if (!mapping.default_value) { + return false; + } + auto &expr = *mapping.default_value; + if (expr.GetExpressionClass() != ExpressionClass::BOUND_CONSTANT) { + return false; + } + auto &constant_expr = expr.Cast(); + return constant_expr.GetValue().IsNull(); +} + static ColumnMapResult MapColumnMap(ClientContext &context, const MultiFileColumnDefinition &global_column, const ColumnIndex &global_index, const MultiFileColumnDefinition &local_column, const MultiFileLocalIndex &local_id, const ColumnMapper &mapper, @@ -384,6 +396,11 @@ static ColumnMapResult MapColumnMap(ClientContext &context, const MultiFileColum auto map_result = MapColumnMapComponent(context, selected_children, global_index, *nested_mapper, i, global_component, local_key_value); + if (name == "key" && IsInvalidMapKeyDefault(map_result)) { + throw InvalidInputException( + "'key' of MAP did not map to a value and the registered DEFAULT is NULL, which is not allowed"); + } + if (map_result.column_index) { child_indexes.push_back(std::move(*map_result.column_index)); mapping->child_mapping.insert(make_pair(i, std::move(map_result.mapping))); diff --git a/src/duckdb/src/common/serializer/async_file_writer.cpp b/src/duckdb/src/common/serializer/async_file_writer.cpp index 23fddacaf..33ceddd2c 100644 --- a/src/duckdb/src/common/serializer/async_file_writer.cpp +++ b/src/duckdb/src/common/serializer/async_file_writer.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/helper.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/storage/buffer_manager.hpp" #include @@ -266,6 +267,9 @@ bool AsyncFileWriter::SupportsPositionalWrites() { } bool AsyncFileWriter::IsLocalFile() { + if (Settings::Get(client_context) > 0) { + return false; + } auto local_file = fs.IsLocalFileSystem(); if (!local_file && handle) { try { diff --git a/src/duckdb/src/common/serializer/async_memory_governor.cpp b/src/duckdb/src/common/serializer/async_memory_governor.cpp new file mode 100644 index 000000000..601cb4faf --- /dev/null +++ b/src/duckdb/src/common/serializer/async_memory_governor.cpp @@ -0,0 +1,93 @@ +#include "duckdb/common/serializer/async_memory_governor.hpp" + +#include "duckdb/common/helper.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/storage/temporary_memory_manager.hpp" + +namespace duckdb { + +ManagedAsyncMemoryGovernor::ManagedAsyncMemoryGovernor(ClientContext &client_context_p) + : client_context(client_context_p) { + auto &scheduler = TaskScheduler::GetScheduler(client_context); + auto regular_threads = MaxValue(NumericCast(scheduler.NumberOfThreads()), 1); + auto async_threads = NumericCast(scheduler.NumberOfAsyncThreads()); + max_pending_bytes = ManagedAsyncMemoryConfig::MAX_PENDING_BYTES_PER_THREAD * regular_threads; + min_pending_bytes = + MinValue(max_pending_bytes, ManagedAsyncMemoryConfig::MIN_PENDING_BYTES_PER_THREAD * regular_threads); + // A reservation is only useful when drain tasks run asynchronously; synchronous draining bounds itself. + if (async_threads > 0 && max_pending_bytes > 0) { + memory_state = TemporaryMemoryManager::Get(client_context).Register(client_context); + memory_state->SetMinimumReservation(min_pending_bytes); + memory_state->SetZero(); + } +} + +ManagedAsyncMemoryGovernor::~ManagedAsyncMemoryGovernor() = default; + +bool ManagedAsyncMemoryGovernor::IsActive() const { + return memory_state != nullptr; +} + +void ManagedAsyncMemoryGovernor::UpdateReservation(idx_t current_pending_bytes) { + if (!memory_state || current_pending_bytes == 0) { + return; + } + + auto current_reservation = memory_state->GetReservation(); + while (current_pending_bytes > MinValue(current_reservation, max_pending_bytes)) { + idx_t next_request; + if (memory_request_bytes > current_reservation) { + // TMM did not fully grant the previous request. Keep retrying it on later growth checks. + next_request = memory_request_bytes; + } else if (memory_request_bytes == 0) { + // Grow coarsely and only release on Release(). + // Repeatedly shrinking here would touch shared TMM state on the registration hot path. + next_request = min_pending_bytes; + } else if (memory_request_bytes >= max_pending_bytes) { + return; + } else if (memory_request_bytes > max_pending_bytes / 2) { + next_request = max_pending_bytes; + } else { + next_request = memory_request_bytes * 2; + } + next_request = MinValue(MaxValue(next_request, min_pending_bytes), max_pending_bytes); + if (next_request <= memory_request_bytes) { + return; + } + + auto previous_reservation = current_reservation; + memory_state->SetRemainingSizeAndUpdateReservation(client_context, next_request); + memory_request_bytes = next_request; + current_reservation = memory_state->GetReservation(); + if (current_reservation <= previous_reservation) { + return; + } + if (current_reservation < next_request) { + return; + } + } +} + +idx_t ManagedAsyncMemoryGovernor::BackpressureBudget() const { + if (!memory_state) { + return NumericLimits::Maximum(); + } + auto reservation = MinValue(memory_state->GetReservation(), max_pending_bytes); + // If TMM only grants a tiny reservation, do not retain an async backlog. This makes low-memory execution + // behave close to synchronous draining, but automatically allows overlap again if the reservation grows later. + if (reservation < ManagedAsyncMemoryConfig::MIN_RESERVATION_FOR_BACKLOG) { + return 0; + } + return reservation; +} + +void ManagedAsyncMemoryGovernor::Release() { + if (!memory_state || memory_request_bytes == 0) { + return; + } + memory_state->SetZero(); + memory_request_bytes = 0; +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/async_task_queue.cpp b/src/duckdb/src/common/serializer/async_task_queue.cpp new file mode 100644 index 000000000..a097de49f --- /dev/null +++ b/src/duckdb/src/common/serializer/async_task_queue.cpp @@ -0,0 +1,634 @@ +#include "duckdb/common/serializer/async_task_queue.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/task_executor.hpp" +#include "duckdb/parallel/task_scheduler.hpp" + +#include + +namespace duckdb { + +static ErrorData TaskErrorDataFromExceptionPtr(const std::exception_ptr &error_ptr) { + try { + std::rethrow_exception(error_ptr); + } catch (const std::exception &ex) { + return ErrorData(ex); + } catch (...) { // LCOV_EXCL_START + return ErrorData("Unknown exception during async task"); + } // LCOV_EXCL_STOP +} + +AsyncTaskRequest::AsyncTaskRequest(unique_ptr task_p, idx_t size_p, AsyncTaskCompletionCallback completion_p) + : task(std::move(task_p)), size(size_p), completion(std::move(completion_p)) { +} + +idx_t AsyncTaskRequest::Size() const { + return size; +} + +//===--------------------------------------------------------------------===// +// AsyncTaskQueue +//===--------------------------------------------------------------------===// +class AsyncTaskQueueTaskGuard { +public: + explicit AsyncTaskQueueTaskGuard(AsyncTaskQueue &queue_p) : queue(queue_p) { + } + + ~AsyncTaskQueueTaskGuard() { + Finish(); + } + + void SetRequestSize(idx_t request_size_p) { + D_ASSERT(!finished); + request_size = request_size_p; + } + + void Finish() { + if (!finished) { + queue.FinishTask(request_size); + finished = true; + } + } + +private: + AsyncTaskQueue &queue; + idx_t request_size = 0; + bool finished = false; +}; + +class AsyncTaskQueueTask : public BaseExecutorTask { +public: + AsyncTaskQueueTask(AsyncTaskQueue &queue_p, TaskExecutor &executor) : BaseExecutorTask(executor), queue(queue_p) { + } + + ~AsyncTaskQueueTask() override { + if (!started) { + queue.CancelScheduledTask(); + } + } + + void ExecuteTask() override { + started = true; + queue.DrainRequest(); + } + +private: + AsyncTaskQueue &queue; + bool started = false; +}; + +AsyncTaskQueue::AsyncTaskQueue(ClientContext &client_context_p, idx_t max_active_tasks_p) + : client_context(client_context_p) { + auto &scheduler = TaskScheduler::GetScheduler(client_context); + auto async_threads = NumericCast(scheduler.NumberOfAsyncThreads()); + max_active_tasks = max_active_tasks_p > 0 ? max_active_tasks_p : MaxValue(async_threads, 1); + if (async_threads == 0) { + return; + } + executor = make_uniq(client_context, TaskSchedulerType::ASYNC); +} + +AsyncTaskQueue::~AsyncTaskQueue() { + lock_guard guard(lock); + auto drained = pending_requests.empty() && pending_bytes == 0 && in_flight_bytes == 0 && active_tasks == 0 && + pending_tasks == 0; + D_ASSERT(closed || drained); + D_ASSERT(!closed || drained); +} + +bool AsyncTaskQueue::IsAsync() const { + return executor != nullptr; +} + +bool AsyncTaskQueue::HasError() { + return executor && executor->HasError(); +} + +void AsyncTaskQueue::Submit(AsyncTaskRequest request) { + if (!request.task) { + return; + } + auto request_size = request.Size(); + if (executor && executor->HasError()) { + ErrorData error; + try { + executor->ThrowError(); + } catch (const std::exception &ex) { + error = ErrorData(ex); + } catch (...) { // LCOV_EXCL_START + error = ErrorData("Unknown exception during async task"); + } // LCOV_EXCL_STOP + request.task.reset(); + CompleteRequest(request, request_size, error); + error.Throw(); + } + if (!executor) { + VerifyOpen(); + ExecuteRequest(std::move(request)); + return; + } + + { + lock_guard guard(lock); + VerifyOpen(); + pending_requests.push_back(std::move(request)); + pending_bytes += request_size; + } + ScheduleTasksInternal(); +} + +idx_t AsyncTaskQueue::PendingBytes() { + lock_guard guard(lock); + return pending_bytes + in_flight_bytes; +} + +void AsyncTaskQueue::ScheduleTasksInternal() { + if (!executor) { + return; + } + idx_t schedule_count = 0; + { + lock_guard guard(lock); + VerifyOpen(); + // One drain task per still-unclaimed pending request, capped at the concurrency limit. + while (pending_tasks + schedule_count < pending_requests.size() && + active_tasks + schedule_count < max_active_tasks) { + schedule_count++; + } + active_tasks += schedule_count; + pending_tasks += schedule_count; + } + for (idx_t task_idx = 0; task_idx < schedule_count; task_idx++) { + unique_ptr task; + try { + task = make_uniq(*this, *executor); + } catch (...) { + CancelScheduledTasks(schedule_count - task_idx); + throw; + } + try { + executor->ScheduleTask(std::move(task)); + } catch (...) { + // The task destructor releases this task's slot. Release the slots for tasks not yet created. + CancelScheduledTasks(schedule_count - task_idx - 1); + throw; + } + } +} + +void AsyncTaskQueue::FinishTask(idx_t task_size) { + lock_guard guard(lock); + D_ASSERT(active_tasks > 0); + active_tasks--; + D_ASSERT(in_flight_bytes >= task_size); + in_flight_bytes -= task_size; +} + +void AsyncTaskQueue::CancelScheduledTask() { + CancelScheduledTasks(1); +} + +void AsyncTaskQueue::CancelScheduledTasks(idx_t task_count) { + if (task_count == 0) { + return; + } + lock_guard guard(lock); + D_ASSERT(active_tasks >= task_count); + D_ASSERT(pending_tasks >= task_count); + active_tasks -= task_count; + pending_tasks -= task_count; +} + +void AsyncTaskQueue::DrainRequest() { + AsyncTaskQueueTaskGuard guard(*this); + AsyncTaskRequest request; + bool has_request = false; + { + lock_guard task_guard(lock); + D_ASSERT(active_tasks > 0); + D_ASSERT(pending_tasks > 0); + pending_tasks--; + if (!pending_requests.empty()) { + request = std::move(pending_requests.front()); + pending_requests.pop_front(); + D_ASSERT(pending_bytes >= request.size); + pending_bytes -= request.size; + in_flight_bytes += request.size; + has_request = true; + } + } + if (!has_request) { + guard.Finish(); + return; + } + guard.SetRequestSize(request.size); + ExecuteRequest(std::move(request)); + guard.Finish(); + ScheduleTasksInternal(); +} + +void AsyncTaskQueue::CompleteRequest(AsyncTaskRequest &request, idx_t size, optional_ptr error) { + if (request.completion) { + request.completion(size, error); + } +} + +void AsyncTaskQueue::ExecuteRequest(AsyncTaskRequest request) { + auto request_size = request.Size(); + ErrorData task_error; + bool has_error = false; + try { + request.task->Execute(); + } catch (const std::exception &ex) { + task_error = ErrorData(ex); + has_error = true; + } catch (...) { // LCOV_EXCL_START + task_error = ErrorData("Unknown exception during async task"); + has_error = true; + } // LCOV_EXCL_STOP + + request.task.reset(); + if (has_error) { + CompleteRequest(request, request_size, task_error); + task_error.Throw(); + } + CompleteRequest(request, request_size, nullptr); +} + +void AsyncTaskQueue::RethrowTaskError() { + if (executor && executor->HasError()) { + executor->ThrowError(); + } +} + +void AsyncTaskQueue::WorkOnPendingTask() { + if (!executor) { + VerifyOpen(); + return; + } + shared_ptr task; + if (!executor->GetTask(task)) { + TaskScheduler::YieldThread(); + return; + } + auto result = task->Execute(TaskExecutionMode::PROCESS_ALL); + D_ASSERT(result != TaskExecutionResult::TASK_BLOCKED); + task.reset(); +} + +void AsyncTaskQueue::Flush() { + bool already_closed; + { + lock_guard guard(lock); + already_closed = closed; + } + if (already_closed) { + RethrowTaskError(); + return; + } + if (!executor) { + RethrowTaskError(); + return; + } + + try { + ScheduleTasksInternal(); + executor->WorkOnTasks(); + } catch (...) { + try { + executor->WorkOnTasks(); + } catch (...) { + } + throw; + } + RethrowTaskError(); +} + +void AsyncTaskQueue::VerifyDrained() const { + if (!pending_requests.empty() || pending_bytes != 0 || in_flight_bytes != 0 || active_tasks != 0 || + pending_tasks != 0) { + throw InternalException("AsyncTaskQueue still owns submitted tasks"); + } +} + +void AsyncTaskQueue::CancelPendingRequestsAfterFailure(const ErrorData &error) noexcept { + deque requests; + { + lock_guard guard(lock); + D_ASSERT(active_tasks == 0); + D_ASSERT(pending_tasks == 0); + D_ASSERT(in_flight_bytes == 0); + if (active_tasks != 0 || pending_tasks != 0 || in_flight_bytes != 0) { + return; + } + + requests = std::move(pending_requests); + pending_bytes = 0; + closed = true; + } + + for (auto &request : requests) { + auto request_size = request.Size(); + request.task.reset(); + try { + CompleteRequest(request, request_size, error); + } catch (...) { + } + } +} + +void AsyncTaskQueue::Close() { + bool already_closed; + { + lock_guard guard(lock); + already_closed = closed; + } + if (already_closed) { + RethrowTaskError(); + return; + } + + try { + Flush(); + } catch (...) { + auto error = std::current_exception(); + auto error_data = TaskErrorDataFromExceptionPtr(error); + CancelPendingRequestsAfterFailure(error_data); + std::rethrow_exception(error); + } + + lock_guard guard(lock); + VerifyDrained(); + closed = true; +} + +void AsyncTaskQueue::VerifyOpen() const { + if (closed) { + throw InternalException("Cannot use closed AsyncTaskQueue"); + } +} + +//===--------------------------------------------------------------------===// +// ManagedAsyncTaskQueue +//===--------------------------------------------------------------------===// +ManagedAsyncTaskQueue::ManagedAsyncTaskQueue(ClientContext &client_context_p, idx_t max_active_tasks) + : client_context(client_context_p), memory_governor(client_context_p) { + auto &scheduler = TaskScheduler::GetScheduler(client_context); + auto async_threads = NumericCast(scheduler.NumberOfAsyncThreads()); + max_active_drain_tasks = max_active_tasks > 0 ? max_active_tasks : MaxValue(async_threads, 1); + task_queue = make_uniq(client_context, max_active_drain_tasks); +} + +ManagedAsyncTaskQueue::~ManagedAsyncTaskQueue() { + lock_guard guard(lock); + auto drained = pending_requests.empty() && pending_bytes == 0 && submitted_bytes == 0 && submitted_requests == 0; + D_ASSERT(closed || drained); + D_ASSERT(!closed || drained); +} + +bool ManagedAsyncTaskQueue::IsAsync() const { + return task_queue->IsAsync(); +} + +bool ManagedAsyncTaskQueue::HasError() { + return task_queue->HasError(); +} + +void ManagedAsyncTaskQueue::Register(unique_ptr task, idx_t byte_size) { + if (!task) { + return; + } + RethrowTaskError(); + + if (!task_queue->IsAsync()) { + VerifyOpen(); + task_queue->Submit(AsyncTaskRequest(std::move(task), byte_size)); + return; + } + + AsyncTaskRequest request(std::move(task), byte_size); + { + lock_guard guard(lock); + VerifyOpen(); + pending_requests.push_back(std::move(request)); + pending_bytes += byte_size; + } + UpdateMemoryState(); + SchedulePendingTasks(); +} + +void ManagedAsyncTaskQueue::SchedulePendingTasks(SchedulePolicy policy) { + if (!task_queue->IsAsync()) { + return; + } + while (true) { + AsyncTaskRequest request; + if (!TakePendingTaskRequest(request, policy)) { + return; + } + task_queue->Submit(std::move(request)); + } +} + +void ManagedAsyncTaskQueue::UpdateMemoryState() { + if (!memory_governor.IsActive()) { + return; + } + idx_t current_pending_bytes; + { + lock_guard guard(lock); + current_pending_bytes = TotalPendingBytes(); + } + memory_governor.UpdateReservation(current_pending_bytes); +} + +idx_t ManagedAsyncTaskQueue::TotalPendingBytes() const { + return pending_bytes + submitted_bytes; +} + +bool ManagedAsyncTaskQueue::TakePendingTaskRequest(AsyncTaskRequest &request, SchedulePolicy policy) { + lock_guard guard(lock); + if (pending_requests.empty()) { + return false; + } + // Keep at most max_active_drain_tasks submitted so the low-level queue's backlog stays bounded. + if (policy == SchedulePolicy::THRESHOLD && submitted_requests >= max_active_drain_tasks) { + return false; + } + + auto request_size = pending_requests.front().Size(); + request = std::move(pending_requests.front()); + pending_requests.pop_front(); + D_ASSERT(pending_bytes >= request_size); + pending_bytes -= request_size; + submitted_bytes += request_size; + submitted_requests++; + // Attach accounting only on submission, so cancelled pending tasks never carry it. + AddCompletionAccounting(request); + return true; +} + +void ManagedAsyncTaskQueue::AddCompletionAccounting(AsyncTaskRequest &request) { + request.completion = [this](idx_t size, optional_ptr error) { + CompleteSubmittedTask(size, error); + }; +} + +void ManagedAsyncTaskQueue::CompleteSubmittedTask(idx_t size, optional_ptr error) { + bool refill = false; + { + lock_guard guard(lock); + D_ASSERT(submitted_requests > 0); + submitted_requests--; + D_ASSERT(submitted_bytes >= size); + submitted_bytes -= size; + refill = !error && !closed && !pending_requests.empty(); + } + if (refill) { + SchedulePendingTasks(); + } +} + +void ManagedAsyncTaskQueue::ApplyBackpressure() { + if (!task_queue->IsAsync()) { + VerifyOpen(); + return; + } + RethrowTaskError(); + UpdateMemoryState(); + SchedulePendingTasks(); + while (true) { + idx_t current_pending_bytes; + { + lock_guard guard(lock); + current_pending_bytes = TotalPendingBytes(); + } + if (current_pending_bytes <= memory_governor.BackpressureBudget()) { + return; + } + SchedulePendingTasks(SchedulePolicy::FORCE); + task_queue->WorkOnPendingTask(); + RethrowTaskError(); + } +} + +void ManagedAsyncTaskQueue::WaitAll() { + bool already_closed; + { + lock_guard guard(lock); + already_closed = closed; + } + if (already_closed) { + RethrowTaskError(); + return; + } + if (!task_queue->IsAsync()) { + RethrowTaskError(); + return; + } + + try { + UpdateMemoryState(); + while (true) { + if (!task_queue->HasError()) { + SchedulePendingTasks(SchedulePolicy::FORCE); + } + task_queue->Flush(); + lock_guard guard(lock); + if (pending_requests.empty() && pending_bytes == 0 && submitted_bytes == 0 && submitted_requests == 0) { + break; + } + } + } catch (...) { + try { + task_queue->Flush(); + } catch (...) { + } + throw; + } + + RethrowTaskError(); +} + +void ManagedAsyncTaskQueue::VerifyDrained() const { + if (!pending_requests.empty() || pending_bytes != 0 || submitted_bytes != 0 || submitted_requests != 0) { + throw InternalException("ManagedAsyncTaskQueue still owns registered tasks"); + } +} + +void ManagedAsyncTaskQueue::CancelPendingTasksAfterFailure(const ErrorData &error) noexcept { + deque tasks; + { + lock_guard guard(lock); + D_ASSERT(submitted_requests == 0); + D_ASSERT(submitted_bytes == 0); + if (submitted_requests != 0 || submitted_bytes != 0) { + return; + } + + tasks = std::move(pending_requests); + pending_bytes = 0; + closed = true; + } + + // Pending tasks were never submitted, so they carry no accounting; only a user completion (if any) fires. + for (auto &request : tasks) { + auto request_size = request.Size(); + request.task.reset(); + if (request.completion) { + try { + request.completion(request_size, error); + } catch (...) { + } + } + } +} + +void ManagedAsyncTaskQueue::Close() { + bool already_closed; + { + lock_guard guard(lock); + already_closed = closed; + } + if (already_closed) { + RethrowTaskError(); + return; + } + + try { + WaitAll(); + task_queue->Close(); + memory_governor.Release(); + } catch (...) { + auto error = std::current_exception(); + try { + task_queue->Close(); + } catch (...) { + } + auto error_data = TaskErrorDataFromExceptionPtr(error); + CancelPendingTasksAfterFailure(error_data); + try { + memory_governor.Release(); + } catch (...) { + } + std::rethrow_exception(error); + } + + lock_guard guard(lock); + VerifyDrained(); + closed = true; +} + +void ManagedAsyncTaskQueue::RethrowTaskError() { + task_queue->RethrowTaskError(); +} + +void ManagedAsyncTaskQueue::VerifyOpen() const { + if (closed) { + throw InternalException("Cannot use closed ManagedAsyncTaskQueue"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/async_write_queue.cpp b/src/duckdb/src/common/serializer/async_write_queue.cpp index 30713e67a..5f5fa20f6 100644 --- a/src/duckdb/src/common/serializer/async_write_queue.cpp +++ b/src/duckdb/src/common/serializer/async_write_queue.cpp @@ -8,16 +8,15 @@ #include "duckdb/parallel/task_executor.hpp" #include "duckdb/parallel/task_scheduler.hpp" #include "duckdb/storage/buffer_manager.hpp" -#include "duckdb/storage/temporary_memory_manager.hpp" #include #include namespace duckdb { -static ErrorData ErrorDataFromExceptionPtr(std::exception_ptr error_ptr) { +static ErrorData ErrorDataFromExceptionPtr(const std::exception_ptr &error_ptr) { try { - std::rethrow_exception(std::move(error_ptr)); + std::rethrow_exception(error_ptr); } catch (const std::exception &ex) { return ErrorData(ex); } catch (...) { // LCOV_EXCL_START @@ -496,21 +495,13 @@ idx_t ManagedAsyncWriteQueue::PendingWrite::Size() const { } ManagedAsyncWriteQueue::ManagedAsyncWriteQueue(ClientContext &client_context_p, AsyncWriteTarget &target_p) - : client_context(client_context_p), target(target_p) { + : client_context(client_context_p), target(target_p), memory_governor(client_context_p) { auto &scheduler = TaskScheduler::GetScheduler(client_context); - auto regular_threads = MaxValue(NumericCast(scheduler.NumberOfThreads()), 1); auto async_threads = NumericCast(scheduler.NumberOfAsyncThreads()); max_active_drain_tasks = MaxValue(async_threads, 1); - max_pending_bytes = AsyncWriteConfig::MAX_PENDING_BYTES_PER_THREAD * regular_threads; - min_pending_bytes = MinValue(max_pending_bytes, AsyncWriteConfig::MIN_PENDING_BYTES_PER_THREAD * regular_threads); AsyncWriteTarget &async_target = *this; write_queue = make_uniq(client_context, async_target); - if (write_queue->IsAsync() && max_pending_bytes > 0) { - memory_state = TemporaryMemoryManager::Get(client_context).Register(client_context); - memory_state->SetMinimumReservation(min_pending_bytes); - memory_state->SetZero(); - } } ManagedAsyncWriteQueue::~ManagedAsyncWriteQueue() { @@ -568,6 +559,10 @@ void ManagedAsyncWriteQueue::DiscardExternalPendingBytes(idx_t bytes) noexcept { return; } lock_guard guard(lock); + if (closed) { + // a failure already cleared all external pending bytes via CancelPendingWritesAfterFailure + return; + } D_ASSERT(external_pending_bytes >= bytes); if (external_pending_bytes >= bytes) { external_pending_bytes -= bytes; @@ -588,7 +583,6 @@ void ManagedAsyncWriteQueue::RegisterWriteInternal(AsyncWriteRequest request, id return; } - AddCompletionAccounting(request); { lock_guard guard(lock); VerifyOpen(); @@ -629,7 +623,7 @@ void ManagedAsyncWriteQueue::SchedulePendingWritesInternal(SchedulePolicy policy void ManagedAsyncWriteQueue::UpdateMemoryState(MemoryUpdateMode mode) { (void)mode; - if (!memory_state) { + if (!memory_governor.IsActive()) { return; } @@ -638,56 +632,11 @@ void ManagedAsyncWriteQueue::UpdateMemoryState(MemoryUpdateMode mode) { lock_guard guard(lock); current_pending_bytes = TotalPendingBytes(); } - if (current_pending_bytes == 0) { - return; - } - - auto current_reservation = memory_state->GetReservation(); - while (current_pending_bytes > MinValue(current_reservation, max_pending_bytes)) { - idx_t next_request; - if (memory_request_bytes > current_reservation) { - // TMM did not fully grant the previous request. Keep retrying it on later growth checks. - next_request = memory_request_bytes; - } else if (memory_request_bytes == 0) { - // Grow coarsely and only release on Close(). - // Repeatedly shrinking here would touch shared TMM state on the write-registration hot path. - next_request = min_pending_bytes; - } else if (memory_request_bytes >= max_pending_bytes) { - return; - } else if (memory_request_bytes > max_pending_bytes / 2) { - next_request = max_pending_bytes; - } else { - next_request = memory_request_bytes * 2; - } - next_request = MinValue(MaxValue(next_request, min_pending_bytes), max_pending_bytes); - if (next_request <= memory_request_bytes) { - return; - } - - auto previous_reservation = current_reservation; - memory_state->SetRemainingSizeAndUpdateReservation(client_context, next_request); - memory_request_bytes = next_request; - current_reservation = memory_state->GetReservation(); - if (current_reservation <= previous_reservation) { - return; - } - if (current_reservation < next_request) { - return; - } - } + memory_governor.UpdateReservation(current_pending_bytes); } idx_t ManagedAsyncWriteQueue::BackpressureBudget() { - if (!memory_state) { - return NumericLimits::Maximum(); - } - auto reservation = MinValue(memory_state->GetReservation(), max_pending_bytes); - // If TMM only grants a tiny reservation, do not retain an async backlog. This makes low-memory execution - // behave close to synchronous writes, but automatically allows overlap again if the reservation grows later. - if (reservation < AsyncWriteConfig::REMOTE_COALESCE_THRESHOLD) { - return 0; - } - return reservation; + return memory_governor.BackpressureBudget(); } idx_t ManagedAsyncWriteQueue::DrainTaskByteBudget() const { @@ -731,6 +680,8 @@ bool ManagedAsyncWriteQueue::TakePendingWriteRequest(AsyncWriteRequest &request, pending_bytes -= request_size; submitted_bytes += request_size; submitted_requests++; + // Attach accounting only on submission, so cancelled pending writes never carry it. + AddCompletionAccounting(request); return true; } @@ -853,6 +804,7 @@ void ManagedAsyncWriteQueue::CancelPendingWritesAfterFailure(const ErrorData &er auto request_size = pending.Size(); auto &request = pending.request; request.payload.reset(); + // Pending writes were never submitted, so request.completion is still the raw user callback (if any). if (request.completion) { try { request.completion(request.offset, request_size, error); @@ -898,11 +850,7 @@ void ManagedAsyncWriteQueue::Close() { } void ManagedAsyncWriteQueue::ReleaseMemoryReservation() { - if (!memory_state || memory_request_bytes == 0) { - return; - } - memory_state->SetZero(); - memory_request_bytes = 0; + memory_governor.Release(); } void ManagedAsyncWriteQueue::RethrowTaskError() { @@ -1259,6 +1207,7 @@ void ManagedAsyncWriteStreamQueue::CompleteSubmittedWrite(idx_t offset, idx_t si optional_ptr error) { (void)offset; bool refill = false; + auto refill_policy = SchedulePolicy::THRESHOLD; { lock_guard guard(lock); D_ASSERT(submitted_requests > 0); @@ -1266,9 +1215,12 @@ void ManagedAsyncWriteStreamQueue::CompleteSubmittedWrite(idx_t offset, idx_t si D_ASSERT(submitted_bytes >= size); submitted_bytes -= size; refill = !error && !closed && batch_depth == 0 && !pending_writes.empty(); + if (force_completion_refill) { + refill_policy = SchedulePolicy::FORCE; + } } if (refill) { - SchedulePendingWritesInternal(); + SchedulePendingWritesInternal(refill_policy); } } @@ -1338,9 +1290,14 @@ void ManagedAsyncWriteStreamQueue::WaitAll(BatchDrainMode batch_drain_mode) { lock_guard guard(lock); batch_depth = previous_batch_depth; }; + auto set_force_completion_refill = [&](bool enabled) { + lock_guard guard(lock); + force_completion_refill = enabled; + }; try { open_batch_for_drain(); + set_force_completion_refill(true); write_queue->UpdateMemoryState(ManagedAsyncWriteQueue::MemoryUpdateMode::FORCE); while (true) { if (!write_queue->HasError()) { @@ -1355,13 +1312,16 @@ void ManagedAsyncWriteStreamQueue::WaitAll(BatchDrainMode batch_drain_mode) { } catch (...) { try { open_batch_for_drain(); + set_force_completion_refill(true); write_queue->WaitAll(); } catch (...) { } + set_force_completion_refill(false); restore_batch(); throw; } + set_force_completion_refill(false); restore_batch(); RethrowTaskError(); } diff --git a/src/duckdb/src/common/thread_util.cpp b/src/duckdb/src/common/thread_util.cpp index bc96a2022..a6e73754d 100644 --- a/src/duckdb/src/common/thread_util.cpp +++ b/src/duckdb/src/common/thread_util.cpp @@ -2,6 +2,8 @@ #include "duckdb/common/chrono.hpp" #include "duckdb/original/std/sstream.hpp" #include "duckdb/common/helper.hpp" +#include "duckdb/common/checked_integer.hpp" +#include "duckdb/common/exception.hpp" #include "duckdb/common/types/timestamp.hpp" #include "duckdb/common/types/interval.hpp" #include "duckdb/main/client_context.hpp" @@ -10,8 +12,12 @@ namespace duckdb { #ifndef DUCKDB_NO_THREADS void ThreadUtil::SleepMs(idx_t sleep_ms, optional_ptr context) { + using checked_int64_t = CheckedInteger; auto target_time = Timestamp::GetCurrentTimestamp(); - target_time.value += static_cast(sleep_ms) * Interval::MICROS_PER_MSEC; + checked_int64_t sleep_duration(sleep_ms); + auto sleep_micros = sleep_duration * Interval::MICROS_PER_MSEC; + checked_int64_t target_value(target_time.value); + target_time.value = (target_value + sleep_micros).GetValue(); static constexpr idx_t DEFAULT_SLEEP_INTERVAL_MS = 100; while (true) { diff --git a/src/duckdb/src/common/types/list_segment.cpp b/src/duckdb/src/common/types/list_segment.cpp index e80feebb9..30ec2e865 100644 --- a/src/duckdb/src/common/types/list_segment.cpp +++ b/src/duckdb/src/common/types/list_segment.cpp @@ -248,102 +248,117 @@ static ListSegment *GetSegment(const ListSegmentFunctions &functions, ArenaAlloc //===--------------------------------------------------------------------===// template static void WriteDataToPrimitiveSegment(const ListSegmentFunctions &, ArenaAllocator &, ListSegment *segment, - RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity + RecursiveUnifiedVectorFormat &input_data, idx_t offset, idx_t count) { auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; + auto segment_data = GetPrimitiveData(segment); + auto input_data_ptr = UnifiedVectorFormat::GetData(input_data.unified); - // write value - if (valid) { - auto segment_data = GetPrimitiveData(segment); - auto input_data_ptr = UnifiedVectorFormat::GetData(input_data.unified); - Store(input_data_ptr[sel_entry_idx], data_ptr_cast(segment_data + segment->count)); + for (idx_t i = 0; i < count; i++) { + auto sel_entry_idx = input_data.unified.sel->get_index(offset + i); + auto target_idx = segment->count + i; + + // write null validity + auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); + null_mask[target_idx] = !valid; + + // write value + if (valid) { + Store(input_data_ptr[sel_entry_idx], data_ptr_cast(segment_data + target_idx)); + } } } static void WriteDataToVarcharSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, - idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity + ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t offset, + idx_t count) { auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // set the length of this string auto str_length_data = GetListLengthData(segment); + auto input_strings = UnifiedVectorFormat::GetData(input_data.unified); - // we can reconstruct the offset from the length - if (!valid) { - Store(0, data_ptr_cast(str_length_data + segment->count)); - return; - } - auto &str_entry = UnifiedVectorFormat::GetData(input_data.unified)[sel_entry_idx]; - auto str_data = str_entry.GetData(); - idx_t str_size = str_entry.GetSize(); - Store(str_size, data_ptr_cast(str_length_data + segment->count)); - - // write the characters to the linked list of child segments + // load the linked list of child segments once, and store it back after writing all strings auto child_segments = Load(data_ptr_cast(GetListChildData(segment))); - idx_t current_offset = 0; - while (current_offset < str_size) { - auto child_segment = GetSegment(functions.child_functions.back(), allocator, child_segments); - auto data = GetStringData(child_segment); - idx_t copy_count = MinValue(str_size - current_offset, child_segment->capacity - child_segment->count); - memcpy(data + child_segment->count, str_data + current_offset, copy_count); - current_offset += copy_count; - child_segment->count += copy_count; + for (idx_t i = 0; i < count; i++) { + auto sel_entry_idx = input_data.unified.sel->get_index(offset + i); + auto target_idx = segment->count + i; + + // write null validity + auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); + null_mask[target_idx] = !valid; + + // we can reconstruct the offset from the length + if (!valid) { + Store(0, data_ptr_cast(str_length_data + target_idx)); + continue; + } + + // set the length of this string + auto &str_entry = input_strings[sel_entry_idx]; + auto str_data = str_entry.GetData(); + idx_t str_size = str_entry.GetSize(); + Store(str_size, data_ptr_cast(str_length_data + target_idx)); + + // write the characters to the linked list of child segments + idx_t current_offset = 0; + while (current_offset < str_size) { + auto child_segment = GetSegment(functions.child_functions.back(), allocator, child_segments); + auto data = GetStringData(child_segment); + idx_t copy_count = + MinValue(str_size - current_offset, child_segment->capacity - child_segment->count); + memcpy(data + child_segment->count, str_data + current_offset, copy_count); + current_offset += copy_count; + child_segment->count += copy_count; + } + child_segments.total_capacity += str_size; } - child_segments.total_capacity += str_size; // store the updated linked list Store(child_segments, data_ptr_cast(GetListChildData(segment))); } static void WriteDataToListSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity + ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t offset, + idx_t count) { auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // set the length of this list auto list_length_data = GetListLengthData(segment); - uint64_t list_length = 0; - - if (valid) { - // get list entry information - const auto &list_entry = UnifiedVectorFormat::GetData(input_data.unified)[sel_entry_idx]; - list_length = list_entry.length; - - // loop over the child vector entries and recurse on them - auto child_segments = Load(data_ptr_cast(GetListChildData(segment))); - D_ASSERT(functions.child_functions.size() == 1); - for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - auto source_idx_child = list_entry.offset + child_idx; - functions.child_functions[0].AppendRow(allocator, child_segments, input_data.children.back(), - source_idx_child); + auto input_lists = UnifiedVectorFormat::GetData(input_data.unified); + + D_ASSERT(functions.child_functions.size() == 1); + // load the linked list of child segments once, and store it back after writing all lists + auto child_segments = Load(data_ptr_cast(GetListChildData(segment))); + for (idx_t i = 0; i < count; i++) { + auto sel_entry_idx = input_data.unified.sel->get_index(offset + i); + auto target_idx = segment->count + i; + + // write null validity + auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); + null_mask[target_idx] = !valid; + + // set the length of this list + uint64_t list_length = 0; + if (valid) { + // get list entry information + const auto &list_entry = input_lists[sel_entry_idx]; + list_length = list_entry.length; + + // recurse on the child vector entries of this list + functions.child_functions[0].AppendRows(allocator, child_segments, input_data.children.back(), + list_entry.offset, list_entry.length); } - // store the updated linked list - Store(child_segments, data_ptr_cast(GetListChildData(segment))); + Store(list_length, data_ptr_cast(list_length_data + target_idx)); } - - Store(list_length, data_ptr_cast(list_length_data + segment->count)); + // store the updated linked list + Store(child_segments, data_ptr_cast(GetListChildData(segment))); } static void WriteDataToStructSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - + ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t offset, + idx_t count) { // write null validity auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; + for (idx_t i = 0; i < count; i++) { + auto sel_entry_idx = input_data.unified.sel->get_index(offset + i); + auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); + null_mask[segment->count + i] = !valid; + } // write value D_ASSERT(input_data.children.size() == functions.child_functions.size()); @@ -353,50 +368,58 @@ static void WriteDataToStructSegment(const ListSegmentFunctions &functions, Aren for (idx_t i = 0; i < input_data.children.size(); i++) { auto child_list_segment = Load(data_ptr_cast(child_list + i)); auto &child_function = functions.child_functions[i]; - child_function.write_data(child_function, allocator, child_list_segment, input_data.children[i], entry_idx); - child_list_segment->count++; + child_function.write_data(child_function, allocator, child_list_segment, input_data.children[i], offset, count); + child_list_segment->count += count; } } static void WriteDataToArraySegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity + ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t offset, + idx_t count) { auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // Arrays require there to be values in the child even when the entry is NULL. auto array_size = ArrayType::GetSize(input_data.logical_type); - auto array_offset = sel_entry_idx * array_size; - auto child_segments = Load(data_ptr_cast(GetArrayChildData(segment))); D_ASSERT(functions.child_functions.size() == 1); - for (idx_t elem_idx = array_offset; elem_idx < array_offset + array_size; elem_idx++) { - functions.child_functions[0].AppendRow(allocator, child_segments, input_data.children.back(), elem_idx); + // load the linked list of child segments once, and store it back after writing all arrays + auto child_segments = Load(data_ptr_cast(GetArrayChildData(segment))); + for (idx_t i = 0; i < count; i++) { + auto sel_entry_idx = input_data.unified.sel->get_index(offset + i); + + // write null validity + auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); + null_mask[segment->count + i] = !valid; + + // Arrays require there to be values in the child even when the entry is NULL. + auto array_offset = sel_entry_idx * array_size; + functions.child_functions[0].AppendRows(allocator, child_segments, input_data.children.back(), array_offset, + array_size); } // store the updated linked list Store(child_segments, data_ptr_cast(GetArrayChildData(segment))); } -void ListSegmentFunctions::AppendRow(ArenaAllocator &allocator, LinkedList &linked_list, - RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) const { +void ListSegmentFunctions::AppendRows(ArenaAllocator &allocator, LinkedList &linked_list, + RecursiveUnifiedVectorFormat &input_data, idx_t offset, idx_t count) const { auto &write_data_to_segment = *this; - auto segment = GetSegment(write_data_to_segment, allocator, linked_list); - write_data_to_segment.write_data(write_data_to_segment, allocator, segment, input_data, entry_idx); - - linked_list.total_capacity++; - segment->count++; + idx_t appended = 0; + while (appended < count) { + // write into the last segment, filling it up to its capacity before moving to a new one + auto segment = GetSegment(write_data_to_segment, allocator, linked_list); + auto append_count = MinValue(count - appended, segment->capacity - segment->count); + write_data_to_segment.write_data(write_data_to_segment, allocator, segment, input_data, offset + appended, + append_count); + + segment->count = NumericCast(segment->count + append_count); + linked_list.total_capacity += append_count; + appended += append_count; + } } void ListSegmentFunctions::AppendListEntry(ArenaAllocator &allocator, LinkedList &linked_list, RecursiveUnifiedVectorFormat &child_data, const list_entry_t &list_entry) const { - for (idx_t child_idx = list_entry.offset; child_idx < list_entry.offset + list_entry.length; child_idx++) { - allocator.AlignNext(); - AppendRow(allocator, linked_list, child_data, child_idx); - } + allocator.AlignNext(); + AppendRows(allocator, linked_list, child_data, list_entry.offset, list_entry.length); } //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/types/selection_vector.cpp b/src/duckdb/src/common/types/selection_vector.cpp index 7c6507e5e..d3ded7f12 100644 --- a/src/duckdb/src/common/types/selection_vector.cpp +++ b/src/duckdb/src/common/types/selection_vector.cpp @@ -16,6 +16,8 @@ SelectionData::SelectionData(idx_t count) { #endif } +SelectionData::~SelectionData() = default; + // LCOV_EXCL_START string SelectionVector::ToString(idx_t count) const { string result = "Selection Vector (" + to_string(count) + ") ["; diff --git a/src/duckdb/src/common/types/value.cpp b/src/duckdb/src/common/types/value.cpp index c9d83abb6..0d3d20894 100644 --- a/src/duckdb/src/common/types/value.cpp +++ b/src/duckdb/src/common/types/value.cpp @@ -2001,6 +2001,15 @@ const LogicalType &UnionValue::GetType(const Value &value) { return UnionType::GetMemberType(value.type(), UnionValue::GetTag(value)); } +Value VariantValue::GetValue(const Value &variant_val) { + D_ASSERT(variant_val.type().id() == LogicalTypeId::VARIANT && !variant_val.IsNull()); + Vector tmp(variant_val, count_t(1)); + RecursiveUnifiedVectorFormat format; + Vector::RecursiveToUnifiedFormat(tmp, format); + UnifiedVariantVectorData vector_data(format); + return VariantUtils::ConvertVariantToValue(vector_data, 0, 0); +} + hugeint_t IntegralValue::Get(const Value &value) { switch (value.type().InternalType()) { case PhysicalType::INT8: diff --git a/src/duckdb/src/common/types/variant/variant_iterator.cpp b/src/duckdb/src/common/types/variant/variant_iterator.cpp index 5214142b2..667af838a 100644 --- a/src/duckdb/src/common/types/variant/variant_iterator.cpp +++ b/src/duckdb/src/common/types/variant/variant_iterator.cpp @@ -76,6 +76,11 @@ VariantIterator::VariantIterator(const Vector &variant) shredded_format.Build(shredded_vec); } +VariantIterator::VariantIterator(const Vector &unshredded_vec, const Vector &shredded) : unshredded(unshredded_vec) { + is_shredded = true; + shredded_format.Build(shredded); +} + //===--------------------------------------------------------------------===// // ShreddedVariantIterator //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/types/variant/variant_value.cpp b/src/duckdb/src/common/types/variant/variant_value.cpp deleted file mode 100644 index 47e60f7a5..000000000 --- a/src/duckdb/src/common/types/variant/variant_value.cpp +++ /dev/null @@ -1,849 +0,0 @@ -#include "duckdb/common/vector/flat_vector.hpp" -#include "duckdb/common/vector/list_vector.hpp" -#include "duckdb/common/vector/string_vector.hpp" -#include "duckdb/common/vector/variant_vector.hpp" -#include "duckdb/common/types/variant_value.hpp" -#include "yyjson.hpp" - -#include "duckdb/common/serializer/varint.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/datetime.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/decimal.hpp" -#include "duckdb/common/types/variant.hpp" -#include "duckdb/common/hugeint.hpp" -#include "duckdb/function/scalar/variant_utils.hpp" -#include "duckdb/common/string_map_set.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/function/cast/variant/to_variant_fwd.hpp" - -using namespace duckdb_yyjson; // NOLINT - -namespace duckdb { - -void VariantValue::AddChild(const string &key, VariantValue &&val) { - D_ASSERT(value_type == VariantValueType::OBJECT); - if (val.IsMissing()) { - return; - } - object_children.emplace(key, std::move(val)); -} - -void VariantValue::AddItem(VariantValue &&val) { - D_ASSERT(value_type == VariantValueType::ARRAY); - if (val.IsMissing()) { - //! SPEC: If a Variant is missing in a context where a value is required, readers must return a Variant null - val = VariantValue::NullValue(); - } - array_items.push_back(std::move(val)); -} - -void VariantValue::SetItems(vector &&values) { - D_ASSERT(value_type == VariantValueType::ARRAY); - for (auto &value : values) { - if (value.IsMissing()) { - //! SPEC: If a Variant is missing in a context where a value is required, readers must return a Variant null - value = VariantValue::NullValue(); - } - } - array_items = std::move(values); -} - -void VariantValue::ReserveItems(idx_t count) { - array_items.reserve(count); -} - -void VariantValue::AddItems(vector::iterator begin, vector::iterator end) { - D_ASSERT(value_type == VariantValueType::ARRAY); - for (; begin != end; begin++) { - auto &value = *begin; - if (value.IsMissing()) { - //! SPEC: If a Variant is missing in a context where a value is required, readers must return a Variant null - value = VariantValue::NullValue(); - } - array_items.push_back(std::move(value)); - } -} - -map VariantValue::TakeObjectChildren() { - D_ASSERT(value_type == VariantValueType::OBJECT); - return std::move(object_children); -} - -const map &VariantValue::ObjectChildren() const { - return object_children; -} - -const vector &VariantValue::ArrayItems() const { - return array_items; -} - -Value VariantValue::GetValue(const Value &variant_val) { - D_ASSERT(variant_val.type().id() == LogicalTypeId::VARIANT && !variant_val.IsNull()); - Vector tmp(variant_val, count_t(1)); - RecursiveUnifiedVectorFormat format; - Vector::RecursiveToUnifiedFormat(tmp, format); - UnifiedVariantVectorData vector_data(format); - return VariantUtils::ConvertVariantToValue(vector_data, 0, 0); -} - -static void AnalyzeValue(const VariantValue &value, idx_t row, DataChunk &offsets) { - auto &keys_offset = variant::OffsetData::GetKeys(offsets)[row]; - auto &children_offset = variant::OffsetData::GetChildren(offsets)[row]; - auto &values_offset = variant::OffsetData::GetValues(offsets)[row]; - auto &data_offset = variant::OffsetData::GetBlob(offsets)[row]; - - values_offset++; - switch (value.value_type) { - case VariantValueType::OBJECT: { - //! Write the count of the children - auto &children = value.ObjectChildren(); - data_offset += GetVarintSize(children.size()); - if (!children.empty()) { - //! Write the children offset - data_offset += GetVarintSize(children_offset); - children_offset += children.size(); - keys_offset += children.size(); - for (auto &child : children) { - auto &child_value = child.second; - AnalyzeValue(child_value, row, offsets); - } - } - break; - } - case VariantValueType::ARRAY: { - //! Write the count of the children - auto &children = value.ArrayItems(); - data_offset += GetVarintSize(children.size()); - if (!children.empty()) { - //! Write the children offset - data_offset += GetVarintSize(children_offset); - children_offset += children.size(); - for (auto &child : children) { - AnalyzeValue(child, row, offsets); - } - } - break; - } - case VariantValueType::PRIMITIVE: { - auto &primitive = value.primitive_value; - auto type_id = primitive.type().id(); - switch (type_id) { - case LogicalTypeId::BOOLEAN: - case LogicalTypeId::SQLNULL: { - break; - } - case LogicalTypeId::TINYINT: { - data_offset += sizeof(int8_t); - break; - } - case LogicalTypeId::SMALLINT: { - data_offset += sizeof(int16_t); - break; - } - case LogicalTypeId::INTEGER: { - data_offset += sizeof(int32_t); - break; - } - case LogicalTypeId::BIGINT: { - data_offset += sizeof(int64_t); - break; - } - case LogicalTypeId::HUGEINT: { - data_offset += sizeof(hugeint_t); - break; - } - case LogicalTypeId::UTINYINT: { - data_offset += sizeof(uint8_t); - break; - } - case LogicalTypeId::USMALLINT: { - data_offset += sizeof(uint16_t); - break; - } - case LogicalTypeId::UINTEGER: { - data_offset += sizeof(uint32_t); - break; - } - case LogicalTypeId::UBIGINT: { - data_offset += sizeof(uint64_t); - break; - } - case LogicalTypeId::UHUGEINT: { - data_offset += sizeof(uhugeint_t); - break; - } - case LogicalTypeId::DOUBLE: { - data_offset += sizeof(double); - break; - } - case LogicalTypeId::FLOAT: { - data_offset += sizeof(float); - break; - } - case LogicalTypeId::DATE: { - data_offset += sizeof(date_t); - break; - } - case LogicalTypeId::TIMESTAMP_TZ: { - data_offset += sizeof(timestamp_tz_t); - break; - } - case LogicalTypeId::TIMESTAMP_TZ_NS: { - data_offset += sizeof(timestamp_tz_ns_t); - break; - } - case LogicalTypeId::TIMESTAMP: { - data_offset += sizeof(timestamp_t); - break; - } - case LogicalTypeId::TIMESTAMP_SEC: { - data_offset += sizeof(timestamp_sec_t); - break; - } - case LogicalTypeId::TIMESTAMP_MS: { - data_offset += sizeof(timestamp_ms_t); - break; - } - case LogicalTypeId::TIME: { - data_offset += sizeof(dtime_t); - break; - } - case LogicalTypeId::TIME_NS: { - data_offset += sizeof(dtime_ns_t); - break; - } - case LogicalTypeId::TIME_TZ: { - data_offset += sizeof(dtime_tz_t); - break; - } - case LogicalTypeId::TIMESTAMP_NS: { - data_offset += sizeof(timestamp_ns_t); - break; - } - case LogicalTypeId::INTERVAL: { - data_offset += sizeof(interval_t); - break; - } - case LogicalTypeId::UUID: { - data_offset += sizeof(hugeint_t); - break; - } - case LogicalTypeId::DECIMAL: { - auto &type = primitive.type(); - uint8_t width; - uint8_t scale; - type.GetDecimalProperties(width, scale); - - auto physical_type = type.InternalType(); - data_offset += GetVarintSize(width); - data_offset += GetVarintSize(scale); - switch (physical_type) { - case PhysicalType::INT16: { - data_offset += sizeof(int16_t); - break; - } - case PhysicalType::INT32: { - data_offset += sizeof(int32_t); - break; - } - case PhysicalType::INT64: { - data_offset += sizeof(int64_t); - break; - } - case PhysicalType::INT128: { - data_offset += sizeof(hugeint_t); - break; - } - default: - throw InternalException("Unexpected physical type for Decimal value: %s", - EnumUtil::ToString(physical_type)); - } - break; - } - case LogicalTypeId::BLOB: - case LogicalTypeId::BIGNUM: - case LogicalTypeId::BIT: - case LogicalTypeId::GEOMETRY: - case LogicalTypeId::VARCHAR: { - auto string_data = primitive.GetValueUnsafe(); - data_offset += GetVarintSize(string_data.GetSize()); - data_offset += string_data.GetSize(); - break; - } - default: - throw InternalException("Encountered unrecognized LogicalType in VariantValue::AnalyzeValue: %s", - primitive.type().ToString()); - } - break; - } - default: - throw InternalException("VariantValueType not handled"); - } -} - -uint32_t GetOrCreateIndex(OrderedOwningStringMap &dictionary, const string_t &key) { - auto unsorted_idx = dictionary.size(); - //! This will later be remapped to the sorted idx (see FinalizeVariantKeys in 'to_variant.cpp') - return dictionary.emplace(std::make_pair(key, unsorted_idx)).first->second; -} - -static void ConvertValue(const VariantValue &value, VariantVectorData &result, idx_t row, DataChunk &offsets, - SelectionVector &keys_selvec, OrderedOwningStringMap &dictionary) { - auto blob_data = data_ptr_cast(result.blob_data[row].GetDataWriteable()); - auto keys_list_offset = result.keys_data[row].offset; - auto children_list_offset = result.children_data[row].offset; - auto values_list_offset = result.values_data[row].offset; - - auto &keys_offset = variant::OffsetData::GetKeys(offsets)[row]; - auto &children_offset = variant::OffsetData::GetChildren(offsets)[row]; - auto &values_offset = variant::OffsetData::GetValues(offsets)[row]; - auto &data_offset = variant::OffsetData::GetBlob(offsets)[row]; - - switch (value.value_type) { - case VariantValueType::OBJECT: { - //! Write the count of the children - auto &children = value.ObjectChildren(); - - //! values - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::OBJECT); - result.byte_offset_data[values_list_offset + values_offset] = data_offset; - values_offset++; - - //! data - VarintEncode(static_cast(children.size()), blob_data + data_offset); - data_offset += GetVarintSize(children.size()); - - if (!children.empty()) { - //! Write the children offset - VarintEncode(children_offset, blob_data + data_offset); - data_offset += GetVarintSize(children_offset); - - auto start_of_children = children_offset; - children_offset += children.size(); - - auto it = children.begin(); - for (idx_t i = 0; i < children.size(); i++) { - //! children - result.keys_index_data[children_list_offset + start_of_children + i] = keys_offset; - result.values_index_data[children_list_offset + start_of_children + i] = values_offset; - - auto &child = *it; - //! keys - auto &child_key = child.first; - auto dictionary_index = GetOrCreateIndex(dictionary, child_key); - keys_selvec.set_index(keys_list_offset + keys_offset, dictionary_index); - keys_offset++; - - auto &child_value = child.second; - ConvertValue(child_value, result, row, offsets, keys_selvec, dictionary); - it++; - } - } - break; - } - case VariantValueType::ARRAY: { - //! Write the count of the children - auto &children = value.ArrayItems(); - - //! values - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::ARRAY); - result.byte_offset_data[values_list_offset + values_offset] = data_offset; - values_offset++; - - //! data - VarintEncode(static_cast(children.size()), blob_data + data_offset); - data_offset += GetVarintSize(children.size()); - - if (!children.empty()) { - //! Write the children offset - VarintEncode(children_offset, blob_data + data_offset); - data_offset += GetVarintSize(children_offset); - - auto start_of_children = children_offset; - children_offset += children.size(); - - for (idx_t i = 0; i < children.size(); i++) { - //! children - result.keys_index_validity.SetInvalid(children_list_offset + start_of_children + i); - result.values_index_data[children_list_offset + start_of_children + i] = values_offset; - - auto &child_value = children[i]; - ConvertValue(child_value, result, row, offsets, keys_selvec, dictionary); - } - } - break; - } - case VariantValueType::PRIMITIVE: { - auto &primitive = value.primitive_value; - auto type_id = primitive.type().id(); - result.byte_offset_data[values_list_offset + values_offset] = data_offset; - switch (type_id) { - case LogicalTypeId::BOOLEAN: { - if (primitive.GetValue()) { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::BOOL_TRUE); - } else { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::BOOL_FALSE); - } - break; - } - case LogicalTypeId::SQLNULL: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::VARIANT_NULL); - break; - } - case LogicalTypeId::TINYINT: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT8); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(int8_t); - break; - } - case LogicalTypeId::SMALLINT: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT16); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(int16_t); - break; - } - case LogicalTypeId::INTEGER: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT32); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(int32_t); - break; - } - case LogicalTypeId::BIGINT: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT64); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(int64_t); - break; - } - case LogicalTypeId::HUGEINT: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::INT128); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(hugeint_t); - break; - } - case LogicalTypeId::UTINYINT: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UINT8); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(uint8_t); - break; - } - case LogicalTypeId::USMALLINT: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UINT16); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(uint16_t); - break; - } - case LogicalTypeId::UINTEGER: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UINT32); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(uint32_t); - break; - } - case LogicalTypeId::UBIGINT: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UINT64); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(uint64_t); - break; - } - case LogicalTypeId::UHUGEINT: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::UINT128); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(uhugeint_t); - break; - } - case LogicalTypeId::DOUBLE: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::DOUBLE); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(double); - break; - } - case LogicalTypeId::FLOAT: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::FLOAT); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(float); - break; - } - case LogicalTypeId::DATE: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::DATE); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(date_t); - break; - } - case LogicalTypeId::TIMESTAMP_TZ: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::TIMESTAMP_MICROS_TZ); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(timestamp_tz_t); - break; - } - case LogicalTypeId::TIMESTAMP_TZ_NS: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::TIMESTAMP_NANOS_TZ); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(timestamp_tz_t); - break; - } - case LogicalTypeId::TIMESTAMP: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::TIMESTAMP_MICROS); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(timestamp_t); - break; - } - case LogicalTypeId::TIMESTAMP_SEC: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::TIMESTAMP_SEC); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(timestamp_sec_t); - break; - } - case LogicalTypeId::TIMESTAMP_MS: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::TIMESTAMP_MILIS); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(timestamp_ms_t); - break; - } - case LogicalTypeId::TIME: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::TIME_MICROS); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(dtime_t); - break; - } - case LogicalTypeId::TIME_NS: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::TIME_NANOS); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(dtime_ns_t); - break; - } - case LogicalTypeId::TIME_TZ: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::TIME_MICROS_TZ); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(dtime_tz_t); - break; - } - case LogicalTypeId::TIMESTAMP_NS: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::TIMESTAMP_NANOS); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(timestamp_ns_t); - break; - } - case LogicalTypeId::INTERVAL: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::INTERVAL); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(interval_t); - break; - } - case LogicalTypeId::UUID: { - result.type_ids_data[values_list_offset + values_offset] = static_cast(VariantLogicalType::UUID); - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(hugeint_t); - break; - } - case LogicalTypeId::DECIMAL: { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::DECIMAL); - auto &type = primitive.type(); - uint8_t width; - uint8_t scale; - type.GetDecimalProperties(width, scale); - - auto physical_type = type.InternalType(); - VarintEncode(width, blob_data + data_offset); - data_offset += GetVarintSize(width); - VarintEncode(scale, blob_data + data_offset); - data_offset += GetVarintSize(scale); - switch (physical_type) { - case PhysicalType::INT16: { - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(int16_t); - break; - } - case PhysicalType::INT32: { - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(int32_t); - break; - } - case PhysicalType::INT64: { - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(int64_t); - break; - } - case PhysicalType::INT128: { - Store(primitive.GetValueUnsafe(), blob_data + data_offset); - data_offset += sizeof(hugeint_t); - break; - } - default: - throw InternalException("Unexpected physical type for Decimal value: %s", - EnumUtil::ToString(physical_type)); - } - break; - } - case LogicalTypeId::BLOB: - case LogicalTypeId::BIGNUM: - case LogicalTypeId::BIT: - case LogicalTypeId::GEOMETRY: - case LogicalTypeId::VARCHAR: { - if (type_id == LogicalTypeId::BLOB) { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::BLOB); - } else if (type_id == LogicalTypeId::BIGNUM) { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::BIGNUM); - } else if (type_id == LogicalTypeId::BIT) { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::BITSTRING); - } else if (type_id == LogicalTypeId::GEOMETRY) { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::GEOMETRY); - } else { - result.type_ids_data[values_list_offset + values_offset] = - static_cast(VariantLogicalType::VARCHAR); - } - auto string_data = primitive.GetValueUnsafe(); - auto string_size = string_data.GetSize(); - VarintEncode(static_cast(string_size), blob_data + data_offset); - data_offset += GetVarintSize(string_size); - memcpy(blob_data + data_offset, string_data.GetData(), string_size); - data_offset += string_size; - break; - } - default: - throw InternalException("Encountered unrecognized LogicalType in VariantValue::ConvertValue: %s", - primitive.type().ToString()); - } - values_offset++; - break; - } - default: - throw InternalException("VariantValueType not handled"); - } -} - -//! Copied and modified from 'to_variant.cpp' -static void InitializeVariants(DataChunk &offsets, Vector &result, SelectionVector &keys_selvec, idx_t &selvec_size) { - auto count = offsets.size(); - auto &keys = VariantVector::GetKeys(result); - auto keys_data = FlatVector::Writer(keys, count); - - auto &children = VariantVector::GetChildren(result); - auto children_data = FlatVector::Writer(children, count); - - auto &values = VariantVector::GetValues(result); - auto values_data = FlatVector::Writer(values, count); - - auto &blob = VariantVector::GetData(result); - auto blob_data = FlatVector::Writer(blob, count); - - idx_t children_offset = 0; - idx_t values_offset = 0; - idx_t keys_offset = 0; - - auto keys_sizes = variant::OffsetData::GetKeys(offsets); - auto children_sizes = variant::OffsetData::GetChildren(offsets); - auto values_sizes = variant::OffsetData::GetValues(offsets); - auto blob_sizes = variant::OffsetData::GetBlob(offsets); - - for (idx_t i = 0; i < count; i++) { - //! keys - keys_data.WriteValue(list_entry_t(keys_offset, keys_sizes[i])); - keys_offset += keys_sizes[i]; - - //! children - children_data.WriteValue(list_entry_t(children_offset, children_sizes[i])); - children_offset += children_sizes[i]; - - //! values - values_data.WriteValue(list_entry_t(values_offset, values_sizes[i])); - values_offset += values_sizes[i]; - - //! value - blob_data.WriteEmptyString(blob_sizes[i]); - } - - //! Reserve for the children of the lists - ListVector::Reserve(keys, keys_offset); - ListVector::Reserve(children, children_offset); - ListVector::Reserve(values, values_offset); - - //! Set list sizes - ListVector::SetListSize(keys, keys_offset); - ListVector::SetListSize(children, children_offset); - ListVector::SetListSize(values, values_offset); - - keys_selvec.Initialize(keys_offset); - selvec_size = keys_offset; -} - -void VariantValue::ToVARIANT(vector &input, Vector &result) { - auto count = input.size(); - if (input.empty()) { - return; - } - - //! Keep track of all the offsets for each row. - DataChunk analyze_offsets; - analyze_offsets.Initialize( - Allocator::DefaultAllocator(), - {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, count); - analyze_offsets.SetChildCardinality(count); - variant::InitializeOffsets(analyze_offsets, count); - - for (idx_t i = 0; i < count; i++) { - auto &value = input[i]; - if (value.IsNull() || value.IsMissing()) { - continue; - } - AnalyzeValue(value, i, analyze_offsets); - } - - SelectionVector keys_selvec; - idx_t keys_selvec_size; - InitializeVariants(analyze_offsets, result, keys_selvec, keys_selvec_size); - - auto &keys = VariantVector::GetKeys(result); - auto &keys_entry = ListVector::GetChildMutable(keys); - OrderedOwningStringMap dictionary(StringVector::GetStringAllocator(keys_entry)); - - DataChunk conversion_offsets; - conversion_offsets.Initialize( - Allocator::DefaultAllocator(), - {LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER, LogicalType::UINTEGER}, count); - conversion_offsets.SetChildCardinality(count); - variant::InitializeOffsets(conversion_offsets, count); - - VariantVectorData variant_data(result); - for (idx_t i = 0; i < count; i++) { - auto &value = input[i]; - if (value.IsNull() || value.IsMissing()) { - //! SPEC: If a Variant is missing in a context where a value is required, readers must return a Variant null - FlatVector::SetNull(result, i, true); - continue; - } - ConvertValue(value, variant_data, i, conversion_offsets, keys_selvec, dictionary); - } - -#ifdef DEBUG - { - auto conversion_keys_offset = variant::OffsetData::GetKeys(conversion_offsets); - auto conversion_children_offset = variant::OffsetData::GetChildren(conversion_offsets); - auto conversion_values_offset = variant::OffsetData::GetValues(conversion_offsets); - auto conversion_data_offset = variant::OffsetData::GetBlob(conversion_offsets); - - auto analyze_keys_offset = variant::OffsetData::GetKeys(analyze_offsets); - auto analyze_children_offset = variant::OffsetData::GetChildren(analyze_offsets); - auto analyze_values_offset = variant::OffsetData::GetValues(analyze_offsets); - auto analyze_data_offset = variant::OffsetData::GetBlob(analyze_offsets); - - for (idx_t i = 0; i < count; i++) { - D_ASSERT(conversion_keys_offset[i] == analyze_keys_offset[i]); - D_ASSERT(conversion_children_offset[i] == analyze_children_offset[i]); - D_ASSERT(conversion_values_offset[i] == analyze_values_offset[i]); - D_ASSERT(conversion_data_offset[i] == analyze_data_offset[i]); - } - } - -#endif - - //! Finalize the 'data' column of the VARIANT - auto conversion_data_offsets = variant::OffsetData::GetBlob(conversion_offsets); - for (idx_t i = 0; i < count; i++) { - auto &data = variant_data.blob_data[i]; - data.SetSizeAndFinalize(conversion_data_offsets[i], conversion_data_offsets[i]); - } - - VariantUtils::FinalizeVariantKeys(result, dictionary, keys_selvec, keys_selvec_size); - - keys_entry.Slice(keys_selvec, keys_selvec_size); - FlatVector::SetSize(result, count); - result.Verify(); -} - -yyjson_mut_val *VariantValue::ToJSON(ClientContext &context, yyjson_mut_doc *doc) const { - switch (value_type) { - case VariantValueType::PRIMITIVE: { - if (primitive_value.IsNull()) { - return yyjson_mut_null(doc); - } - switch (primitive_value.type().id()) { - case LogicalTypeId::BOOLEAN: { - if (primitive_value.GetValue()) { - return yyjson_mut_true(doc); - } else { - return yyjson_mut_false(doc); - } - } - case LogicalTypeId::TINYINT: - return yyjson_mut_int(doc, primitive_value.GetValue()); - case LogicalTypeId::SMALLINT: - return yyjson_mut_int(doc, primitive_value.GetValue()); - case LogicalTypeId::INTEGER: - return yyjson_mut_int(doc, primitive_value.GetValue()); - case LogicalTypeId::BIGINT: - return yyjson_mut_int(doc, primitive_value.GetValue()); - case LogicalTypeId::FLOAT: - return yyjson_mut_real(doc, primitive_value.GetValue()); - case LogicalTypeId::DOUBLE: - return yyjson_mut_real(doc, primitive_value.GetValue()); - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::VARCHAR: { - auto value_str = primitive_value.ToString(); - return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); - } - case LogicalTypeId::TIMESTAMP: { - auto value_str = primitive_value.ToString(); - return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); - } - case LogicalTypeId::TIMESTAMP_TZ: { - auto value_str = primitive_value.CastAs(context, LogicalType::VARCHAR).GetValue(); - return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); - } - case LogicalTypeId::TIMESTAMP_TZ_NS: { - auto value_str = primitive_value.CastAs(context, LogicalType::VARCHAR).GetValue(); - return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); - } - case LogicalTypeId::TIMESTAMP_NS: { - auto value_str = primitive_value.CastAs(context, LogicalType::VARCHAR).GetValue(); - return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); - } - default: - throw InternalException("Unexpected primitive type: %s", primitive_value.type().ToString()); - } - } - case VariantValueType::OBJECT: { - auto obj = yyjson_mut_obj(doc); - for (const auto &it : object_children) { - auto &key = it.first; - auto value = it.second.ToJSON(context, doc); - yyjson_mut_obj_add_val(doc, obj, key.c_str(), value); - } - return obj; - } - case VariantValueType::ARRAY: { - auto arr = yyjson_mut_arr(doc); - for (auto &item : array_items) { - auto value = item.ToJSON(context, doc); - yyjson_mut_arr_add_val(arr, value); - } - return arr; - } - default: - throw InternalException("Can't serialize this VariantValue type to JSON"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector.cpp b/src/duckdb/src/common/types/vector.cpp index 8f8c7ec78..ce924d2e6 100644 --- a/src/duckdb/src/common/types/vector.cpp +++ b/src/duckdb/src/common/types/vector.cpp @@ -188,6 +188,9 @@ void Vector::Reinterpret(const Vector &other) { new_vector.Reinterpret(DictionaryVector::Child(other)); auto &old_dict = buffer->Cast(); auto new_entry = make_shared_ptr(std::move(new_vector)); + // reinterpret re-mints the entry; the id and global flag are one contract and must survive together + new_entry->id = old_dict.GetEntry().id; + new_entry->global_dictionary = old_dict.GetEntry().global_dictionary; buffer = make_buffer(old_dict.GetSelVector(), old_dict.Capacity(), std::move(new_entry)); } } diff --git a/src/duckdb/src/common/vector/dictionary_vector.cpp b/src/duckdb/src/common/vector/dictionary_vector.cpp index 70d7b5740..9b39ced40 100644 --- a/src/duckdb/src/common/vector/dictionary_vector.cpp +++ b/src/duckdb/src/common/vector/dictionary_vector.cpp @@ -167,6 +167,13 @@ buffer_ptr DictionaryVector::CreateReusableDictionary(const Log return entry; } +buffer_ptr DictionaryVector::CreateReusableGlobalDictionary(const LogicalType &type, + const idx_t &size) { + auto entry = CreateReusableDictionary(type, size); + entry->global_dictionary = true; + return entry; +} + const Vector &DictionaryVector::GetCachedHashes(const Vector &input) { D_ASSERT(CanCacheHashes(input)); diff --git a/src/duckdb/src/execution/index/art/art.cpp b/src/duckdb/src/execution/index/art/art.cpp index bde663f37..553e81e00 100644 --- a/src/duckdb/src/execution/index/art/art.cpp +++ b/src/duckdb/src/execution/index/art/art.cpp @@ -563,12 +563,6 @@ ErrorData ART::InsertKeys(ArenaAllocator &arena, unsafe_vector &keys, un VerifyAllocationsInternal(); } - if (conflict_type == ARTConflictType::TRANSACTION) { - // chunk is only null when called from MergeCheckpointDeltas. - auto msg = chunk ? AppendRowError(*chunk, conflict_idx.GetIndex()) : string("???"); - return ErrorData(TransactionException("write-write conflict on key: \"%s\"", msg)); - } - if (conflict_type == ARTConflictType::CONSTRAINT) { // chunk is only null when called from MergeCheckpointDeltas. auto msg = chunk ? AppendRowError(*chunk, conflict_idx.GetIndex()) : string("???"); diff --git a/src/duckdb/src/execution/index/art/art_index.cpp b/src/duckdb/src/execution/index/art/art_index.cpp index 7a295ee2f..f8c032e55 100644 --- a/src/duckdb/src/execution/index/art/art_index.cpp +++ b/src/duckdb/src/execution/index/art/art_index.cpp @@ -102,7 +102,6 @@ void ARTBuildSinkUnsorted(IndexBuildSinkInput &input, DataChunk &key_chunk, Data auto conflict_type = ARTOperator::Insert(l_state.arena_allocator, art, art.tree, l_state.keys[i], 0, l_state.row_ids[i], status, DeleteIndexInfo(), IndexAppendMode::DEFAULT); - D_ASSERT(conflict_type != ARTConflictType::TRANSACTION); if (conflict_type == ARTConflictType::CONSTRAINT) { throw ConstraintException("Data contains duplicates on indexed column(s)"); } diff --git a/src/duckdb/src/execution/join_hashtable.cpp b/src/duckdb/src/execution/join_hashtable.cpp index a6abe48f8..b0217e24f 100644 --- a/src/duckdb/src/execution/join_hashtable.cpp +++ b/src/duckdb/src/execution/join_hashtable.cpp @@ -83,18 +83,29 @@ JoinHashTable::JoinHashTable(ClientContext &context_p, const PhysicalOperator &o // at least one equality is necessary D_ASSERT(!equality_types.empty()); - // Types for the layout - auto layout = make_shared_ptr(); - vector layout_types(condition_types); - layout_types.insert(layout_types.end(), build_types.begin(), build_types.end()); - if (PropagatesBuildSide(join_type)) { - // full/right outer joins need an extra bool to keep track of whether or not a tuple has found a matching entry - // we place the bool before the NEXT pointer - layout_types.emplace_back(LogicalType::BOOLEAN); + if (join_type == JoinType::SINGLE) { + single_join_error_on_multiple_rows = Settings::Get(context); + } + + if (non_equality_predicates.empty() && !residual_predicate && + (join_type == JoinType::SEMI || join_type == JoinType::ANTI || join_type == JoinType::MARK)) { + insert_duplicate_keys = false; + } + + InitializePartitionMasks(); + // Layout-dependent state is published lazily on the first Sink chunk (see FinishInitWithLayout). +} + +void JoinHashTable::FinishInitWithLayout(shared_ptr published_layout, + vector dict_index_width_p) { + D_ASSERT(!layout_ptr); + layout_ptr = std::move(published_layout); + if (dict_index_width_p.empty()) { + dict_index_width.assign(build_types.size(), 0); + } else { + D_ASSERT(dict_index_width_p.size() == build_types.size()); + dict_index_width = std::move(dict_index_width_p); } - layout_types.emplace_back(LogicalType::HASH); - layout->Initialize(layout_types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); - layout_ptr = std::move(layout); // Initialize the row matcher that are used for filtering during the probing only if there are non-equality preds if (!non_equality_predicates.empty()) { @@ -125,16 +136,8 @@ JoinHashTable::JoinHashTable(ClientContext &context_p, const PhysicalOperator &o dead_end = make_unsafe_uniq_array_uninitialized(layout_ptr->GetRowWidth()); memset(dead_end.get(), 0, layout_ptr->GetRowWidth()); - if (join_type == JoinType::SINGLE) { - single_join_error_on_multiple_rows = Settings::Get(context); - } - - if (non_equality_predicates.empty() && !residual_predicate && - (join_type == JoinType::SEMI || join_type == JoinType::ANTI || join_type == JoinType::MARK)) { - insert_duplicate_keys = false; - } - - InitializePartitionMasks(); + // indexed parallel to build_types / payload columns (not to output_columns) + dict_registry.assign(build_types.size(), nullptr); } JoinHashTable::~JoinHashTable() { @@ -161,6 +164,23 @@ void JoinHashTable::Merge(JoinHashTable &other) { } sink_collection->Combine(*other.sink_collection); + + // Reconcile per-column pinned dictionary entries. For global dictionary producers every thread pins the + // same buffer_ptr, so this is normally a no-op or a "take theirs" adoption. + if (dict_registry.size() < other.dict_registry.size()) { + dict_registry.resize(other.dict_registry.size(), nullptr); + } + for (idx_t col = 0; col < other.dict_registry.size(); col++) { + if (!other.dict_registry[col]) { + continue; + } + if (!dict_registry[col]) { + dict_registry[col] = std::move(other.dict_registry[col]); + } else { + // Both threads pinned the same upstream entry, so ids match by construction; a mismatch is a producer bug. + D_ASSERT(dict_registry[col]->id == other.dict_registry[col]->id); + } + } } static void ApplyBitmaskAndGetSaltBuild(Vector &hashes_v, Vector &salt_v, const idx_t &bitmask) { @@ -391,6 +411,153 @@ static idx_t FilterNullValues(UnifiedVectorFormat &vdata, const SelectionVector return result_count; } +// A dictionary of D slots uses indices 0..D-1, so D slots fit a W-byte index iff D <= 2^(8*W). +static constexpr idx_t DICT_INDEX_UINT8_CAPACITY = 256; +static constexpr idx_t DICT_INDEX_UINT16_CAPACITY = 65536; + +//! Narrowest unsigned-integer byte width that can hold any index into a dictionary of dict_size slots +static uint8_t DictIndexWidth(idx_t dict_size) { + if (dict_size <= DICT_INDEX_UINT8_CAPACITY) { + return sizeof(uint8_t); + } + if (dict_size <= DICT_INDEX_UINT16_CAPACITY) { + return sizeof(uint16_t); + } + return sizeof(uint32_t); +} + +//! Materialise a flat width-INDEX_T index vector (allocation owned by `buffers`, view in `vectors`) from the +//! dict sel. UnsafeNumericCast is safe: the publisher sized the width to the dictionary, so every index fits. +template +static void ScatterDictIndices(const SelectionVector &dict_sel, idx_t count, const LogicalType &index_type, + Allocator &allocator, vector &buffers, vector &vectors) { + buffers.emplace_back(allocator.Allocate(count * sizeof(INDEX_T))); + vectors.emplace_back(index_type, buffers.back().get(), count); + auto idx_data = FlatVector::GetDataMutable(vectors.back()); + for (idx_t r = 0; r < count; r++) { + idx_data[r] = UnsafeNumericCast(dict_sel.get_index(r)); + } +} + +//! Dispatch on the chosen index width (1/2/4 B) to the matching templated ScatterDictIndices +static void ScatterDictIndices(uint8_t index_width, const SelectionVector &dict_sel, idx_t count, Allocator &allocator, + vector &buffers, vector &vectors) { + switch (index_width) { + case sizeof(uint8_t): + ScatterDictIndices(dict_sel, count, LogicalType::UTINYINT, allocator, buffers, vectors); + break; + case sizeof(uint16_t): + ScatterDictIndices(dict_sel, count, LogicalType::USMALLINT, allocator, buffers, vectors); + break; + default: + ScatterDictIndices(dict_sel, count, LogicalType::UINTEGER, allocator, buffers, vectors); + break; + } +} + +//! Load the per-match dict index (width INDEX_T) from each matched row's slot at col_offset into build_sel_vec +template +static void LoadDictIndices(const data_ptr_t *ptrs, const SelectionVector &ptr_sel, idx_t count, idx_t col_offset, + SelectionVector &build_sel_vec) { + for (idx_t i = 0; i < count; i++) { + build_sel_vec.set_index(i, Load(ptrs[ptr_sel.get_index(i)] + col_offset)); + } +} + +//! Dispatch on the chosen index width (1/2/4 B) to the matching templated LoadDictIndices +static void LoadDictIndices(uint8_t index_width, const data_ptr_t *ptrs, const SelectionVector &ptr_sel, idx_t count, + idx_t col_offset, SelectionVector &build_sel_vec) { + switch (index_width) { + case sizeof(uint8_t): + LoadDictIndices(ptrs, ptr_sel, count, col_offset, build_sel_vec); + break; + case sizeof(uint16_t): + LoadDictIndices(ptrs, ptr_sel, count, col_offset, build_sel_vec); + break; + default: + LoadDictIndices(ptrs, ptr_sel, count, col_offset, build_sel_vec); + break; + } +} + +bool JoinHashTable::ColumnReferencedByResidual(idx_t build_col_idx) const { + if (!residual_info) { + return false; + } + const auto layout_col = condition_types.size() + build_col_idx; + for (const auto &kv : residual_info->build_input_to_layout_map) { + if (kv.second == layout_col) { + return true; + } + } + return false; +} + +uint8_t JoinHashTable::GetDictSurvivingIndexWidth(idx_t build_col_idx, const Vector &incoming) const { + const auto &type = build_types[build_col_idx]; + // Vector::Dictionary rejects nested types + switch (type.InternalType()) { + case PhysicalType::STRUCT: + case PhysicalType::LIST: + case PhysicalType::ARRAY: + return 0; + default: + break; + } + // residual predicates read from the row slot directly; a narrowed slot would corrupt them + if (ColumnReferencedByResidual(build_col_idx)) { + return 0; + } + if (incoming.GetVectorType() != VectorType::DICTIONARY_VECTOR) { + return 0; + } + if (!DictionaryVector::IsGlobalDictionary(incoming)) { + return 0; + } + // an empty/invalid dictionary keeps native width + const auto dict_size = DictionaryVector::DictionarySize(incoming); + if (!dict_size.IsValid()) { + return 0; + } + // only narrow when strictly smaller than the native slot (never regress; e.g. INTEGER over a D<=256 dict) + const uint8_t index_width = DictIndexWidth(dict_size.GetIndex()); + const auto native_bytes = GetTypeIdSize(type.InternalType()); + if (index_width >= native_bytes) { + return 0; + } + return index_width; +} + +void JoinHashTable::PinDictSurvivingColumn(idx_t build_col_idx, const Vector &incoming, uint8_t index_width) { + // Slot was narrowed on the first chunk; there is no fallback for a non-dict later chunk. Throw, not D_ASSERT: + // the Cast below is UB in release on a non-dict vector, scattering foreign bytes as indices. + if (incoming.GetVectorType() != VectorType::DICTIONARY_VECTOR || DictionaryVector::DictionaryId(incoming).empty() || + !DictionaryVector::IsGlobalDictionary(incoming)) { + throw InternalException("dict-surviving join: narrowed column %llu received a " + "non-global-dictionary chunk; build pipeline is not single-source", + static_cast(build_col_idx)); + } + const auto &entry = incoming.Buffer().Cast().GetEntry(); + if (dict_registry[build_col_idx]) { + // Subsequent chunks wrap the same entry, so ids match by construction; a mismatch is a producer bug. + D_ASSERT(dict_registry[build_col_idx]->id == entry.id); + return; + } + // The upstream child is a zero-copy gather: its long strings point into the producer's row-store heap, recycled + // before we gather on probe. Deep-copy into a self-owned entry so it outlives them. + const auto &upstream_child = entry.data; + const auto child_count = upstream_child.size(); + // child_count fits index_width by construction (publisher sized the width to this dict). Assert so a producer + // that grows the child past it fails loudly instead of truncating in the UnsafeNumericCast. + D_ASSERT(child_count <= (idx_t(1) << (8 * index_width))); + auto owned_entry = DictionaryVector::CreateReusableGlobalDictionary(upstream_child.GetType(), child_count); + if (child_count > 0) { + VectorOperations::Copy(upstream_child, owned_entry->data, child_count, 0, 0); + } + owned_entry->id = entry.id; + dict_registry[build_col_idx] = std::move(owned_entry); +} + void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChunk &keys, DataChunk &payload) { D_ASSERT(!finalized); D_ASSERT(keys.size() == payload.size()); @@ -427,8 +594,22 @@ void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChu } idx_t col_offset = keys.ColumnCount(); D_ASSERT(build_types.size() == payload.ColumnCount()); + // Scratch index vectors for narrowed columns, allocated via the buffer manager so the bytes are accounted. + // buffers own the bytes, vectors are flat views; both must outlive the ToUnifiedFormat / AppendUnified below. + vector dict_idx_buffers; + vector dict_idx_vectors; for (idx_t i = 0; i < payload.ColumnCount(); i++) { - source_chunk.data[col_offset + i].Reference(payload.data[i]); + auto &incoming = payload.data[i]; + const uint8_t index_width = i < dict_index_width.size() ? dict_index_width[i] : 0; + if (index_width == 0) { + source_chunk.data[col_offset + i].Reference(incoming); + continue; + } + // Narrowed column: pin the dictionary, then scatter its sel indices at the chosen width in place of the value. + PinDictSurvivingColumn(i, incoming, index_width); + ScatterDictIndices(index_width, DictionaryVector::SelVector(incoming), payload.size(), + buffer_manager.GetBufferAllocator(), dict_idx_buffers, dict_idx_vectors); + source_chunk.data[col_offset + i].Reference(dict_idx_vectors.back()); } col_offset += payload.ColumnCount(); if (PropagatesBuildSide(join_type)) { @@ -1438,10 +1619,35 @@ void JoinHashTable::GatherRHS(Vector &row_ptrs, const SelectionVector &ptr_sel, return; } const auto &result_sel = *FlatVector::IncrementalSelectionVector(); + const auto ptrs = FlatVector::GetData(row_ptrs); + const auto &offsets = layout_ptr->GetOffsets(); + const auto cond_count = condition_types.size(); + for (idx_t col_idx = 0; col_idx < output_columns.size(); col_idx++) { auto &vector = result.data[rhs_col_offset + col_idx]; const auto output_col_idx = output_columns[col_idx]; + + // Pinned column: re-emit the upstream dictionary. Layout is [conditions, build, (found), hash], so payload + // columns sit at layout index >= cond_count; guard the subtraction so it cannot underflow into a spurious idx. + if (output_col_idx >= cond_count) { + const idx_t payload_idx = output_col_idx - cond_count; + if (payload_idx < dict_registry.size() && dict_registry[payload_idx]) { + SelectionVector build_sel_vec(count); + const auto col_offset = offsets[output_col_idx]; + // Read the dict index at the width chosen at layout-publication time + const uint8_t index_width = payload_idx < dict_index_width.size() ? dict_index_width[payload_idx] : 0; + LoadDictIndices(index_width, ptrs, ptr_sel, count, col_offset, build_sel_vec); + vector.Dictionary(dict_registry[payload_idx], build_sel_vec, count); + continue; + } + } + D_ASSERT(vector.GetType() == layout_ptr->GetTypes()[output_col_idx]); + // The native gather writes straight into a flat buffer. A reused output chunk (e.g. across recursive-CTE + // iterations) may still hold a DICTIONARY_VECTOR left by a prior dict-emitting iteration; reset it to flat. + if (vector.GetVectorType() != VectorType::FLAT_VECTOR) { + vector.Initialize(); + } data_collection->Gather(row_ptrs, ptr_sel, count, output_col_idx, vector, result_sel, nullptr); FlatVector::SetSize(vector, count_t(count)); } @@ -2178,6 +2384,10 @@ static void ResetCorrelatedMarkJoinInfo(JoinHashTable &ht) { } void JoinHashTable::ResetForNewIterationSinglePartition() { + if (!layout_ptr) { + // layout was never published (no Sink chunk last iteration); the next first chunk will publish + return; + } data_collection->Reset(); // Always use a single partition (radix_bits=0) to avoid per-iteration overhead of resetting // and re-creating many radix partitions when only one thread builds the hash table. @@ -2208,6 +2418,10 @@ void JoinHashTable::ResetForNewIterationSinglePartition() { aux_next_ptrs.Reset(); aux_next_ptrs_data = nullptr; use_dict_emission = false; + // Drop pinned upstream dictionary entries; the next iteration's first chunk re-pins + for (auto &entry : dict_registry) { + entry.reset(); + } } bool JoinHashTable::PrepareExternalFinalize(const idx_t max_ht_size) { @@ -2401,6 +2615,13 @@ bool JoinHashTable::CanUseDictionaryEmission(const PhysicalHashJoin &op, bool ex if (external) { return false; } + // mutually exclusive with the dict-surviving path: that path narrows the payload slot to a 1/2/4-byte + // index, while NEXT_PTR embedding here overwrites a different field and assumes payload bytes are intact. + for (const auto &entry : dict_registry) { + if (entry) { + return false; + } + } // SINGLE joins need FlatVector::SetNull for unmatched rows; dictionary vectors cannot supply it if (join_type == JoinType::SINGLE) { return false; @@ -2468,11 +2689,16 @@ void JoinHashTable::BuildDictionaryArrays(const PhysicalHashJoin &op) { (void)collected; const auto row_ptrs = FlatVector::GetData(row_pointer_vector); + // LEFT / OUTER joins fill unmatched probe rows with a constant-NULL vector (NextLeftJoin), so they do not + // wrap this entry on every chunk; it is only a global dictionary when every chunk goes through EmitDictVectors. + const bool dict_on_every_chunk = join_type != JoinType::LEFT && join_type != JoinType::OUTER; + // gather RHS output columns into columnar dictionary arrays const auto &sel = *FlatVector::IncrementalSelectionVector(); for (idx_t col_idx = 0; col_idx < op.rhs_output_columns.col_types.size(); col_idx++) { const auto &type = op.rhs_output_columns.col_types[col_idx]; - auto dict_entry = DictionaryVector::CreateReusableDictionary(type, build_count); + auto dict_entry = dict_on_every_chunk ? DictionaryVector::CreateReusableGlobalDictionary(type, build_count) + : DictionaryVector::CreateReusableDictionary(type, build_count); const auto output_col_idx = output_columns[col_idx]; collection.Gather(row_pointer_vector, sel, build_count, output_col_idx, dict_entry->data, sel, nullptr); dict_arrays.emplace_back(std::move(dict_entry)); diff --git a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp index 665f91679..547f36a40 100644 --- a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp @@ -15,7 +15,8 @@ class MaterializedCollectorGlobalState : public GlobalSinkState { public: mutex glock; unique_ptr collection; - shared_ptr context; + //! This is weak to avoid creating a cyclical reference + weak_ptr context; }; class MaterializedCollectorLocalState : public LocalSinkState { @@ -64,12 +65,12 @@ unique_ptr PhysicalMaterializedCollector::GetLocalSinkState(Exec unique_ptr PhysicalMaterializedCollector::GetResult(GlobalSinkState &state) const { auto &gstate = state.Cast(); + auto cc = gstate.context.lock(); if (!gstate.collection) { - gstate.collection = CreateCollection(*gstate.context); + gstate.collection = CreateCollection(*cc); } - auto result = - make_uniq(statement_type, properties, IdentifiersToStrings(names), - std::move(gstate.collection), gstate.context->GetClientProperties()); + auto result = make_uniq(statement_type, properties, IdentifiersToStrings(names), + std::move(gstate.collection), cc->GetClientProperties()); return std::move(result); } diff --git a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp index 68e95e145..7055c0d3a 100644 --- a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp +++ b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp @@ -137,7 +137,8 @@ bool PerfectHashJoinExecutor::BuildPerfectHashTable() { // First, allocate memory for each build column const auto build_size = perfect_join_statistics.build_range + 1; for (const auto &type : join.rhs_output_columns.col_types) { - perfect_hash_table.emplace_back(DictionaryVector::CreateReusableDictionary(type, build_size)); + // PHJ keeps each entry alive for the operator's lifetime and wraps it in every emitted chunk + perfect_hash_table.emplace_back(DictionaryVector::CreateReusableGlobalDictionary(type, build_size)); } // and for duplicate_checking diff --git a/src/duckdb/src/execution/operator/join/physical_cross_product.cpp b/src/duckdb/src/execution/operator/join/physical_cross_product.cpp index 822f6342e..46ca0bd8e 100644 --- a/src/duckdb/src/execution/operator/join/physical_cross_product.cpp +++ b/src/duckdb/src/execution/operator/join/physical_cross_product.cpp @@ -135,6 +135,14 @@ OperatorResultType CrossProductExecutor::Execute(const DataChunk &input, DataChu auto &constant_chunk = scan_input_chunk ? scan_chunk : input; auto col_count = constant_chunk.ColumnCount(); auto col_offset = scan_input_chunk ? input.ColumnCount() : 0; + // SetChildCardinality cannot resize a non-flat vector, and a reused output chunk may still hold a stale + // DICTIONARY_VECTOR (e.g. a recursive-CTE dict flush Reset cannot flatten), so reset to flat before resizing. + for (auto &v : output.data) { + auto vtype = v.GetVectorType(); + if (vtype != VectorType::FLAT_VECTOR && vtype != VectorType::CONSTANT_VECTOR) { + v.Initialize(); + } + } output.SetChildCardinality(constant_chunk.size()); for (idx_t i = 0; i < col_count; i++) { output.data[col_offset + i].Reference(constant_chunk.data[i]); diff --git a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp index c44696f51..9c3a019bb 100644 --- a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp @@ -7,6 +7,7 @@ #include "duckdb/common/types/value_map.hpp" #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/types/uhugeint.hpp" +#include "duckdb/common/vector/dictionary_vector.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/join_hashtable.hpp" #include "duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp" @@ -293,6 +294,36 @@ unique_ptr JoinFilterPushdownInfo::GetGlobalState(ClientC return result; } +//! True iff the build subtree funnels multiple producer pipelines into one sink (UNION ALL, recursive CTE), +//! breaking the "decide layout once on the first chunk" contract. Conservative: may over-exclude, never misses one. +static bool BuildSideHasMultipleSources(const PhysicalOperator &op) { + if (op.type == PhysicalOperatorType::UNION || op.type == PhysicalOperatorType::RECURSIVE_CTE || + op.type == PhysicalOperatorType::RECURSIVE_KEY_CTE) { + return true; + } + for (const auto &child : op.children) { + if (BuildSideHasMultipleSources(child.get())) { + return true; + } + } + return false; +} + +//! Synchronisation state for the first-chunk publication of the TupleDataLayout: the canonical layout +//! shared by all per-thread JHTs, plus the per-column decision on whether to narrow the row-store slot. +struct LayoutGate { + mutex publish_mutex; + atomic published {false}; + shared_ptr layout_ptr; + vector dict_index_width; + + void Reset() { + published.store(false, std::memory_order_release); + layout_ptr.reset(); + dict_index_width.clear(); + } +}; + class HashJoinGlobalSinkState : public GlobalSinkState { public: HashJoinGlobalSinkState(const PhysicalHashJoin &op_p, ClientContext &context_p) @@ -306,6 +337,10 @@ class HashJoinGlobalSinkState : public GlobalSinkState { // For perfect hash join perfect_join_executor = make_uniq(op, *hash_table); auto use_perfect_hash = CanUsePerfectHashJoin(op, *perfect_join_executor); + can_use_perfect_hash = use_perfect_hash; + // A multi-source build side (UNION ALL / recursive CTE) feeds the sink from several producers, + // disqualifying dict-surviving. Computed once from the static plan; cannot change at runtime. + build_side_multi_source = BuildSideHasMultipleSources(op.children[1].get()); // For external hash join external = Settings::Get(context); // Set probe types @@ -328,6 +363,11 @@ class HashJoinGlobalSinkState : public GlobalSinkState { void ScheduleFinalize(Pipeline &pipeline, Event &event); void InitializeProbeSpill(); + //! First-chunk election: build the canonical layout (with per-column slot narrowing) and publish it. + //! Idempotent and safe to call concurrently; only the first thread runs the slow path. + void PublishLayoutIfFirst(class HashJoinLocalSinkState &lstate, DataChunk &payload_chunk); + //! True iff at least one column on the global JHT carries a pinned upstream dictionary entry. + bool DictSurvivingActive() const; bool SupportsReuse() const override { return true; @@ -338,6 +378,7 @@ class HashJoinGlobalSinkState : public GlobalSinkState { hash_table->ResetForNewIterationSinglePartition(); perfect_join_executor = make_uniq(op, *hash_table); auto use_perfect_hash = CanUsePerfectHashJoin(op, *perfect_join_executor); + can_use_perfect_hash = use_perfect_hash; finalized = false; active_local_states = 0; external = Settings::Get(context); @@ -360,6 +401,8 @@ class HashJoinGlobalSinkState : public GlobalSinkState { } global_filter_state = op.filter_pushdown->GetGlobalState(context, op); } + // Keep the published layout across CTE iterations (same upstream operator, same arrival types). + // ResetForNewIterationSinglePartition already cleared the row data and dict_registry. GlobalSinkState::Reset(context); } @@ -406,6 +449,15 @@ class HashJoinGlobalSinkState : public GlobalSinkState { bool skip_filter_pushdown = false; unique_ptr global_filter_state; bool keep_local_hash_tables = false; + + //! Coordinates first-chunk publication of the TupleDataLayout across parallel sinks. + LayoutGate layout_gate; + //! True iff this join may use perfect-hash-join at Finalize. PHJ's FullScanHashTable reads payload at native + //! width, so it disables dict-surviving slot narrowing. + bool can_use_perfect_hash = false; + //! True iff the build subtree funnels multiple producer pipelines into this sink (UNION ALL / + //! recursive CTE); disables dict-surviving because the first-chunk layout election is unsound there. + bool build_side_multi_source = false; }; unique_ptr JoinFilterPushdownInfo::GetLocalState(JoinFilterGlobalState &gstate) const { @@ -430,7 +482,8 @@ class HashJoinLocalSinkState : public LocalSinkState { } hash_table = op.InitializeHashTable(context, gstate.hash_table->GetRadixBits()); - hash_table->GetSinkCollection().InitializeAppendState(append_state); + // sink_collection exists only after the layout is published on the first build chunk, so + // InitializeAppendState runs lazily inside Sink. keep_hash_table = gstate.keep_local_hash_tables; gstate.active_local_states++; @@ -443,6 +496,8 @@ class HashJoinLocalSinkState : public LocalSinkState { public: const PhysicalHashJoin &op; PartitionedTupleDataAppendState append_state; + //! True once InitializeAppendState has been called against the published sink_collection + bool append_state_initialised = false; ExpressionExecutor join_key_executor; DataChunk join_keys; @@ -463,12 +518,16 @@ class HashJoinLocalSinkState : public LocalSinkState { auto &gstate = gstate_p.Cast(); join_keys.Reset(); payload_chunk.Reset(); - if (hash_table) { + if (hash_table && append_state_initialised) { + // the layout survives the iteration; only the row data is dropped hash_table->ResetForNewIterationSinglePartition(); + hash_table->GetSinkCollection().ResetAppendState(append_state); } else { + // HT was moved into gstate during Combine, or never had a layout published. Rebuild against the global + // radix_bits so partition counts stay consistent in PrepareFinalize. hash_table = op.InitializeHashTable(context.client, gstate.hash_table->GetRadixBits()); + append_state_initialised = false; } - hash_table->GetSinkCollection().ResetAppendState(append_state); keep_hash_table = gstate.keep_local_hash_tables; gstate.active_local_states++; if (op.filter_pushdown) { @@ -479,6 +538,122 @@ class HashJoinLocalSinkState : public LocalSinkState { } }; +//! Map a dict-index byte width to its row-store slot type (the width is decided by GetDictSurvivingIndexWidth) +static LogicalType DictIndexType(uint8_t index_width) { + switch (index_width) { + case sizeof(uint8_t): + return LogicalType::UTINYINT; + case sizeof(uint16_t): + return LogicalType::USMALLINT; + default: + return LogicalType::UINTEGER; + } +} + +//! Build the row layout [conditions, build payload, (found flag), hash], narrowing a payload column to its +//! dict-index slot when dict_index_width[col] != 0. Shared by publisher and empty-input fallback to avoid drift. +static shared_ptr BuildJoinLayout(const vector &cond_types, + const vector &build_types, JoinType join_type, + const vector &dict_index_width) { + vector layout_types(cond_types); + for (idx_t col = 0; col < build_types.size(); col++) { + if (col < dict_index_width.size() && dict_index_width[col] != 0) { + layout_types.emplace_back(DictIndexType(dict_index_width[col])); + } else { + layout_types.emplace_back(build_types[col]); + } + } + if (PropagatesBuildSide(join_type)) { + layout_types.emplace_back(LogicalType::BOOLEAN); + } + layout_types.emplace_back(LogicalType::HASH); + + auto layout = make_shared_ptr(); + layout->Initialize(layout_types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); + return layout; +} + +//! Join-level gate: shape-only eligibility checks, mirroring the dict-emission path plus a PHJ exclusion +static bool CanUseDictSurvivingJoin(const PhysicalHashJoin &op, const JoinHashTable &ht, bool can_use_perfect_hash, + bool build_side_multi_source) { + // external is safe here: the dictionary is an in-memory self-owned copy and the index is a plain row-store + // column, so a spill/repartition preserves both (unlike the pointer-embedding dict-emission/compressed-probe paths) + // a multi-source build can deliver a later chunk flat or as a different dictionary under the + // already-narrowed slot, so disqualify the whole join (see BuildSideHasMultipleSources) + if (build_side_multi_source) { + return false; + } + // SINGLE joins need FlatVector::SetNull on unmatched rows; dictionary vectors cannot supply it + if (ht.join_type == JoinType::SINGLE) { + return false; + } + // LEFT may dispatch into NextUniqueLeftJoin, which gathers payload via ScanStructure::GatherResult and + // bypasses GatherRHS' dict branch; the narrowed slot would be read as native type and trip the gather type check. + if (ht.join_type == JoinType::LEFT) { + return false; + } + // OUTER fills unmatched-probe rows with CONSTANT_NULL (NextLeftJoin), mixing dict chunks with flat fill chunks; + // admitting it would re-emit a falsely global dictionary a downstream consumer cannot trust. + if (ht.join_type == JoinType::OUTER) { + return false; + } + if (op.rhs_output_columns.col_types.empty()) { + return false; + } + // PHJ's FullScanHashTable reads payload from the row store at native width; a narrowed slot would corrupt it. + if (can_use_perfect_hash) { + return false; + } + return true; +} + +bool HashJoinGlobalSinkState::DictSurvivingActive() const { + if (!hash_table) { + return false; + } + for (const auto &entry : hash_table->dict_registry) { + if (entry) { + return true; + } + } + return false; +} + +void HashJoinGlobalSinkState::PublishLayoutIfFirst(HashJoinLocalSinkState &lstate, DataChunk &payload_chunk) { + if (layout_gate.published.load(std::memory_order_acquire)) { + return; + } + unique_lock guard(layout_gate.publish_mutex); + if (layout_gate.published.load(std::memory_order_relaxed)) { + return; + } + + const auto &cond_types = lstate.hash_table->condition_types; + const auto &build_types = lstate.hash_table->build_types; + layout_gate.dict_index_width.assign(build_types.size(), 0); + + if (CanUseDictSurvivingJoin(op, *lstate.hash_table, can_use_perfect_hash, build_side_multi_source)) { + // Per-column width decision lives on the JHT (GetDictSurvivingIndexWidth); feed it each arriving vector. + for (idx_t col = 0; col < build_types.size(); col++) { + if (col >= payload_chunk.ColumnCount()) { + continue; + } + layout_gate.dict_index_width[col] = + lstate.hash_table->GetDictSurvivingIndexWidth(col, payload_chunk.data[col]); + } + } + + auto layout = BuildJoinLayout(cond_types, build_types, lstate.hash_table->join_type, layout_gate.dict_index_width); + layout_gate.layout_ptr = layout; + + // global HT receives the same layout so Merge/Combine and finalize-time scans operate against it + if (hash_table && !hash_table->IsLayoutFinalized()) { + hash_table->FinishInitWithLayout(layout, layout_gate.dict_index_width); + } + + layout_gate.published.store(true, std::memory_order_release); +} + static bool ShouldPrepareBloomFilterBuild(const PhysicalHashJoin &op) { if (!op.filter_pushdown || op.filter_pushdown->probe_info.empty()) { return false; @@ -628,6 +803,16 @@ SinkResultType PhysicalHashJoin::Sink(ExecutionContext &context, DataChunk &chun lstate.payload_chunk.ReferenceColumns(chunk, payload_columns.col_idxs); } + // first-chunk: publish the canonical layout against the actually-arriving vector types + gstate.PublishLayoutIfFirst(lstate, lstate.payload_chunk); + + // lazy per-thread setup against the published layout + if (!lstate.append_state_initialised) { + lstate.hash_table->FinishInitWithLayout(gstate.layout_gate.layout_ptr, gstate.layout_gate.dict_index_width); + lstate.hash_table->GetSinkCollection().InitializeAppendState(lstate.append_state); + lstate.append_state_initialised = true; + } + // build the HT lstate.hash_table->Build(lstate.append_state, lstate.join_keys, lstate.payload_chunk); @@ -642,9 +827,16 @@ SinkCombineResultType PhysicalHashJoin::Combine(ExecutionContext &context, Opera auto &gstate = input.global_state.Cast(); auto &lstate = input.local_state.Cast(); - lstate.hash_table->GetSinkCollection().FlushAppendState(lstate.append_state); + // Under deferred layout, a thread that never received a Sink chunk has no sink_collection to flush + const bool has_layout = lstate.append_state_initialised; + if (has_layout) { + lstate.hash_table->GetSinkCollection().FlushAppendState(lstate.append_state); + } annotated_lock_guard guard(gstate.lock); - if (lstate.keep_hash_table) { + if (!has_layout) { + // nothing to merge — drop the empty thread-local hash table + gstate.active_local_states--; + } else if (lstate.keep_hash_table) { gstate.local_hash_tables.push_back(*lstate.hash_table); } else { gstate.owned_local_hash_tables.push_back(std::move(lstate.hash_table)); @@ -736,6 +928,24 @@ static idx_t GetPartitioningSpaceRequirement(ClientContext &context, const vecto void PhysicalHashJoin::PrepareFinalize(ClientContext &context, GlobalSinkState &global_state) const { auto &gstate = global_state.Cast(); + // If no Sink chunk ever arrived, the layout was never published. Fall back to a default layout + // (all columns at their native width) so finalize-time scans can dereference data_collection. + if (!gstate.layout_gate.published.load(std::memory_order_acquire)) { + unique_lock guard(gstate.layout_gate.publish_mutex); + if (!gstate.layout_gate.published.load(std::memory_order_relaxed)) { + const auto &cond_types = gstate.hash_table->condition_types; + const auto &build_types = gstate.hash_table->build_types; + gstate.layout_gate.dict_index_width.assign(build_types.size(), 0); + // all-zero dict_index_width => BuildJoinLayout keeps every build column at its native width + auto layout = BuildJoinLayout(cond_types, build_types, gstate.hash_table->join_type, + gstate.layout_gate.dict_index_width); + gstate.layout_gate.layout_ptr = layout; + if (!gstate.hash_table->IsLayoutFinalized()) { + gstate.hash_table->FinishInitWithLayout(layout); + } + gstate.layout_gate.published.store(true, std::memory_order_release); + } + } const auto &ht = *gstate.hash_table; gstate.total_size = @@ -1291,29 +1501,49 @@ static void CreateDynamicMinMaxFilter(const PhysicalComparisonJoin &op, const Jo } } -static unique_ptr CreateComparisonExpressionFilter(ExpressionType comparison_type, const Value &constant, - const LogicalType &column_type) { +static unique_ptr CreateComparisonExpressionFilter(ExpressionType comparison_type, + unique_ptr input, const Value &constant, + const LogicalType &comparison_logical_type) { auto constant_value = constant; if (!constant_value.IsNull()) { - constant_value.DefaultTryCastAs(column_type); + constant_value.DefaultTryCastAs(comparison_logical_type); } - auto column = make_uniq(column_type, 0ULL); - return BoundComparisonExpression::Create(comparison_type, std::move(column), + return BoundComparisonExpression::Create(comparison_type, std::move(input), make_uniq(std::move(constant_value))); } +static unique_ptr CreateComparisonExpressionFilter(ExpressionType comparison_type, const Value &constant, + const LogicalType &column_type) { + auto column = make_uniq(column_type, 0ULL); + return CreateComparisonExpressionFilter(comparison_type, std::move(column), constant, column_type); +} + +static unique_ptr +CreateJoinFilterComparisonExpression(ClientContext &context, const JoinFilterPushdownColumn &column, + ExpressionType comparison_type, const Value &constant, + const LogicalType &comparison_logical_type, bool reconstruct_expression) { + if (!reconstruct_expression) { + return CreateComparisonExpressionFilter(comparison_type, constant, comparison_logical_type); + } + auto input = CreateRuntimeFilterInputExpression(context, column, comparison_logical_type); + return CreateComparisonExpressionFilter(comparison_type, std::move(input), constant, comparison_logical_type); +} + static void CreateDynamicMinMaxFilters(const PhysicalComparisonJoin &op, const JoinFilterPushdownFilter &info, + ClientContext &context, const JoinFilterPushdownColumn &column, ProjectionIndex filter_col_idx, ExpressionType cmp, const Value &min_val, const Value &max_val, const LogicalType &condition_type, - bool selectivity_optional) { + bool reconstruct_expression, bool selectivity_optional) { + auto filter_column_type = reconstruct_expression ? column.storage_type : condition_type; switch (cmp) { case ExpressionType::COMPARE_EQUAL: case ExpressionType::COMPARE_GREATERTHAN: case ExpressionType::COMPARE_GREATERTHANOREQUALTO: { - CreateDynamicMinMaxFilter( - op, info, filter_col_idx, - CreateComparisonExpressionFilter(ExpressionType::COMPARE_GREATERTHANOREQUALTO, min_val, condition_type), - condition_type, selectivity_optional); + CreateDynamicMinMaxFilter(op, info, filter_col_idx, + CreateJoinFilterComparisonExpression(context, column, + ExpressionType::COMPARE_GREATERTHANOREQUALTO, + min_val, condition_type, reconstruct_expression), + filter_column_type, selectivity_optional); break; } default: @@ -1323,10 +1553,11 @@ static void CreateDynamicMinMaxFilters(const PhysicalComparisonJoin &op, const J case ExpressionType::COMPARE_EQUAL: case ExpressionType::COMPARE_LESSTHAN: case ExpressionType::COMPARE_LESSTHANOREQUALTO: { - CreateDynamicMinMaxFilter( - op, info, filter_col_idx, - CreateComparisonExpressionFilter(ExpressionType::COMPARE_LESSTHANOREQUALTO, max_val, condition_type), - condition_type, selectivity_optional); + CreateDynamicMinMaxFilter(op, info, filter_col_idx, + CreateJoinFilterComparisonExpression(context, column, + ExpressionType::COMPARE_LESSTHANOREQUALTO, + max_val, condition_type, reconstruct_expression), + filter_column_type, selectivity_optional); break; } default: @@ -1360,9 +1591,14 @@ unique_ptr JoinFilterPushdownInfo::FinalizeFilters(ClientContext &con auto min_val = min_val_before_cast; auto max_val = max_val_before_cast; + auto runtime_filter_input_type = GetRuntimeFilterInputType(pushdown_column, min_val_before_cast.type()); + const bool reconstruct_filter_expression = + pushdown_column.mode == JoinFilterPushdownMode::RECONSTRUCT_EXPRESSION && + pushdown_column.storage_type.id() == LogicalTypeId::VARIANT && + runtime_filter_input_type != pushdown_column.storage_type; // Cast to storage type, skip if fails - if (pushdown_column.storage_type.IsValid()) { + if (pushdown_column.storage_type.IsValid() && !reconstruct_filter_expression) { if (!min_val.DefaultTryCastAs(pushdown_column.storage_type)) { continue; } @@ -1379,7 +1615,7 @@ unique_ptr JoinFilterPushdownInfo::FinalizeFilters(ClientContext &con } auto condition_type = min_val.type(); - auto runtime_filter_input_type = GetRuntimeFilterInputType(pushdown_column, condition_type); + runtime_filter_input_type = GetRuntimeFilterInputType(pushdown_column, condition_type); bool can_emit_runtime_filters = pushdown_column.mode == JoinFilterPushdownMode::RECONSTRUCT_EXPRESSION; if (can_emit_runtime_filters && ht) { can_emit_runtime_filters = runtime_filter_input_type == ht->conditions[0].GetLHS().GetReturnType(); @@ -1391,12 +1627,14 @@ unique_ptr JoinFilterPushdownInfo::FinalizeFilters(ClientContext &con // Note that this also works for equalities. info.dynamic_filters->PushFilter( op, filter_col_idx, - make_uniq(CreateComparisonExpressionFilter(cmp, min_val, condition_type))); + make_uniq(CreateJoinFilterComparisonExpression( + context, pushdown_column, cmp, min_val, condition_type, reconstruct_filter_expression))); } else { if (cmp != ExpressionType::COMPARE_EQUAL) { // min != max - generate range filters for non-equality comparisons. // For non-equalities, the range must be half-open. - CreateDynamicMinMaxFilters(op, info, filter_col_idx, cmp, min_val, max_val, condition_type, true); + CreateDynamicMinMaxFilters(op, info, context, pushdown_column, filter_col_idx, cmp, min_val, + max_val, condition_type, reconstruct_filter_expression, true); continue; } @@ -1432,7 +1670,8 @@ unique_ptr JoinFilterPushdownInfo::FinalizeFilters(ClientContext &con } if (!pushed_in_filter) { - CreateDynamicMinMaxFilters(op, info, filter_col_idx, cmp, min_val, max_val, condition_type, false); + CreateDynamicMinMaxFilters(op, info, context, pushdown_column, filter_col_idx, cmp, min_val, + max_val, condition_type, reconstruct_filter_expression, false); } if (allow_bloom_filters && can_emit_runtime_filters && ht && CanUseBloomFilter(context, op, cmp, ht)) { PushBloomFilter(context, op, *ht, info, filter_idx, filter_col_idx); @@ -1546,6 +1785,11 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl // check for possible perfect hash table auto use_perfect_hash = sink.perfect_join_executor->CanDoPerfectHashJoin(*this, min, max); + // PHJ's FullScanHashTable reads payload at native width; if any slot was narrowed it would crash. Runtime + // min/max from filter pushdown can re-enable PHJ here, so re-check after dict-surviving may have narrowed. + if (use_perfect_hash && sink.DictSurvivingActive()) { + use_perfect_hash = false; + } if (use_perfect_hash) { use_perfect_hash = sink.perfect_join_executor->BuildPerfectHashTable(); } diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp index bf09e17ad..e2bfdca1f 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp @@ -1,9 +1,9 @@ #include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" -#include "duckdb/common/optional.hpp" #include "duckdb/common/file_opener.hpp" #include "duckdb/common/file_system.hpp" #include "duckdb/common/hive_partitioning.hpp" +#include "duckdb/common/optional.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/common/sorting/sort_strategy.hpp" #include "duckdb/common/types/column/column_data_collection_segment.hpp" @@ -12,16 +12,24 @@ #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/operator/logical_copy_to_file.hpp" #include "duckdb/parallel/base_pipeline_event.hpp" +#include "duckdb/parallel/task_executor.hpp" #include "duckdb/main/settings.hpp" #include "fmt/format.h" #include +#include +#include #include +#include namespace duckdb { //===--------------------------------------------------------------------===// -// Util +// Declarations +//===--------------------------------------------------------------------===// + +//===--------------------------------------------------------------------===// +// Utility Declarations //===--------------------------------------------------------------------===// enum class PhysicalCopyToFilePhase : uint8_t { SINK, COMBINE, FINALIZE }; @@ -67,81 +75,347 @@ struct VectorOfValuesLess { template using vector_of_value_map_t = unordered_map, T, VectorOfValuesHashFunction, VectorOfValuesEquality>; -void CheckDirectory(FileSystem &fs, const string &file_path, CopyOverwriteMode overwrite_mode) { - if (overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE_OR_IGNORE || - overwrite_mode == CopyOverwriteMode::COPY_APPEND) { - // with overwrite or ignore we fully ignore the presence of any files instead of erasing them - return; +//===--------------------------------------------------------------------===// +// Copy File State Types +//===--------------------------------------------------------------------===// +struct GlobalFileState { +public: + explicit GlobalFileState(unique_ptr data_p, const string &path_p) + : data(std::move(data_p)), path(path_p), num_batches(0) { } - vector file_list; - vector directory_list; - directory_list.push_back(file_path); - for (idx_t dir_idx = 0; dir_idx < directory_list.size(); dir_idx++) { - auto directory = directory_list[dir_idx]; - fs.ListFiles(directory, [&](const string &path, bool is_directory) { - auto full_path = fs.JoinPath(directory, path); - if (is_directory) { - directory_list.emplace_back(std::move(full_path)); - } else { - file_list.emplace_back(std::move(full_path)); - } - }); + +public: + annotated_mutex lock; + unique_ptr data; + const string path; + idx_t num_batches DUCKDB_GUARDED_BY(lock); +}; + +//===--------------------------------------------------------------------===// +// Copy File State Declarations +//===--------------------------------------------------------------------===// +struct PendingFileState { + string output_path; + optional_ptr written_file_info; +}; + +struct PartitionDirectory { + string path; + vector directories; +}; + +enum class CopyDirectoryState : uint8_t { PENDING, COMPLETE, FAILED }; + +class CopyDirectoryManager { +public: + void EnsureDirectory(FileSystem &fs, const string &dir_path); + +private: + struct DirectoryEntry { + CopyDirectoryState state = CopyDirectoryState::PENDING; + std::exception_ptr error; + }; + +private: + mutex lock; + std::condition_variable condition; + unordered_map directories; +}; + +class CopyOutputFileRegistry { +public: + explicit CopyOutputFileRegistry(const PhysicalCopyToFile &op_p) : op(op_p) { } - if (file_list.empty()) { - return; + +public: + //! The registry is guarded by CopyToFileGlobalState::lock. + PendingFileState ReserveFile(string output_path, optional_ptr> partition_values); + void PublishCreatedPath(PendingFileState &pending_file_state, string output_path); + + idx_t WrittenFileCount() const { + return written_files.size(); } - if (overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE) { - fs.RemoveFiles(file_list); - } else { - throw IOException("Directory \"%s\" is not empty! Enable OVERWRITE option to overwrite files", file_path); + + CopyToFileInfo &GetWrittenFile(idx_t file_idx) { + return *written_files[file_idx]; } -} -struct PhysicalCopyToFileColumnStatsMapData { - vector keys; - vector values; + const vector> &GetWrittenFiles() const { + return written_files; + } + + bool HasCreatedFiles() const { + return !created_files.empty(); + } + + const vector &GetCreatedFiles() const { + return created_files; + } + +private: + optional_ptr AddFile(const string &file_name); + +private: + const PhysicalCopyToFile &op; + vector created_files; + vector> written_files; }; -static PhysicalCopyToFileColumnStatsMapData -CreateColumnStatistics(const case_insensitive_map_t> &column_statistics) { - PhysicalCopyToFileColumnStatsMapData result; +//===--------------------------------------------------------------------===// +// Copy File Lifecycle Declarations +//===--------------------------------------------------------------------===// +enum class CopyFileLifecycleWaitMode : uint8_t { INTERRUPTIBLE, DRAIN }; - //! Use a map to make sure the result has a consistent ordering - map stats; - for (auto &entry : column_statistics) { - map per_column_stats; - for (auto &stats_entry : entry.second) { - per_column_stats.emplace(stats_entry.first, stats_entry.second); +class CopyFileLifecycleJob { +public: + bool IsFinished() const { + return finished.load(std::memory_order_acquire); + } + + void Complete() { + finished.store(true, std::memory_order_release); + } + + void CompleteException(const std::exception_ptr &error_p) { + error = error_p; + Complete(); + } + + void Rethrow() const { + if (error) { + std::rethrow_exception(error); } - vector stats_keys; - vector stats_values; - for (auto &stats_entry : per_column_stats) { - stats_keys.emplace_back(stats_entry.first); - stats_values.emplace_back(std::move(stats_entry.second)); + } + +private: + atomic finished {false}; + std::exception_ptr error; +}; + +class FileStateOpenJob : public CopyFileLifecycleJob { +public: + void Complete(unique_ptr file_state_p) { + file_state = std::move(file_state_p); + CopyFileLifecycleJob::Complete(); + } + + GlobalFileState &GetFileState() const { + D_ASSERT(IsFinished()); + Rethrow(); + D_ASSERT(file_state); + return *file_state; + } + + unique_ptr TakeFileState() { + D_ASSERT(IsFinished()); + Rethrow(); + return std::move(file_state); + } + +private: + unique_ptr file_state; +}; + +struct PendingFileStateOpen { + PendingFileState pending_file_state; + shared_ptr open_job; + + explicit operator bool() const { + return open_job.get(); + } +}; + +class CopyToFileGlobalState; + +struct PartitionFileOpenRequest { + PartitionFileOpenRequest(PendingFileStateOpen pending_file_state_open, PartitionDirectory directory_p, + idx_t offset_p) + : pending_file_state(std::move(pending_file_state_open.pending_file_state)), + open_job(std::move(pending_file_state_open.open_job)), directory(std::move(directory_p)), offset(offset_p) { + } + + void Run(CopyToFileGlobalState ©_gstate); + + PendingFileState pending_file_state; + shared_ptr open_job; + PartitionDirectory directory; + idx_t offset; +}; + +struct FileStateHandle { +public: + FileStateHandle() = default; + FileStateHandle(FileStateHandle &&) = default; + FileStateHandle &operator=(FileStateHandle &&) = default; + FileStateHandle(const FileStateHandle &) = delete; + FileStateHandle &operator=(const FileStateHandle &) = delete; + +public: + bool HasFileState() const { + return open_job.get(); + } + + bool IsReady() const { + return open_job && open_job->IsFinished(); + } + + GlobalFileState &GetFileState() const { + D_ASSERT(open_job); + return open_job->GetFileState(); + } + + optional_ptr GetFileStatePtr() const { + if (!IsReady()) { + return nullptr; } - auto map_value = - Value::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR, std::move(stats_keys), std::move(stats_values)); - stats.emplace(entry.first, std::move(map_value)); + return open_job->GetFileState(); } - for (auto &entry : stats) { - result.keys.emplace_back(entry.first); - result.values.emplace_back(std::move(entry.second)); + + unique_ptr TakeFileState() { + if (!open_job) { + return nullptr; + } + auto result = open_job->TakeFileState(); + open_job.reset(); + return result; } - return result; -} -struct GlobalFileState { - explicit GlobalFileState(unique_ptr data_p, const string &path_p) - : data(std::move(data_p)), path(path_p), num_batches(0) { + explicit operator bool() const { + return HasFileState(); } - annotated_mutex lock; - unique_ptr data; - const string path; - idx_t num_batches DUCKDB_GUARDED_BY(lock); + +public: + shared_ptr open_job; }; -static bool PhysicalCopyRotateNow(const PhysicalCopyToFile &op, GlobalFileState &global_state) - DUCKDB_REQUIRES(global_state.lock); +struct PartitionFileRequest { + PartitionFileRequest(PartitionFileOpenRequest open_request_p, vector files_to_finalize_p) + : open_request(std::move(open_request_p)), files_to_finalize(std::move(files_to_finalize_p)) { + } + + shared_ptr OpenJob() const { + return open_request.open_job; + } + + PartitionFileOpenRequest open_request; + vector files_to_finalize; +}; + +template +class CopyFileLifecycleTask; + +class CopyFileLifecycleExecutor { + static constexpr idx_t MIN_PENDING_TASKS = 4096; + +public: + explicit CopyFileLifecycleExecutor(ClientContext &context_p) + : context(context_p), executor(context_p, TaskSchedulerType::ASYNC) { + auto &scheduler = TaskScheduler::GetScheduler(context); + async_threads = NumericCast(scheduler.NumberOfAsyncThreads()); + auto regular_threads = NumericCast(scheduler.NumberOfThreads()); + max_pending_tasks = MaxValue(MIN_PENDING_TASKS, (async_threads + regular_threads) * 4); + } + +public: + template + void Schedule(shared_ptr job, CopyFileLifecycleWaitMode mode, FUNC &&task); + void WaitForJob(CopyFileLifecycleJob &job, CopyFileLifecycleWaitMode mode); + void WaitAll(CopyFileLifecycleWaitMode mode); + void WorkOnTaskOrYield(); + void FinishTask(); + void PushError(const std::exception_ptr &error); + +private: + bool WorkOnTask(bool throw_error = true); + void WaitForTaskSlot(CopyFileLifecycleWaitMode mode); + void ThrowError(); + +private: + ClientContext &context; + TaskExecutor executor; + idx_t async_threads; + idx_t max_pending_tasks; + atomic pending_tasks {0}; + mutex error_lock; + std::exception_ptr error; +}; + +class CopyFileLifecycleTaskFinishGuard { +public: + CopyFileLifecycleTaskFinishGuard(TaskExecutor &executor_p, CopyFileLifecycleExecutor &lifecycle_p) + : executor(executor_p), lifecycle(lifecycle_p) { + } + + ~CopyFileLifecycleTaskFinishGuard() { + Finish(); + } + + void Finish() { + if (!finished) { + lifecycle.FinishTask(); + executor.FinishTask(); + finished = true; + } + } + +private: + TaskExecutor &executor; + CopyFileLifecycleExecutor &lifecycle; + bool finished = false; +}; + +template +class CopyFileLifecycleTask : public Task { +public: + CopyFileLifecycleTask(TaskExecutor &executor_p, CopyFileLifecycleExecutor &lifecycle_p, + shared_ptr job_p, FUNC task_p) + : executor(executor_p), lifecycle(lifecycle_p), job(std::move(job_p)), task(std::move(task_p)) { + } + +public: + TaskExecutionResult Execute(TaskExecutionMode mode) override { + CopyFileLifecycleTaskFinishGuard finish_guard(executor, lifecycle); + try { + task(); + if (!job->IsFinished()) { + job->Complete(); + } + } catch (...) { + auto error = std::current_exception(); + job->CompleteException(error); + lifecycle.PushError(error); + } + return TaskExecutionResult::TASK_FINISHED; + } + + string TaskType() const override { + return "CopyFileLifecycleTask"; + } + +private: + TaskExecutor &executor; + CopyFileLifecycleExecutor &lifecycle; + shared_ptr job; + FUNC task; +}; + +template +void CopyFileLifecycleExecutor::Schedule(shared_ptr job, CopyFileLifecycleWaitMode mode, + FUNC &&task) { + WaitForTaskSlot(mode); + auto job_ref = job; + ++pending_tasks; + try { + using TaskType = CopyFileLifecycleTask::type>; + executor.ScheduleTask(make_uniq(executor, *this, std::move(job), std::forward(task))); + } catch (...) { + --pending_tasks; + throw; + } + if (async_threads == 0) { + WaitForJob(*job_ref, mode); + } +} //===--------------------------------------------------------------------===// // Copy State Declarations @@ -154,23 +428,42 @@ class CopyToFileGlobalState : public GlobalSinkState { ~CopyToFileGlobalState() override; public: - void Initialize(); + void Initialize() DUCKDB_EXCLUDES(lock); - void CreateDir(const string &dir_path) DUCKDB_REQUIRES(lock); - unique_ptr CreateFileStateLocked(string output_path = string(), - optional_ptr> partition_values = nullptr) + PendingFileState PrepareFileStateLocked(string output_path = string(), + optional_ptr> partition_values = nullptr) + DUCKDB_REQUIRES(lock); + PendingFileStateOpen CreateFileStateOpenLocked(FileStateHandle &file_state, string output_path = string(), + optional_ptr> partition_values = nullptr) DUCKDB_REQUIRES(lock); - unique_ptr CreateFileState(string output_path = string(), - optional_ptr> partition_values = nullptr) + PendingFileStateOpen CreatePartitionFileStateOpenLocked(FileStateHandle &file_state, string output_path, + optional_ptr> partition_values) + DUCKDB_REQUIRES(lock); + unique_ptr InitializeFileState(PendingFileState pending_file_state) DUCKDB_EXCLUDES(lock); + void RegisterPrepareGlobalStateLocked(GlobalFileState &file_state) DUCKDB_REQUIRES(lock); + void ScheduleOutputDirectorySetup() DUCKDB_EXCLUDES(lock); + void EnsureOutputDirectoryReady() DUCKDB_EXCLUDES(lock); + void ScheduleFileStateOpen(PendingFileStateOpen pending_file_state_open) DUCKDB_EXCLUDES(lock); + void SchedulePartitionFileStateOpen(PartitionFileOpenRequest request) DUCKDB_EXCLUDES(lock); + void RequestFileState(FileStateHandle &file_state, string output_path = string(), + optional_ptr> partition_values = nullptr) DUCKDB_EXCLUDES(lock); + GlobalFileState &EnsureFileStateReady(FileStateHandle &file_state, + const std::function &create_file_state_fun) DUCKDB_EXCLUDES(lock); - unique_ptr FinalizeFileStateLocked(unique_ptr file_state) DUCKDB_REQUIRES(lock); - void FinalizeFileState(unique_ptr file_state) DUCKDB_EXCLUDES(lock); + FileStateHandle FinalizeFileStateLocked(FileStateHandle file_state) DUCKDB_REQUIRES(lock); + void FinalizeFileState(FileStateHandle file_state) DUCKDB_EXCLUDES(lock); - unique_ptr TryFinalizeOwnedFileStateLocked() DUCKDB_REQUIRES(lock); + FileStateHandle TryFinalizeOwnedFileStateLocked() DUCKDB_REQUIRES(lock); void TryFinalizeOwnedFileState() DUCKDB_EXCLUDES(lock); + void WaitForLifecycleTasks() DUCKDB_EXCLUDES(lock); private: - optional_ptr AddFile(const string &file_name) DUCKDB_REQUIRES(lock); + void PrepareOutputDirectory() DUCKDB_EXCLUDES(lock); + void EnsureDirectory(const string &dir_path) DUCKDB_EXCLUDES(lock); + void RegisterPendingFileStatePathLocked(PendingFileState &pending_file_state, string output_path) + DUCKDB_REQUIRES(lock); + + friend struct PartitionFileOpenRequest; public: const PhysicalCopyToFile &op; @@ -188,13 +481,17 @@ class CopyToFileGlobalState : public GlobalSinkState { //! Therefore, we must delay deciding which file to flush; otherwise, parallel writes overshoot //! All Prepare are done against this state atomic> prepare_global_state; - unique_ptr prepare_global_state_owned DUCKDB_GUARDED_BY(lock); + FileStateHandle prepare_global_state_owned DUCKDB_GUARDED_BY(lock); //! The (current) global state - unique_ptr global_state; + FileStateHandle global_state; //! Lambda to create a new global file state - const std::function()> create_file_state_fun; - unordered_set *> creating_file_states DUCKDB_GUARDED_BY(lock); + const std::function create_file_state_fun; + //! Asynchronously prepares the root output directory for directory-style COPY outputs. + shared_ptr output_directory_job; + CopyFileLifecycleExecutor lifecycle_executor; + CopyDirectoryManager directory_manager; + CopyOutputFileRegistry output_files; //! The final batch mutable annotated_mutex last_batch_lock; @@ -203,20 +500,13 @@ class CopyToFileGlobalState : public GlobalSinkState { //! Partitioning state unique_ptr partitioned_copy; - //! Created directories - unordered_set created_directories DUCKDB_GUARDED_BY(lock); - //! The list of files created by this operator - vector created_files DUCKDB_GUARDED_BY(lock); - //! Written file info and stats - vector> written_files DUCKDB_GUARDED_BY(lock); - //! Counters atomic rows_copied; atomic last_file_offset; }; //===--------------------------------------------------------------------===// -// Copy Local State Declaration +// Copy Local State Declarations //===--------------------------------------------------------------------===// class PartitionedCopyLocalState; @@ -230,7 +520,7 @@ class CopyToFileLocalState : public LocalSinkState { CopyToFileGlobalState &gstate; //! Global/local file state (unpartitioned write) - unique_ptr global_file_state; + FileStateHandle global_file_state; unique_ptr local_file_state; //! Current append batch (unpartitioned write) @@ -245,7 +535,7 @@ class CopyToFileLocalState : public LocalSinkState { }; //===--------------------------------------------------------------------===// -// Partitioned Copy Declarations +// Partitioned Copy Type Declarations //===--------------------------------------------------------------------===// enum class PartitionedCopyStage : uint8_t { SORT, MATERIALIZE, MASK, BATCH, PREPARE, FLUSH, DONE }; enum class FileCreationReason : uint8_t { NORMAL, SORTED_RUN_BOUNDARY, ROTATION }; @@ -273,16 +563,111 @@ struct PartitionedCopyBatch { struct PartitionWriteInfo { //! Serializes operations that need a complete partition writer run boundary. annotated_mutex lock; - unique_ptr file_state; + FileStateHandle file_state; idx_t active_writes = 0; }; struct PartitionFileStateReservation { - vector> files_to_finalize; + vector files_to_finalize; idx_t offset = 0; }; +//===--------------------------------------------------------------------===// +// Partition Write Manager Declarations +//===--------------------------------------------------------------------===// +class PartitionWriteManager; + +class PartitionWriteLease { +public: + PartitionWriteLease() = default; + PartitionWriteLease(PartitionWriteManager &manager_p, PartitionWriteInfo &write_info_p); + PartitionWriteLease(PartitionWriteLease &&other) noexcept; + PartitionWriteLease &operator=(PartitionWriteLease &&other) noexcept; + PartitionWriteLease(const PartitionWriteLease &) = delete; + PartitionWriteLease &operator=(const PartitionWriteLease &) = delete; + ~PartitionWriteLease(); + +public: + explicit operator bool() const { + return write_info.get(); + } + + PartitionWriteInfo &Get() const { + D_ASSERT(write_info); + return *write_info.get_mutable(); + } + + PartitionWriteInfo &operator*() const { + return Get(); + } + + PartitionWriteInfo *operator->() const { + return &Get(); + } + + void Reset(); + +private: + optional_ptr manager; + optional_ptr write_info; +}; + +class PartitionWriteManager { +public: + class ReservationLock { + public: + ReservationLock(ReservationLock &&) noexcept = default; + ReservationLock &operator=(ReservationLock &&) noexcept = default; + ReservationLock(const ReservationLock &) = delete; + ReservationLock &operator=(const ReservationLock &) = delete; + + private: + explicit ReservationLock(annotated_mutex &lock_p) : guard(lock_p) { + } + + private: + annotated_unique_lock guard; + friend class PartitionWriteManager; + }; + +public: + PartitionWriteManager(const PhysicalCopyToFile &op_p, ClientContext &context_p) : op(op_p), context(context_p) { + } + +public: + PartitionWriteLease Acquire(const vector &values) DUCKDB_EXCLUDES(lock); + ReservationLock LockForReservation() DUCKDB_EXCLUDES(lock); + PartitionFileStateReservation ReserveFileState(ReservationLock &reservation_lock, const vector &values, + FileCreationReason reason) DUCKDB_NO_THREAD_SAFETY_ANALYSIS; + vector TakeOpenFileStates() DUCKDB_EXCLUDES(lock); + +private: + void Release(PartitionWriteInfo &write_info) DUCKDB_EXCLUDES(lock); + +private: + const PhysicalCopyToFile &op; + ClientContext &context; + mutable annotated_mutex lock; + vector_of_value_map_t> active_writes DUCKDB_GUARDED_BY(lock); + vector_of_value_map_t previous_partitions DUCKDB_GUARDED_BY(lock); + idx_t global_offset DUCKDB_GUARDED_BY(lock) = 0; + + friend class PartitionWriteLease; +}; + +//===--------------------------------------------------------------------===// +// Partitioned Copy Batch State Declarations +//===--------------------------------------------------------------------===// enum class PartitionedCopyBatchMode : uint8_t { BUFFERING, PREPARING, DELAYED, PREPARED }; +enum class PartitionedCopyBatchActionType : uint8_t { STORE_COLLECTION, ACQUIRE_AND_PREPARE, PREPARE_WITH_WRITE_INFO }; +enum class PartitionedCopyPrepareTaskActionType : uint8_t { + SKIP_PARTITION, + WAIT_FOR_WRITE_INFO, + ACQUIRE_WRITE_INFO, + PREPARE_BATCH +}; +enum class PartitionedCopyPrepareActionType : uint8_t { ACQUIRE_WRITE_INFO, PREPARE_BATCH }; +enum class PartitionedCopyFlushActionType : uint8_t { DELAYED_COLLECTIONS, PREPARED_BATCHES }; enum class PartitionedCopyCollectionSchema : uint8_t { RAW_SCHEMA, WRITE_SCHEMA }; @@ -302,32 +687,39 @@ struct PartitionedCopyCollection { unique_ptr collection; }; -struct PartitionedCopyBatchState { - void SetValues(vector values_p) { - D_ASSERT(values.empty() || values == values_p); - if (values.empty()) { - values = std::move(values_p); - } - } +struct PartitionedCopyBatchAction { + PartitionedCopyBatchActionType type = PartitionedCopyBatchActionType::STORE_COLLECTION; + vector values; + optional_ptr write_info; +}; - idx_t AddCollectionSlot(PartitionedCopyCollectionSchema schema, idx_t row_count) { - D_ASSERT(mode == PartitionedCopyBatchMode::BUFFERING || mode == PartitionedCopyBatchMode::PREPARING); - collections.emplace_back(schema); - if (mode == PartitionedCopyBatchMode::PREPARING && batches.size() < collections.size()) { - batches.emplace_back(); - } - count += row_count; - return collections.size() - 1; - } +struct PartitionedCopyPrepareTaskAction { + PartitionedCopyPrepareTaskActionType type = PartitionedCopyPrepareTaskActionType::SKIP_PARTITION; + idx_t batch_idx = DConstants::INVALID_INDEX; +}; - void StoreCollection(idx_t batch_idx, unique_ptr collection) { - D_ASSERT(mode == PartitionedCopyBatchMode::BUFFERING || mode == PartitionedCopyBatchMode::PREPARING); - D_ASSERT(batch_idx < collections.size()); - collections[batch_idx].collection = std::move(collection); - } +struct PartitionedCopyPrepareAction { + PartitionedCopyPrepareActionType type = PartitionedCopyPrepareActionType::ACQUIRE_WRITE_INFO; + vector values; + optional_ptr write_info; + PartitionedCopyCollection data; +}; + +struct PartitionedCopyFlushAction { + PartitionedCopyFlushActionType type = PartitionedCopyFlushActionType::DELAYED_COLLECTIONS; + vector values; + vector collections; + vector> batches; + PartitionWriteLease write_lease; +}; - bool CanStartPreparing(idx_t flush_threshold, bool has_delayed_partition) const { - return mode == PartitionedCopyBatchMode::BUFFERING && count >= flush_threshold && !has_delayed_partition; +struct PartitionedCopyBatchState { +private: + void SetValues(vector values_p) { + D_ASSERT(values.empty() || values == values_p); + if (values.empty()) { + values = std::move(values_p); + } } void StartPreparing() { @@ -338,24 +730,26 @@ struct PartitionedCopyBatchState { bool TryReserveWriteInfo() { D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); - if (write_info || write_info_requested) { + if (write_lease || write_info_requested) { return false; } write_info_requested = true; return true; } - void SetWriteInfo(PartitionWriteInfo &write_info_p) { + PartitionWriteInfo &SetWriteInfo(PartitionWriteLease write_lease_p) { D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); - D_ASSERT(!write_info || write_info == optional_ptr(write_info_p)); - write_info = write_info_p; + D_ASSERT(!write_lease); + D_ASSERT(write_lease_p); + write_lease = std::move(write_lease_p); write_info_requested = false; batches.resize(collections.size()); + return write_lease.Get(); } void EnsurePreparingBatchSlots() { D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); - D_ASSERT(write_info); + D_ASSERT(write_lease); batches.resize(collections.size()); } @@ -364,43 +758,188 @@ struct PartitionedCopyBatchState { mode = PartitionedCopyBatchMode::DELAYED; } - void MarkPrepared() { - if (mode == PartitionedCopyBatchMode::PREPARING) { - mode = PartitionedCopyBatchMode::PREPARED; + bool NeedsWriteInfo() const { + return mode == PartitionedCopyBatchMode::PREPARING && !write_lease; + } + + bool NeedsPrepare(idx_t batch_idx) const { + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); + D_ASSERT(write_lease); + D_ASSERT(batch_idx < collections.size()); + return collections[batch_idx].collection && (batch_idx >= batches.size() || !batches[batch_idx]); + } + + bool HasWriteInfo() const { + return write_lease.operator bool(); + } + + PartitionWriteInfo &GetWriteInfo() const { + return write_lease.Get(); + } + + PartitionWriteLease TakeWriteLease() { + D_ASSERT(write_lease); + return std::move(write_lease); + } + +public: + idx_t NextCollectionIndex() const { + return collections.size(); + } + + const vector &Values() const { + return values; + } + + idx_t AddCollectionSlot(PartitionedCopyCollectionSchema schema, idx_t row_count) { + D_ASSERT(mode == PartitionedCopyBatchMode::BUFFERING || mode == PartitionedCopyBatchMode::PREPARING); + collections.emplace_back(schema); + if (mode == PartitionedCopyBatchMode::PREPARING && batches.size() < collections.size()) { + batches.emplace_back(); + } + count += row_count; + return collections.size() - 1; + } + + PartitionedCopyBatchAction RegisterBatch(vector values_p, idx_t flush_threshold, + bool has_delayed_partition) { + SetValues(std::move(values_p)); + auto result = PartitionedCopyBatchAction {PartitionedCopyBatchActionType::STORE_COLLECTION, values, nullptr}; + if (mode == PartitionedCopyBatchMode::BUFFERING && count >= flush_threshold && !has_delayed_partition) { + StartPreparing(); + if (TryReserveWriteInfo()) { + result.type = PartitionedCopyBatchActionType::ACQUIRE_AND_PREPARE; + return result; + } + } + + if (mode == PartitionedCopyBatchMode::PREPARING && write_lease) { + result.type = PartitionedCopyBatchActionType::PREPARE_WITH_WRITE_INFO; + result.write_info = GetWriteInfo(); } + return result; } - bool SkipsPrepare() const { - return mode == PartitionedCopyBatchMode::DELAYED || mode == PartitionedCopyBatchMode::PREPARED; + PartitionWriteInfo &CompleteWriteInfoAcquisition(PartitionWriteLease write_lease_p) { + D_ASSERT(write_info_requested); + return SetWriteInfo(std::move(write_lease_p)); } - bool NeedsWriteInfo() const { - return mode == PartitionedCopyBatchMode::PREPARING && !write_info; + void StoreCollection(idx_t batch_idx, unique_ptr collection) { + D_ASSERT(mode == PartitionedCopyBatchMode::BUFFERING || mode == PartitionedCopyBatchMode::PREPARING); + D_ASSERT(batch_idx < collections.size()); + collections[batch_idx].collection = std::move(collection); } - bool IsWriteInfoReserved() const { - return write_info_requested; + idx_t FinalizeBatching(idx_t flush_threshold, bool has_delayed_partition) { + D_ASSERT(!values.empty()); + D_ASSERT(count > 0); + if (mode == PartitionedCopyBatchMode::PREPARING) { + D_ASSERT(HasWriteInfo()); + EnsurePreparingBatchSlots(); + return 0; + } + D_ASSERT(mode == PartitionedCopyBatchMode::BUFFERING); + if (count < flush_threshold || has_delayed_partition) { + MarkDelayed(); + return count; + } + StartPreparing(); + return 0; } - bool NeedsPrepare(idx_t batch_idx) const { + PartitionedCopyPrepareTaskAction SelectPrepareTask(idx_t &prepare_batch_idx) { + if (mode == PartitionedCopyBatchMode::DELAYED || mode == PartitionedCopyBatchMode::PREPARED) { + return {PartitionedCopyPrepareTaskActionType::SKIP_PARTITION}; + } D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); - D_ASSERT(write_info); + if (NeedsWriteInfo()) { + if (!write_info_requested && TryReserveWriteInfo()) { + return {PartitionedCopyPrepareTaskActionType::ACQUIRE_WRITE_INFO}; + } + return {PartitionedCopyPrepareTaskActionType::WAIT_FOR_WRITE_INFO}; + } + while (prepare_batch_idx < collections.size()) { + if (NeedsPrepare(prepare_batch_idx)) { + return {PartitionedCopyPrepareTaskActionType::PREPARE_BATCH, prepare_batch_idx++}; + } + ++prepare_batch_idx; + } + return {PartitionedCopyPrepareTaskActionType::SKIP_PARTITION}; + } + + PartitionedCopyPrepareAction BeginPrepareTask(idx_t batch_idx) { + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); + PartitionedCopyPrepareAction result; + result.values = values; + if (batch_idx == DConstants::INVALID_INDEX) { + D_ASSERT(NeedsWriteInfo()); + D_ASSERT(write_info_requested); + result.type = PartitionedCopyPrepareActionType::ACQUIRE_WRITE_INFO; + return result; + } + D_ASSERT(HasWriteInfo()); D_ASSERT(batch_idx < collections.size()); - return collections[batch_idx].collection && (batch_idx >= batches.size() || !batches[batch_idx]); + result.type = PartitionedCopyPrepareActionType::PREPARE_BATCH; + result.write_info = GetWriteInfo(); + result.data = TakeCollection(batch_idx); + return result; } +private: PartitionedCopyCollection TakeCollection(idx_t batch_idx) { D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING); D_ASSERT(batch_idx < collections.size()); return std::move(collections[batch_idx]); } +public: void StorePreparedBatch(idx_t batch_idx, unique_ptr batch) { D_ASSERT(mode == PartitionedCopyBatchMode::PREPARING || mode == PartitionedCopyBatchMode::PREPARED); D_ASSERT(batch_idx < batches.size()); batches[batch_idx] = std::move(batch); } + void MarkPrepared() { + if (mode == PartitionedCopyBatchMode::PREPARING) { + mode = PartitionedCopyBatchMode::PREPARED; + } + } + + bool ReadyForFlush() const { + if (mode == PartitionedCopyBatchMode::DELAYED) { + return true; + } + if (mode != PartitionedCopyBatchMode::PREPARING && mode != PartitionedCopyBatchMode::PREPARED) { + return false; + } + if (!HasWriteInfo() || batches.size() < collections.size()) { + return false; + } + for (idx_t batch_idx = 0; batch_idx < collections.size(); batch_idx++) { + if (!batches[batch_idx]) { + return false; + } + } + return true; + } + + PartitionedCopyFlushAction TakeFlushAction() { + PartitionedCopyFlushAction result; + result.values = values; + if (mode == PartitionedCopyBatchMode::DELAYED) { + result.type = PartitionedCopyFlushActionType::DELAYED_COLLECTIONS; + result.collections = TakeDelayedCollections(); + return result; + } + D_ASSERT(mode == PartitionedCopyBatchMode::PREPARED); + result.type = PartitionedCopyFlushActionType::PREPARED_BATCHES; + result.write_lease = TakeWriteLease(); + result.batches = TakePreparedBatches(); + return result; + } + +private: vector TakeDelayedCollections() { D_ASSERT(mode == PartitionedCopyBatchMode::DELAYED); return std::move(collections); @@ -411,8 +950,9 @@ struct PartitionedCopyBatchState { return std::move(batches); } +private: vector values; - optional_ptr write_info; + PartitionWriteLease write_lease; vector collections; vector> batches; idx_t count = 0; @@ -420,6 +960,9 @@ struct PartitionedCopyBatchState { PartitionedCopyBatchMode mode = PartitionedCopyBatchMode::BUFFERING; }; +//===--------------------------------------------------------------------===// +// Delayed Partition Buffering Declarations +//===--------------------------------------------------------------------===// struct DelayedPartitionFlush { vector values; PartitionedCopyCollection data; @@ -566,6 +1109,9 @@ class DelayedPartitionBuffers { vector_of_value_map_t partitions DUCKDB_GUARDED_BY(lock); }; +//===--------------------------------------------------------------------===// +// Partitioned Copy Declarations +//===--------------------------------------------------------------------===// //! Manages a single partitioned COPY hash bin class PartitionedCopyHashGroup { public: @@ -592,6 +1138,7 @@ class PartitionedCopyHashGroup { optional TryNextBatchTask() DUCKDB_REQUIRES(lock); optional TryNextPrepareTask() DUCKDB_REQUIRES(lock); optional TryNextFlushTask() DUCKDB_REQUIRES(lock); + bool PreparedBatchesReady() const DUCKDB_REQUIRES(lock); void Sort(ExecutionContext &context, GlobalSinkState &sink, InterruptState &interrupt, const PartitionedCopyTask &task); @@ -727,11 +1274,11 @@ class PartitionedCopy { void Finalize(Pipeline &pipeline, Event &event, InterruptState &interrupt_state); void Flush(ExecutionContext &execution_context, InterruptState &interrupt_state); +public: + //! Partitioning-specific functions void InitializeFlush() DUCKDB_REQUIRES(lock); void FinalizeState(PartitionedCopyState &state, InterruptState &interrupt_state) DUCKDB_REQUIRES(state.lock); - PartitionWriteInfo &GetPartitionWriteInfo(const vector &values) DUCKDB_EXCLUDES(active_writes_lock); - void ReleasePartitionWriteInfo(PartitionWriteInfo &write_info) DUCKDB_EXCLUDES(active_writes_lock); bool HasDelayedPartition(const vector &values); bool HasDelayedPartitions() const; optional BufferOrTakeReadyPartition(const vector &values, @@ -754,118 +1301,378 @@ class PartitionedCopy { unique_ptr batch); void FlushPartitionCollection(ExecutionContext &execution_context, InterruptState &interrupt_state, DelayedPartitionFlush flush); - unique_ptr CreatePartitionFileState(const vector &values, - FileCreationReason reason = FileCreationReason::NORMAL) - DUCKDB_EXCLUDES(active_writes_lock); - PartitionFileStateReservation - ReservePartitionFileStateLocked(const vector &values, FileCreationReason reason = FileCreationReason::NORMAL) - DUCKDB_REQUIRES(active_writes_lock); - unique_ptr CreatePartitionFileStateFromReservation(const vector &values, idx_t offset) - DUCKDB_EXCLUDES(active_writes_lock); - void FinalizeActiveWrites() DUCKDB_EXCLUDES(active_writes_lock); - void FinalizeFileStates(vector> files_to_finalize) DUCKDB_EXCLUDES(active_writes_lock); - string GetOrCreateDirectory(string path, const vector &values) DUCKDB_REQUIRES(copy_gstate.lock); + void RequestPartitionFileState(FileStateHandle &file_state, const vector &values, + FileCreationReason reason = FileCreationReason::NORMAL) + DUCKDB_EXCLUDES(copy_gstate.lock); + void FinalizeActiveWrites() DUCKDB_EXCLUDES(copy_gstate.lock); + void FinalizeFileStates(vector files_to_finalize) DUCKDB_EXCLUDES(copy_gstate.lock); + +private: + unique_ptr ConstructSortStrategy() const; + void CreateNextState(); + bool ShouldStopFlushing() const; + bool RequiresSerializedPartitionWrites() const; + void EnsureFreshPartitionFileForSortedRun(PartitionWriteInfo &write_info, const vector &values) + DUCKDB_EXCLUDES(copy_gstate.lock); + void EnsureFreshPartitionFileForRotation(PartitionWriteInfo &write_info, const vector &values) + DUCKDB_EXCLUDES(copy_gstate.lock); + //! Swaps write_info.file_state after temporarily dropping copy_gstate.lock to request the replacement file. + //! Callers that can reach the swap path must serialize the full partition writer run for this write_info. + void EnsureFreshPartitionFile(PartitionWriteInfo &write_info, const vector &values, + FileCreationReason reason) DUCKDB_EXCLUDES(copy_gstate.lock); + template + void WithSerializedPartitionWriteRun(PartitionWriteInfo &write_info, FUNC &&func) { + annotated_unique_lock run_guard(write_info.lock, std::defer_lock); + if (RequiresSerializedPartitionWrites()) { + run_guard.lock(); + } + func(); + } + void FlushDelayedPartitionRun(const vector &values, PartitionWriteInfo &write_info, + ColumnDataCollection &collection); + +public: + const PhysicalCopyToFile &op; + ClientContext &context; + CopyToFileGlobalState ©_gstate; + PartitionWriteManager partition_writes; + + //! Which columns/types to write to the file + vector write_columns; + vector write_types; + vector raw_columns; + + //! Partition/sort strategy with PhysicalOperator-like interface + const unique_ptr sort_strategy; + + //! Lock for managing states (sinking_state, flushing_state, flushing flag) + mutable annotated_mutex lock; + //! Whether a flushing state currently exists + atomic flushing; + //! How many threads are active + atomic locals; + //! How many threads did a combine + atomic combined; + //! Whether Finalize has been called + atomic finalized; + + //! Current sink and combine states + shared_ptr sinking_state DUCKDB_GUARDED_BY(lock); + shared_ptr flushing_state DUCKDB_GUARDED_BY(lock); + + //! Delayed below-threshold partitions + DelayedPartitionBuffers delayed_partition_buffers; +}; + +//===--------------------------------------------------------------------===// +// Partition File Request Builder Declarations +//===--------------------------------------------------------------------===// +class PartitionFileRequestBuilder { +public: + PartitionFileRequestBuilder(PartitionedCopy &partitioned_copy_p, FileStateHandle &file_state_p, + const vector &values_p, FileCreationReason reason_p) + : partitioned_copy(partitioned_copy_p), file_state(file_state_p), values(values_p), reason(reason_p) { + } + +public: + optional Build(); + vector TakeFilesToFinalize(); + +private: + PartitionDirectory BuildDirectory(string path) const; + +private: + PartitionedCopy &partitioned_copy; + FileStateHandle &file_state; + const vector &values; + FileCreationReason reason; + PartitionFileStateReservation reservation; +}; + +//===--------------------------------------------------------------------===// +// Partitioned Copy Local State Declarations +//===--------------------------------------------------------------------===// +class PartitionedCopyLocalState : public LocalSinkState { +public: + shared_ptr current_state; + unique_ptr sort_strategy_local_state; + idx_t append_count = 0; +}; + +//===--------------------------------------------------------------------===// +// Partitioned Copy Scoped Guard Declarations +//===--------------------------------------------------------------------===// +class DelayedPartitionFlushGuard { +public: + DelayedPartitionFlushGuard(PartitionedCopy &partitioned_copy_p, const vector &values_p) + : partitioned_copy(partitioned_copy_p), values(values_p) { + } + + ~DelayedPartitionFlushGuard(); + optional Complete(); + +private: + PartitionedCopy &partitioned_copy; + const vector &values; + bool active = true; +}; + +//===--------------------------------------------------------------------===// +// Implementations +//===--------------------------------------------------------------------===// + +//===--------------------------------------------------------------------===// +// Utility Helpers +//===--------------------------------------------------------------------===// +void CheckDirectory(FileSystem &fs, const string &file_path, CopyOverwriteMode overwrite_mode) { + if (overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE_OR_IGNORE || + overwrite_mode == CopyOverwriteMode::COPY_APPEND) { + // with overwrite or ignore we fully ignore the presence of any files instead of erasing them + return; + } + vector file_list; + vector directory_list; + directory_list.push_back(file_path); + for (idx_t dir_idx = 0; dir_idx < directory_list.size(); dir_idx++) { + auto directory = directory_list[dir_idx]; + fs.ListFiles(directory, [&](const string &path, bool is_directory) { + auto full_path = fs.JoinPath(directory, path); + if (is_directory) { + directory_list.emplace_back(std::move(full_path)); + } else { + file_list.emplace_back(std::move(full_path)); + } + }); + } + if (file_list.empty()) { + return; + } + if (overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE) { + fs.RemoveFiles(file_list); + } else { + throw IOException("Directory \"%s\" is not empty! Enable OVERWRITE option to overwrite files", file_path); + } +} + +struct PhysicalCopyToFileColumnStatsMapData { + vector keys; + vector values; +}; + +static PhysicalCopyToFileColumnStatsMapData +CreateColumnStatistics(const case_insensitive_map_t> &column_statistics) { + PhysicalCopyToFileColumnStatsMapData result; + + //! Use a map to make sure the result has a consistent ordering + map stats; + for (auto &entry : column_statistics) { + map per_column_stats; + for (auto &stats_entry : entry.second) { + per_column_stats.emplace(stats_entry.first, stats_entry.second); + } + vector stats_keys; + vector stats_values; + for (auto &stats_entry : per_column_stats) { + stats_keys.emplace_back(stats_entry.first); + stats_values.emplace_back(std::move(stats_entry.second)); + } + auto map_value = + Value::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR, std::move(stats_keys), std::move(stats_values)); + stats.emplace(entry.first, std::move(map_value)); + } + for (auto &entry : stats) { + result.keys.emplace_back(entry.first); + result.values.emplace_back(std::move(entry.second)); + } + return result; +} -private: - unique_ptr ConstructSortStrategy() const; - void CreateNextState(); - bool ShouldStopFlushing() const; - bool RequiresSerializedPartitionWrites() const; - void EnsureFreshPartitionFileForSortedRun(PartitionWriteInfo &write_info, const vector &values) - DUCKDB_EXCLUDES(active_writes_lock); - void EnsureFreshPartitionFileForRotation(PartitionWriteInfo &write_info, const vector &values) - DUCKDB_EXCLUDES(active_writes_lock); - //! Swaps write_info.file_state after temporarily dropping copy_gstate.lock to initialize the replacement file. - //! Callers that can reach the swap path must serialize the full partition writer run for this write_info. - void EnsureFreshPartitionFile(PartitionWriteInfo &write_info, const vector &values, - FileCreationReason reason) DUCKDB_EXCLUDES(active_writes_lock); - template - void WithSerializedPartitionWriteRun(PartitionWriteInfo &write_info, FUNC &&func) { - annotated_unique_lock run_guard(write_info.lock, std::defer_lock); - if (RequiresSerializedPartitionWrites()) { - run_guard.lock(); +//===--------------------------------------------------------------------===// +// Copy File Lifecycle +//===--------------------------------------------------------------------===// +static void FinalizeLifecycleFileState(ClientContext &context, copy_to_finalize_t finalize, FunctionData &bind_data, + unique_ptr state) { + if (!finalize) { + throw InternalException("COPY file lifecycle finalize requires a finalize callback"); + } + if (!state || !state->data) { + throw InternalException("COPY file lifecycle finalize reached an empty file state"); + } + finalize(context, bind_data, *state->data); +} +void CopyFileLifecycleExecutor::WaitForJob(CopyFileLifecycleJob &job, CopyFileLifecycleWaitMode mode) { + while (!job.IsFinished()) { + if (mode == CopyFileLifecycleWaitMode::INTERRUPTIBLE) { + context.InterruptCheck(); } - func(); + WorkOnTaskOrYield(); } - void FlushDelayedPartitionRun(const vector &values, PartitionWriteInfo &write_info, - ColumnDataCollection &collection); + job.Rethrow(); +} -public: - const PhysicalCopyToFile &op; - ClientContext &context; - CopyToFileGlobalState ©_gstate; +void CopyFileLifecycleExecutor::WaitAll(CopyFileLifecycleWaitMode mode) { + if (mode == CopyFileLifecycleWaitMode::DRAIN) { + executor.WorkOnTasks(); + ThrowError(); + return; + } + while (pending_tasks.load(std::memory_order_relaxed) > 0) { + context.InterruptCheck(); + WorkOnTaskOrYield(); + } + ThrowError(); +} - //! Which columns/types to write to the file - vector write_columns; - vector write_types; - vector raw_columns; +void CopyFileLifecycleExecutor::FinishTask() { + --pending_tasks; +} - //! Partition/sort strategy with PhysicalOperator-like interface - const unique_ptr sort_strategy; +void CopyFileLifecycleExecutor::PushError(const std::exception_ptr &error_p) { + lock_guard guard(error_lock); + if (!error) { + error = error_p; + } +} - //! Lock for managing states (sinking_state, flushing_state, flushing flag) - mutable annotated_mutex lock; - //! Whether a flushing state currently exists - atomic flushing; - //! How many threads are active - atomic locals; - //! How many threads did a combine - atomic combined; - //! Whether Finalize has been called - atomic finalized; +bool CopyFileLifecycleExecutor::WorkOnTask(bool throw_error) { + shared_ptr task; + if (!executor.GetTask(task)) { + return false; + } + const auto result = task->Execute(TaskExecutionMode::PROCESS_ALL); + D_ASSERT(result != TaskExecutionResult::TASK_BLOCKED); + task.reset(); + if (throw_error) { + ThrowError(); + } + return true; +} - //! Current sink and combine states - shared_ptr sinking_state DUCKDB_GUARDED_BY(lock); - shared_ptr flushing_state DUCKDB_GUARDED_BY(lock); +void CopyFileLifecycleExecutor::WorkOnTaskOrYield() { + if (!WorkOnTask()) { + TaskScheduler::YieldThread(); + } +} - //! Fine-grained lock for partition write tracking - mutable annotated_mutex active_writes_lock; - //! The active writes per partition (for partitioned write) - vector_of_value_map_t> active_writes DUCKDB_GUARDED_BY(active_writes_lock); - vector_of_value_map_t previous_partitions DUCKDB_GUARDED_BY(active_writes_lock); - idx_t global_offset DUCKDB_GUARDED_BY(active_writes_lock) = 0; +void CopyFileLifecycleExecutor::WaitForTaskSlot(CopyFileLifecycleWaitMode mode) { + while (pending_tasks >= max_pending_tasks) { + if (mode == CopyFileLifecycleWaitMode::INTERRUPTIBLE) { + context.InterruptCheck(); + } + if (async_threads == 0 || mode == CopyFileLifecycleWaitMode::DRAIN) { + WorkOnTaskOrYield(); + } else { + TaskScheduler::YieldThread(); + } + } +} - //! Delayed below-threshold partitions - DelayedPartitionBuffers delayed_partition_buffers; -}; +void CopyFileLifecycleExecutor::ThrowError() { + lock_guard guard(error_lock); + if (error) { + std::rethrow_exception(error); + } +} //===--------------------------------------------------------------------===// -// Partitioned Copy Scoped Guards +// Copy File State Helpers //===--------------------------------------------------------------------===// -class PartitionWriteInfoGuard { -public: - PartitionWriteInfoGuard(PartitionedCopy &partitioned_copy_p, PartitionWriteInfo &write_info_p) - : partitioned_copy(partitioned_copy_p), write_info(write_info_p) { +void CopyDirectoryManager::EnsureDirectory(FileSystem &fs, const string &dir_path) { + bool created_entry = false; + { + std::unique_lock guard(lock); + while (true) { + auto entry = directories.find(dir_path); + if (entry == directories.end()) { + directories.emplace(dir_path, DirectoryEntry()); + created_entry = true; + break; + } + + if (entry->second.state == CopyDirectoryState::COMPLETE) { + return; + } + if (entry->second.state == CopyDirectoryState::FAILED) { + std::rethrow_exception(entry->second.error); + } + condition.wait(guard); + } } - ~PartitionWriteInfoGuard(); + std::exception_ptr error; + try { + if (!fs.DirectoryExists(dir_path)) { + fs.CreateDirectory(dir_path); + } + } catch (...) { + error = std::current_exception(); + } -private: - PartitionedCopy &partitioned_copy; - optional_ptr write_info; -}; + { + lock_guard guard(lock); + auto entry = directories.find(dir_path); + D_ASSERT(entry != directories.end()); + D_ASSERT(created_entry); + entry->second.state = error ? CopyDirectoryState::FAILED : CopyDirectoryState::COMPLETE; + entry->second.error = error; + } + condition.notify_all(); -class DelayedPartitionFlushGuard { -public: - DelayedPartitionFlushGuard(PartitionedCopy &partitioned_copy_p, const vector &values_p) - : partitioned_copy(partitioned_copy_p), values(values_p) { + if (error) { + std::rethrow_exception(error); } +} - ~DelayedPartitionFlushGuard(); - optional Complete(); +PendingFileState CopyOutputFileRegistry::ReserveFile(string output_path, + optional_ptr> partition_values) { + PendingFileState result; + result.output_path = std::move(output_path); -private: - PartitionedCopy &partitioned_copy; - const vector &values; - bool active = true; -}; + if (op.return_type != CopyFunctionReturnType::CHANGED_ROWS) { + result.written_file_info = AddFile(result.output_path); + } + + if (result.written_file_info && !op.partition_columns.empty()) { + D_ASSERT(partition_values); + vector partition_keys; + vector partition_values_as_varchar; + for (idx_t i = 0; i < op.partition_columns.size(); i++) { + const auto &partition_col_name = op.names[op.partition_columns[i]]; + const auto &partition_value = (*partition_values)[i]; + partition_keys.emplace_back(partition_col_name); + partition_values_as_varchar.push_back(partition_value.DefaultCastAs(LogicalType::VARCHAR)); + } + result.written_file_info->partition_keys = + Value::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR, std::move(partition_keys), + std::move(partition_values_as_varchar)); + } + return result; +} + +void CopyOutputFileRegistry::PublishCreatedPath(PendingFileState &pending_file_state, string output_path) { + pending_file_state.output_path = std::move(output_path); + if (pending_file_state.written_file_info) { + pending_file_state.written_file_info->file_path = pending_file_state.output_path; + } + created_files.push_back(pending_file_state.output_path); +} -PartitionWriteInfoGuard::~PartitionWriteInfoGuard() { - if (write_info) { - partitioned_copy.ReleasePartitionWriteInfo(*write_info); +optional_ptr CopyOutputFileRegistry::AddFile(const string &file_name) { + auto file_info = make_uniq(file_name); + if (op.return_type == CopyFunctionReturnType::WRITTEN_FILE_STATISTICS) { + file_info->file_stats = make_uniq(); } + auto result = file_info.get(); + written_files.push_back(std::move(file_info)); + return result; } +static bool PhysicalCopyRotateNow(const PhysicalCopyToFile &op, GlobalFileState &global_state) + DUCKDB_REQUIRES(global_state.lock); + +//===--------------------------------------------------------------------===// +// Partitioned Copy Scoped Guards +//===--------------------------------------------------------------------===// DelayedPartitionFlushGuard::~DelayedPartitionFlushGuard() { if (active) { partitioned_copy.CompleteDelayedPartition(values, false); @@ -881,7 +1688,115 @@ optional DelayedPartitionFlushGuard::Complete() { } //===--------------------------------------------------------------------===// -// Partitioned Hash Group Implementation +// Partition Write Manager +//===--------------------------------------------------------------------===// +PartitionWriteLease::PartitionWriteLease(PartitionWriteManager &manager_p, PartitionWriteInfo &write_info_p) + : manager(manager_p), write_info(write_info_p) { +} + +PartitionWriteLease::PartitionWriteLease(PartitionWriteLease &&other) noexcept + : manager(other.manager), write_info(other.write_info) { + other.manager = nullptr; + other.write_info = nullptr; +} + +PartitionWriteLease &PartitionWriteLease::operator=(PartitionWriteLease &&other) noexcept { + if (this != &other) { + Reset(); + manager = other.manager; + write_info = other.write_info; + other.manager = nullptr; + other.write_info = nullptr; + } + return *this; +} + +PartitionWriteLease::~PartitionWriteLease() { + Reset(); +} + +void PartitionWriteLease::Reset() { + if (manager && write_info) { + manager->Release(*write_info); + } + manager = nullptr; + write_info = nullptr; +} + +PartitionWriteLease PartitionWriteManager::Acquire(const vector &values) { + PartitionWriteInfo *result; + { + annotated_lock_guard guard(lock); + auto active_write_entry = active_writes.find(values); + if (active_write_entry != active_writes.end()) { + active_write_entry->second->active_writes++; + result = active_write_entry->second.get(); + } else { + auto info = make_uniq(); + result = info.get(); + info->active_writes = 1; + active_writes.insert(make_pair(values, std::move(info))); + } + } + return PartitionWriteLease(*this, *result); +} + +PartitionWriteManager::ReservationLock PartitionWriteManager::LockForReservation() { + return ReservationLock(lock); +} + +PartitionFileStateReservation +PartitionWriteManager::ReserveFileState(ReservationLock &reservation_lock, const vector &values, + FileCreationReason reason) DUCKDB_NO_THREAD_SAFETY_ANALYSIS { + D_ASSERT(reservation_lock.guard.owns_lock()); + PartitionFileStateReservation reservation; + if (active_writes.size() >= Settings::Get(context)) { + for (auto it = active_writes.begin(); it != active_writes.end(); ++it) { + if (it->second->active_writes == 0) { + reservation.files_to_finalize.push_back(std::move(it->second->file_state)); + ++previous_partitions[it->first]; + active_writes.erase(it); + break; + } + } + } + + if (op.hive_file_pattern) { + if (reason == FileCreationReason::SORTED_RUN_BOUNDARY || reason == FileCreationReason::ROTATION) { + ++previous_partitions[values]; + } + auto prev_offset = previous_partitions.find(values); + if (prev_offset != previous_partitions.end()) { + reservation.offset = prev_offset->second; + } + } else { + reservation.offset = global_offset++; + } + + return reservation; +} + +vector PartitionWriteManager::TakeOpenFileStates() { + vector files_to_finalize; + { + annotated_lock_guard guard(lock); + for (auto &entry : active_writes) { + D_ASSERT(entry.second->active_writes == 0); + files_to_finalize.push_back(std::move(entry.second->file_state)); + } + active_writes.clear(); + } + return files_to_finalize; +} + +void PartitionWriteManager::Release(PartitionWriteInfo &write_info) { + annotated_lock_guard guard(lock); + D_ASSERT(write_info.active_writes > 0); + write_info.active_writes--; +} + +//===--------------------------------------------------------------------===// +// Partitioned Copy Hash Group //===--------------------------------------------------------------------===// PartitionedCopyHashGroup::PartitionedCopyHashGroup(PartitionedCopy &partitioned_copy, const ChunkRow &chunk_row, idx_t group_idx_p) @@ -932,7 +1847,7 @@ bool PartitionedCopyHashGroup::TryPrepareNextStage() { } return false; case PartitionedCopyStage::PREPARE: - if (prepared == count) { + if (prepared == count && PreparedBatchesReady()) { CompletePreparedBatchStates(); stage = PartitionedCopyStage::FLUSH; return true; @@ -993,7 +1908,7 @@ optional PartitionedCopyHashGroup::TryNextBatchTask() { task.stage = PartitionedCopyStage::BATCH; task.group_idx = group_idx; task.thread_idx = batch_states.size() - 1; - task.batch_idx = batch_state.collections.size(); + task.batch_idx = batch_state.NextCollectionIndex(); task.begin_idx = batch_row_idx; // Find the end_idx @@ -1069,41 +1984,23 @@ optional PartitionedCopyHashGroup::TryNextBatchTask() { optional PartitionedCopyHashGroup::TryNextPrepareTask() { while (prepare_partition_idx < batch_states.size()) { auto &batch_state = *batch_states[prepare_partition_idx]; - if (batch_state.SkipsPrepare()) { + auto action = batch_state.SelectPrepareTask(prepare_batch_idx); + if (action.type == PartitionedCopyPrepareTaskActionType::SKIP_PARTITION) { ++prepare_partition_idx; prepare_batch_idx = 0; continue; } - D_ASSERT(batch_state.mode == PartitionedCopyBatchMode::PREPARING); - if (batch_state.NeedsWriteInfo()) { - if (!batch_state.IsWriteInfoReserved() && batch_state.TryReserveWriteInfo()) { - PartitionedCopyTask task; - task.stage = PartitionedCopyStage::PREPARE; - task.group_idx = group_idx; - task.thread_idx = prepare_partition_idx; - task.batch_idx = DConstants::INVALID_INDEX; - return task; - } + if (action.type == PartitionedCopyPrepareTaskActionType::WAIT_FOR_WRITE_INFO) { return nullopt; } - while (prepare_batch_idx < batch_state.collections.size()) { - if (!batch_state.NeedsPrepare(prepare_batch_idx)) { - ++prepare_batch_idx; - continue; - } - break; - } - if (prepare_batch_idx == batch_state.collections.size()) { - ++prepare_partition_idx; - prepare_batch_idx = 0; - continue; - } PartitionedCopyTask task; task.stage = PartitionedCopyStage::PREPARE; task.group_idx = group_idx; task.thread_idx = prepare_partition_idx; - task.batch_idx = prepare_batch_idx++; + task.batch_idx = action.type == PartitionedCopyPrepareTaskActionType::ACQUIRE_WRITE_INFO + ? DConstants::INVALID_INDEX + : action.batch_idx; return task; } @@ -1124,6 +2021,19 @@ optional PartitionedCopyHashGroup::TryNextFlushTask() { return task; } +bool PartitionedCopyHashGroup::PreparedBatchesReady() const { + for (auto &batch_state_ptr : batch_states) { + if (!batch_state_ptr) { + return false; + } + auto &batch_state = *batch_state_ptr; + if (!batch_state.ReadyForFlush()) { + return false; + } + } + return true; +} + void PartitionedCopyHashGroup::Sort(ExecutionContext &execution_context, GlobalSinkState &sink, InterruptState &interrupt, const PartitionedCopyTask &task) { D_ASSERT(task.stage == PartitionedCopyStage::SORT); @@ -1234,38 +2144,18 @@ void PartitionedCopyHashGroup::Batch(const PartitionedCopyTask &task) { collection->Scan(scan_state, scan_chunk); } - // Get pointer to batch state (under lock) - optional_ptr batch_state; optional_ptr write_info; - vector partition_values; - bool acquire_write_info = false; - bool prepare_batch = false; + const auto flush_threshold = Settings::Get(partitioned_copy.context); + const auto has_delayed_partition = partitioned_copy.HasDelayedPartition(values); + PartitionedCopyBatchAction batch_action; { annotated_lock_guard guard(lock); - batch_state = batch_states[task.thread_idx]; - batch_state->SetValues(std::move(values)); - - if (batch_state->mode == PartitionedCopyBatchMode::BUFFERING) { - const auto flush_threshold = Settings::Get(partitioned_copy.context); - const auto has_delayed_partition = partitioned_copy.HasDelayedPartition(batch_state->values); - if (batch_state->CanStartPreparing(flush_threshold, has_delayed_partition)) { - batch_state->StartPreparing(); - acquire_write_info = batch_state->TryReserveWriteInfo(); - } - } - - if (batch_state->mode == PartitionedCopyBatchMode::PREPARING && batch_state->write_info) { - write_info = batch_state->write_info; - partition_values = batch_state->values; - prepare_batch = true; - } else if (acquire_write_info) { - partition_values = batch_state->values; - prepare_batch = true; - } + auto &batch_state = *batch_states[task.thread_idx]; + batch_action = batch_state.RegisterBatch(std::move(values), flush_threshold, has_delayed_partition); } const auto row_count = task.end_idx - task.begin_idx; - if (!prepare_batch) { + if (batch_action.type == PartitionedCopyBatchActionType::STORE_COLLECTION) { { annotated_lock_guard guard(lock); auto ¤t_batch_state = *batch_states[task.thread_idx]; @@ -1275,18 +2165,20 @@ void PartitionedCopyHashGroup::Batch(const PartitionedCopyTask &task) { return; } - if (acquire_write_info) { - auto &partition_write_info = partitioned_copy.GetPartitionWriteInfo(partition_values); + if (batch_action.type == PartitionedCopyBatchActionType::ACQUIRE_AND_PREPARE) { + auto write_lease = partitioned_copy.partition_writes.Acquire(batch_action.values); { annotated_lock_guard guard(lock); auto ¤t_batch_state = *batch_states[task.thread_idx]; - current_batch_state.SetWriteInfo(partition_write_info); - write_info = current_batch_state.write_info; + write_info = current_batch_state.CompleteWriteInfoAcquisition(std::move(write_lease)); } + } else { + D_ASSERT(batch_action.type == PartitionedCopyBatchActionType::PREPARE_WITH_WRITE_INFO); + write_info = batch_action.write_info; } D_ASSERT(write_info); auto prepared_batch = partitioned_copy.PreparePartitionBatch( - partition_values, *write_info, PartitionedCopyCollection(collection_schema, std::move(batch))); + batch_action.values, *write_info, PartitionedCopyCollection(collection_schema, std::move(batch))); { annotated_lock_guard guard(lock); @@ -1302,23 +2194,8 @@ void PartitionedCopyHashGroup::PrepareBatchStates() { const auto flush_threshold = Settings::Get(partitioned_copy.context); for (auto &batch_state : batch_states) { D_ASSERT(batch_state); - D_ASSERT(!batch_state->values.empty()); - D_ASSERT(batch_state->count > 0); - - if (batch_state->mode == PartitionedCopyBatchMode::PREPARING) { - D_ASSERT(batch_state->write_info); - batch_state->EnsurePreparingBatchSlots(); - continue; - } - D_ASSERT(batch_state->mode == PartitionedCopyBatchMode::BUFFERING); - - if (batch_state->count < flush_threshold || partitioned_copy.HasDelayedPartition(batch_state->values)) { - batch_state->MarkDelayed(); - prepared += batch_state->count; - continue; - } - - batch_state->StartPreparing(); + const auto has_delayed_partition = partitioned_copy.HasDelayedPartition(batch_state->Values()); + prepared += batch_state->FinalizeBatching(flush_threshold, has_delayed_partition); } } @@ -1333,38 +2210,26 @@ void PartitionedCopyHashGroup::Prepare(ExecutionContext &execution_context, Inte const PartitionedCopyTask &task) { D_ASSERT(task.stage == PartitionedCopyStage::PREPARE); - vector values; - optional_ptr write_info; - PartitionedCopyCollection data; - bool acquire_write_info = false; + PartitionedCopyPrepareAction prepare_action; { annotated_lock_guard guard(lock); auto &batch_state = *batch_states[task.thread_idx]; - D_ASSERT(batch_state.mode == PartitionedCopyBatchMode::PREPARING); - values = batch_state.values; - if (task.batch_idx == DConstants::INVALID_INDEX) { - D_ASSERT(batch_state.NeedsWriteInfo()); - D_ASSERT(batch_state.IsWriteInfoReserved()); - acquire_write_info = true; - } else { - D_ASSERT(batch_state.write_info); - D_ASSERT(task.batch_idx < batch_state.collections.size()); - write_info = batch_state.write_info; - data = batch_state.TakeCollection(task.batch_idx); - } + prepare_action = batch_state.BeginPrepareTask(task.batch_idx); } - if (acquire_write_info) { - auto &partition_write_info = partitioned_copy.GetPartitionWriteInfo(values); + if (prepare_action.type == PartitionedCopyPrepareActionType::ACQUIRE_WRITE_INFO) { + auto write_lease = partitioned_copy.partition_writes.Acquire(prepare_action.values); annotated_lock_guard guard(lock); auto &batch_state = *batch_states[task.thread_idx]; - batch_state.SetWriteInfo(partition_write_info); + batch_state.CompleteWriteInfoAcquisition(std::move(write_lease)); return; } - D_ASSERT(data.collection); - const auto row_count = data.Count(); + D_ASSERT(prepare_action.type == PartitionedCopyPrepareActionType::PREPARE_BATCH); + D_ASSERT(prepare_action.data.collection); + const auto row_count = prepare_action.data.Count(); - auto prepared_batch = partitioned_copy.PreparePartitionBatch(values, *write_info, std::move(data)); + auto prepared_batch = partitioned_copy.PreparePartitionBatch(prepare_action.values, *prepare_action.write_info, + std::move(prepare_action.data)); { annotated_lock_guard guard(lock); @@ -1379,31 +2244,19 @@ void PartitionedCopyHashGroup::Flush(ExecutionContext &execution_context, Interr const PartitionedCopyTask &task) { D_ASSERT(task.stage == PartitionedCopyStage::FLUSH); - vector values; - vector collections; - vector> batches; - optional_ptr write_info; - PartitionedCopyBatchMode mode; + PartitionedCopyFlushAction flush_action; { annotated_lock_guard guard(lock); auto &batch_state = *batch_states[task.thread_idx]; - values = batch_state.values; - mode = batch_state.mode; - if (mode == PartitionedCopyBatchMode::DELAYED) { - collections = batch_state.TakeDelayedCollections(); - } else { - D_ASSERT(mode == PartitionedCopyBatchMode::PREPARED); - write_info = batch_state.write_info; - batches = batch_state.TakePreparedBatches(); - } + flush_action = batch_state.TakeFlushAction(); } - D_ASSERT(!values.empty()); + D_ASSERT(!flush_action.values.empty()); - if (mode == PartitionedCopyBatchMode::DELAYED) { + if (flush_action.type == PartitionedCopyFlushActionType::DELAYED_COLLECTIONS) { const auto collection_schema = partitioned_copy.GetPartitionCollectionSchema(); const auto &collection_types = partitioned_copy.GetPartitionCollectionTypes(collection_schema); auto collection = make_uniq(partitioned_copy.context, collection_types); - for (auto &batch : collections) { + for (auto &batch : flush_action.collections) { D_ASSERT(batch.schema == collection_schema); if (batch.collection) { collection->Combine(*batch.collection); @@ -1414,20 +2267,21 @@ void PartitionedCopyHashGroup::Flush(ExecutionContext &execution_context, Interr } auto data = PartitionedCopyCollection(collection_schema, std::move(collection)); - auto ready_partition = partitioned_copy.BufferOrTakeReadyPartition(values, std::move(data), false); + auto ready_partition = partitioned_copy.BufferOrTakeReadyPartition(flush_action.values, std::move(data), false); if (ready_partition) { partitioned_copy.FlushPartitionCollection(execution_context, interrupt_state, std::move(*ready_partition)); } return; } - D_ASSERT(write_info); - PartitionWriteInfoGuard write_guard(partitioned_copy, *write_info); - partitioned_copy.FlushPreparedPartitionRun(values, *write_info, std::move(batches)); + D_ASSERT(flush_action.type == PartitionedCopyFlushActionType::PREPARED_BATCHES); + D_ASSERT(flush_action.write_lease); + partitioned_copy.FlushPreparedPartitionRun(flush_action.values, *flush_action.write_lease, + std::move(flush_action.batches)); } //===--------------------------------------------------------------------===// -// Partitioned Copy State Implementation +// Partitioned Copy State //===--------------------------------------------------------------------===// PartitionedCopyState::PartitionedCopyState(PartitionedCopy &partitioned_copy_p, unique_ptr global_sink_state_p) @@ -1594,20 +2448,13 @@ void PartitionedCopyState::FinishTask(const PartitionedCopyTask &task) { } } -class PartitionedCopyLocalState : public LocalSinkState { -public: - shared_ptr current_state; - unique_ptr sort_strategy_local_state; - idx_t append_count = 0; -}; - //===--------------------------------------------------------------------===// // Partitioned Copy Lifecycle //===--------------------------------------------------------------------===// PartitionedCopy::PartitionedCopy(const PhysicalCopyToFile &op_p, ClientContext &context_p, CopyToFileGlobalState ©_gstate_p) - : op(op_p), context(context_p), copy_gstate(copy_gstate_p), sort_strategy(ConstructSortStrategy()), flushing(false), - locals(0), combined(0), finalized(false) { + : op(op_p), context(context_p), copy_gstate(copy_gstate_p), partition_writes(op_p, context_p), + sort_strategy(ConstructSortStrategy()), flushing(false), locals(0), combined(0), finalized(false) { unordered_set part_col_set(op.partition_columns.begin(), op.partition_columns.end()); for (idx_t col_idx = 0; col_idx < op.expected_types.size(); col_idx++) { raw_columns.push_back(col_idx); @@ -1830,6 +2677,7 @@ class PartitionedCopyFinalizeEvent : public BasePipelineEvent { if (done) { partitioned_copy.FinalizeActiveWrites(); partitioned_copy.copy_gstate.TryFinalizeOwnedFileState(); + partitioned_copy.copy_gstate.WaitForLifecycleTasks(); } } @@ -1859,6 +2707,7 @@ void PartitionedCopy::Finalize(Pipeline &pipeline, Event &event, InterruptState if (should_finalize_writes) { FinalizeActiveWrites(); copy_gstate.TryFinalizeOwnedFileState(); + copy_gstate.WaitForLifecycleTasks(); } } @@ -2048,16 +2897,69 @@ unique_ptr PartitionedCopy::ProjectToWriteColumns(unique_p return result; } +//===--------------------------------------------------------------------===// +// Partition File Request Builder +//===--------------------------------------------------------------------===// +optional PartitionFileRequestBuilder::Build() { + auto reservation_lock = partitioned_copy.partition_writes.LockForReservation(); + annotated_lock_guard global_guard(partitioned_copy.copy_gstate.lock); + if (file_state) { + return nullopt; + } + + reservation = partitioned_copy.partition_writes.ReserveFileState(reservation_lock, values, reason); + + auto &op = partitioned_copy.op; + auto &context = partitioned_copy.context; + auto &fs = FileSystem::GetFileSystem(context); + auto directory = BuildDirectory(op.GetTrimmedPath(context, op.file_path)); + auto full_path = op.filename_pattern.CreateFilename(fs, directory.path, op.file_extension, reservation.offset); + auto pending_file_state_open = + partitioned_copy.copy_gstate.CreatePartitionFileStateOpenLocked(file_state, std::move(full_path), values); + D_ASSERT(pending_file_state_open); + + PartitionFileOpenRequest open_request(std::move(pending_file_state_open), std::move(directory), reservation.offset); + return PartitionFileRequest(std::move(open_request), std::move(reservation.files_to_finalize)); +} + +vector PartitionFileRequestBuilder::TakeFilesToFinalize() { + return std::move(reservation.files_to_finalize); +} + +PartitionDirectory PartitionFileRequestBuilder::BuildDirectory(string path) const { + auto &fs = FileSystem::GetFileSystem(partitioned_copy.context); + PartitionDirectory result; + result.path = std::move(path); + if (partitioned_copy.op.hive_file_pattern) { + for (idx_t i = 0; i < partitioned_copy.op.partition_columns.size(); i++) { + const auto &partition_col_name = partitioned_copy.op.names[partitioned_copy.op.partition_columns[i]]; + const auto &partition_value = values[i]; + string p_dir; + p_dir += HivePartitioning::Escape(partition_col_name.GetIdentifierName()); + p_dir += "="; + if (partition_value.IsNull()) { + p_dir += "__HIVE_DEFAULT_PARTITION__"; + } else { + p_dir += HivePartitioning::Escape(partition_value.ToString()); + } + result.path = fs.JoinPath(result.path, p_dir); + result.directories.push_back(result.path); + } + } + return result; +} + //===--------------------------------------------------------------------===// // Partitioned Write Helpers //===--------------------------------------------------------------------===// unique_ptr PartitionedCopy::PreparePartitionBatch(const vector &values, PartitionWriteInfo &write_info, PartitionedCopyCollection data) { - auto collection = PrepareCollectionForWrite(std::move(data)); - const auto create_file_state_fun = [&]() { - return CreatePartitionFileState(values); + const auto create_file_state_fun = [&](FileStateHandle &file_state) { + RequestPartitionFileState(file_state, values); }; + create_file_state_fun(write_info.file_state); + auto collection = PrepareCollectionForWrite(std::move(data)); auto [batch_analyzer, prepared_batch] = op.PrepareBatch(context, copy_gstate, write_info.file_state, create_file_state_fun, std::move(collection)); return make_uniq(batch_analyzer, std::move(prepared_batch)); @@ -2068,7 +2970,9 @@ void PartitionedCopy::FlushPreparedPartitionRun(const vector &values, Par WithSerializedPartitionWriteRun(write_info, [&]() { EnsureFreshPartitionFileForSortedRun(write_info, values); for (auto &batch : batches) { - D_ASSERT(batch); + if (!batch) { + throw InternalException("Partitioned COPY reached FLUSH with a missing prepared batch"); + } EnsureFreshPartitionFileForRotation(write_info, values); FlushPreparedPartitionBatch(values, write_info, std::move(batch)); } @@ -2078,8 +2982,8 @@ void PartitionedCopy::FlushPreparedPartitionRun(const vector &values, Par void PartitionedCopy::FlushPreparedPartitionBatch(const vector &values, PartitionWriteInfo &write_info, unique_ptr batch) { D_ASSERT(batch); - const auto create_file_state_fun = [&]() { - return CreatePartitionFileState(values); + const auto create_file_state_fun = [&](FileStateHandle &file_state) { + RequestPartitionFileState(file_state, values); }; op.FlushBatch(context, copy_gstate, write_info.file_state, create_file_state_fun, batch->batch_analyzer, std::move(batch->prepared_batch)); @@ -2137,9 +3041,8 @@ void PartitionedCopy::FlushPartitionCollection(ExecutionContext &execution_conte } D_ASSERT(flush.data.schema == PartitionedCopyCollectionSchema::WRITE_SCHEMA); - auto &write_info = GetPartitionWriteInfo(flush.values); - PartitionWriteInfoGuard write_guard(*this, write_info); - FlushDelayedPartitionRun(flush.values, write_info, *flush.data.collection); + auto write_lease = partition_writes.Acquire(flush.values); + FlushDelayedPartitionRun(flush.values, *write_lease, *flush.data.collection); } auto next = delayed_guard.Complete(); @@ -2150,100 +3053,34 @@ void PartitionedCopy::FlushPartitionCollection(ExecutionContext &execution_conte } } -PartitionWriteInfo &PartitionedCopy::GetPartitionWriteInfo(const vector &values) { - PartitionWriteInfo *result; - { - annotated_lock_guard guard(active_writes_lock); - // check if we have already started writing this partition - auto active_write_entry = active_writes.find(values); - if (active_write_entry != active_writes.end()) { - // we have - continue writing in this partition - active_write_entry->second->active_writes++; - result = active_write_entry->second.get(); - } else { - auto info = make_uniq(); - result = info.get(); - - info->active_writes = 1; - // store in active write map - active_writes.insert(make_pair(values, std::move(info))); - } - } - - return *result; -} - -void PartitionedCopy::ReleasePartitionWriteInfo(PartitionWriteInfo &write_info) { - annotated_lock_guard guard(active_writes_lock); - D_ASSERT(write_info.active_writes > 0); - write_info.active_writes--; -} - -unique_ptr PartitionedCopy::CreatePartitionFileState(const vector &values, - FileCreationReason reason) { - PartitionFileStateReservation reservation; - { - annotated_lock_guard guard(active_writes_lock); - reservation = ReservePartitionFileStateLocked(values, reason); - } - FinalizeFileStates(std::move(reservation.files_to_finalize)); - return CreatePartitionFileStateFromReservation(values, reservation.offset); -} - -PartitionFileStateReservation PartitionedCopy::ReservePartitionFileStateLocked(const vector &values, - FileCreationReason reason) { - PartitionFileStateReservation reservation; - // check if we need to close any writers before we can continue - if (active_writes.size() >= Settings::Get(context)) { - // we need to! try to close writers - for (auto it = active_writes.begin(); it != active_writes.end(); ++it) { - if (it->second->active_writes == 0) { - // we can evict this entry - evict the partition - reservation.files_to_finalize.push_back(std::move(it->second->file_state)); - ++previous_partitions[it->first]; - active_writes.erase(it); - break; - } +void PartitionedCopy::RequestPartitionFileState(FileStateHandle &file_state, const vector &values, + FileCreationReason reason) { + PartitionFileRequestBuilder builder(*this, file_state, values, reason); + optional request; + try { + request = builder.Build(); + } catch (...) { + auto error = std::current_exception(); + try { + FinalizeFileStates(builder.TakeFilesToFinalize()); + } catch (...) { } + std::rethrow_exception(error); } - - if (op.hive_file_pattern) { - if (reason == FileCreationReason::SORTED_RUN_BOUNDARY || reason == FileCreationReason::ROTATION) { - ++previous_partitions[values]; - } - auto prev_offset = previous_partitions.find(values); - if (prev_offset != previous_partitions.end()) { - reservation.offset = prev_offset->second; - } - } else { - reservation.offset = global_offset++; + if (!request) { + return; } - return reservation; -} - -unique_ptr PartitionedCopy::CreatePartitionFileStateFromReservation(const vector &values, - idx_t offset) { - // The reservation/eviction decision has already been made under active_writes_lock. This section only - // serializes global file bookkeeping: directory tracking, filename registration, and writer initialization. - annotated_lock_guard guard(copy_gstate.lock); - - // Create a writer for the current file - auto &fs = FileSystem::GetFileSystem(context); - const auto hive_path = GetOrCreateDirectory(op.GetTrimmedPath(context, op.file_path), values); - auto full_path = op.filename_pattern.CreateFilename(fs, hive_path, op.file_extension, offset); - if (op.overwrite_mode == CopyOverwriteMode::COPY_APPEND) { - // when appending, we first check if the file exists - while (fs.FileExists(full_path)) { - // file already exists - re-generate name - if (!op.filename_pattern.HasUUID()) { - throw InternalException("CopyOverwriteMode::COPY_APPEND without {uuid} - and file exists"); - } - full_path = op.filename_pattern.CreateFilename(fs, hive_path, op.file_extension, offset); + auto open_job = request->OpenJob(); + try { + FinalizeFileStates(std::move(request->files_to_finalize)); + copy_gstate.SchedulePartitionFileStateOpen(std::move(request->open_request)); + } catch (...) { + if (open_job && !open_job->IsFinished()) { + open_job->CompleteException(std::current_exception()); } + throw; } - // Initialize write - return copy_gstate.CreateFileStateLocked(full_path, values); } void PartitionedCopy::EnsureFreshPartitionFileForSortedRun(PartitionWriteInfo &write_info, @@ -2266,13 +3103,18 @@ void PartitionedCopy::EnsureFreshPartitionFile(PartitionWriteInfo &write_info, c D_ASSERT(reason == FileCreationReason::SORTED_RUN_BOUNDARY || reason == FileCreationReason::ROTATION); D_ASSERT(RequiresSerializedPartitionWrites()); - optional_ptr old_file_state_ptr; + if (!write_info.file_state) { + return; + } + auto &old_file_state_ref = copy_gstate.EnsureFileStateReady( + write_info.file_state, [&](FileStateHandle &file_state) { RequestPartitionFileState(file_state, values); }); + optional_ptr old_file_state_ptr = old_file_state_ref; { annotated_lock_guard global_guard(copy_gstate.lock); if (!write_info.file_state) { return; } - old_file_state_ptr = write_info.file_state.get(); + D_ASSERT(RefersToSameObject(*old_file_state_ptr.get(), write_info.file_state.GetFileState())); annotated_lock_guard file_guard(old_file_state_ptr->lock); if (reason == FileCreationReason::SORTED_RUN_BOUNDARY && old_file_state_ptr->num_batches == 0) { return; @@ -2282,18 +3124,20 @@ void PartitionedCopy::EnsureFreshPartitionFile(PartitionWriteInfo &write_info, c } } - auto new_file_state = CreatePartitionFileState(values, reason); + FileStateHandle new_file_state; + RequestPartitionFileState(new_file_state, values, reason); - unique_ptr old_file_state; + FileStateHandle old_file_state; { annotated_lock_guard global_guard(copy_gstate.lock); D_ASSERT(write_info.file_state); - D_ASSERT(RefersToSameObject(*old_file_state_ptr.get(), *write_info.file_state)); - annotated_lock_guard file_guard(write_info.file_state->lock); + auto ¤t_file_state = write_info.file_state.GetFileState(); + D_ASSERT(RefersToSameObject(*old_file_state_ptr.get(), current_file_state)); + annotated_lock_guard file_guard(current_file_state.lock); if (reason == FileCreationReason::SORTED_RUN_BOUNDARY) { - D_ASSERT(write_info.file_state->num_batches > 0); + D_ASSERT(current_file_state.num_batches > 0); } else { - D_ASSERT(PhysicalCopyRotateNow(op, *write_info.file_state)); + D_ASSERT(PhysicalCopyRotateNow(op, current_file_state)); } old_file_state = std::move(write_info.file_state); @@ -2304,18 +3148,10 @@ void PartitionedCopy::EnsureFreshPartitionFile(PartitionWriteInfo &write_info, c } void PartitionedCopy::FinalizeActiveWrites() { - vector> files_to_finalize; - { - annotated_lock_guard aw_guard(active_writes_lock); - for (auto &entry : active_writes) { - files_to_finalize.push_back(std::move(entry.second->file_state)); - } - active_writes.clear(); - } - FinalizeFileStates(std::move(files_to_finalize)); + FinalizeFileStates(partition_writes.TakeOpenFileStates()); } -void PartitionedCopy::FinalizeFileStates(vector> files_to_finalize) { +void PartitionedCopy::FinalizeFileStates(vector files_to_finalize) { for (auto &file_state : files_to_finalize) { if (file_state) { copy_gstate.FinalizeFileState(std::move(file_state)); @@ -2323,44 +3159,26 @@ void PartitionedCopy::FinalizeFileStates(vector> fil } } -string PartitionedCopy::GetOrCreateDirectory(string path, const vector &values) { - auto &fs = FileSystem::GetFileSystem(context); - copy_gstate.CreateDir(path); - if (op.hive_file_pattern) { - for (idx_t i = 0; i < op.partition_columns.size(); i++) { - const auto &partition_col_name = op.names[op.partition_columns[i]]; - const auto &partition_value = values[i]; - string p_dir; - p_dir += HivePartitioning::Escape(partition_col_name.GetIdentifierName()); - p_dir += "="; - if (partition_value.IsNull()) { - p_dir += "__HIVE_DEFAULT_PARTITION__"; - } else { - p_dir += HivePartitioning::Escape(partition_value.ToString()); - } - path = fs.JoinPath(path, p_dir); - copy_gstate.CreateDir(path); - } - } - return path; -} - //===--------------------------------------------------------------------===// -// Copy Global State Implementation +// Copy Global State //===--------------------------------------------------------------------===// CopyToFileGlobalState::CopyToFileGlobalState(const PhysicalCopyToFile &op_p, ClientContext &context_p) : op(op_p), context(context_p), initialized(false), finalized(false), prepare_global_state(nullptr), - create_file_state_fun([&]() DUCKDB_EXCLUDES(lock) { return CreateFileState(); }), rows_copied(0), - last_file_offset(0) { + create_file_state_fun([&](FileStateHandle &file_state) DUCKDB_EXCLUDES(lock) { RequestFileState(file_state); }), + lifecycle_executor(context_p), output_files(op_p), rows_copied(0), last_file_offset(0) { } CopyToFileGlobalState::~CopyToFileGlobalState() { - if (!initialized || finalized || created_files.empty()) { + try { + WaitForLifecycleTasks(); + } catch (...) { + } + if (!initialized || finalized || !output_files.HasCreatedFiles()) { return; } // If we reach here, the query failed before Finalize was called auto &fs = FileSystem::GetFileSystem(context); - for (auto &file : created_files) { + for (auto &file : output_files.GetCreatedFiles()) { try { fs.TryRemoveFile(file); } catch (...) { @@ -2374,127 +3192,282 @@ void CopyToFileGlobalState::Initialize() { if (initialized) { return; } - annotated_lock_guard guard(lock); - if (initialized) { - return; - } - // initialize writing to the file - global_state = CreateFileStateLocked(op.file_path); + RequestFileState(global_state, op.file_path); initialized = true; } -void CopyToFileGlobalState::CreateDir(const string &dir_path) { +void CopyToFileGlobalState::PrepareOutputDirectory() { auto &fs = FileSystem::GetFileSystem(context); - if (created_directories.find(dir_path) != created_directories.end()) { - // already attempted to create this directory - return; + if (!fs.IsRemoteFile(op.file_path)) { + if (fs.FileExists(op.file_path)) { + if (op.overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE) { + fs.RemoveFile(op.file_path); + } else { + throw IOException("Cannot write to \"%s\" - it exists and is a file, not a directory! Enable " + "OVERWRITE option to overwrite the file", + op.file_path); + } + } } - if (!fs.DirectoryExists(dir_path)) { - fs.CreateDirectory(dir_path); + + if (!fs.DirectoryExists(op.file_path)) { + fs.CreateDirectory(op.file_path); + } else { + CheckDirectory(fs, op.file_path, op.overwrite_mode); + } +} + +void CopyToFileGlobalState::ScheduleOutputDirectorySetup() { + D_ASSERT(!output_directory_job); + output_directory_job = make_shared_ptr(); + try { + lifecycle_executor.Schedule(output_directory_job, CopyFileLifecycleWaitMode::INTERRUPTIBLE, + [this]() { PrepareOutputDirectory(); }); + } catch (...) { + if (!output_directory_job->IsFinished()) { + output_directory_job->CompleteException(std::current_exception()); + } + throw; + } +} + +void CopyToFileGlobalState::EnsureOutputDirectoryReady() { + if (!output_directory_job) { + return; } - created_directories.insert(dir_path); + lifecycle_executor.WaitForJob(*output_directory_job, CopyFileLifecycleWaitMode::INTERRUPTIBLE); +} + +void CopyToFileGlobalState::EnsureDirectory(const string &dir_path) { + auto &fs = FileSystem::GetFileSystem(context); + directory_manager.EnsureDirectory(fs, dir_path); } -unique_ptr -CopyToFileGlobalState::CreateFileStateLocked(string output_path, optional_ptr> partition_values) { +PendingFileState CopyToFileGlobalState::PrepareFileStateLocked(string output_path, + optional_ptr> partition_values) { auto &fs = FileSystem::GetFileSystem(context); if (output_path.empty()) { output_path = op.filename_pattern.CreateFilename(fs, op.file_path, op.file_extension, last_file_offset++); } - created_files.push_back(output_path); + auto result = output_files.ReserveFile(std::move(output_path), partition_values); + output_files.PublishCreatedPath(result, result.output_path); + return result; +} - optional_ptr written_file_info; - if (op.return_type != CopyFunctionReturnType::CHANGED_ROWS) { - written_file_info = AddFile(output_path); - } - - auto data = op.function.copy_to_initialize_global(context, *op.bind_data, output_path); - if (written_file_info) { - op.function.copy_to_get_written_statistics(context, *op.bind_data, *data, *written_file_info->file_stats); - - if (!op.partition_columns.empty()) { - D_ASSERT(partition_values); - vector partition_keys; - vector partition_values_as_varchar; - for (idx_t i = 0; i < op.partition_columns.size(); i++) { - const auto &partition_col_name = op.names[op.partition_columns[i]]; - const auto &partition_value = (*partition_values)[i]; - partition_keys.emplace_back(partition_col_name); - partition_values_as_varchar.push_back(partition_value.DefaultCastAs(LogicalType::VARCHAR)); - } - written_file_info->partition_keys = - Value::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR, std::move(partition_keys), - std::move(partition_values_as_varchar)); - } - } +void CopyToFileGlobalState::RegisterPendingFileStatePathLocked(PendingFileState &pending_file_state, + string output_path) { + output_files.PublishCreatedPath(pending_file_state, std::move(output_path)); +} +unique_ptr CopyToFileGlobalState::InitializeFileState(PendingFileState pending_file_state) { + auto data = op.function.copy_to_initialize_global(context, *op.bind_data, pending_file_state.output_path); + if (pending_file_state.written_file_info && pending_file_state.written_file_info->file_stats) { + op.function.copy_to_get_written_statistics(context, *op.bind_data, *data, + *pending_file_state.written_file_info->file_stats); + } if (op.function.initialize_operator) { op.function.initialize_operator(*data, op); } - auto res = make_uniq(std::move(data), output_path); + return make_uniq(std::move(data), pending_file_state.output_path); +} + +void CopyToFileGlobalState::RegisterPrepareGlobalStateLocked(GlobalFileState &file_state) { if (!prepare_global_state.load(std::memory_order_acquire)) { - prepare_global_state.store(res, std::memory_order_release); + prepare_global_state.store(file_state, std::memory_order_release); } - return res; } -unique_ptr CopyToFileGlobalState::CreateFileState(string output_path, - optional_ptr> partition_values) { - annotated_lock_guard guard(lock); - return CreateFileStateLocked(std::move(output_path), partition_values); +PendingFileStateOpen +CopyToFileGlobalState::CreateFileStateOpenLocked(FileStateHandle &file_state, string output_path, + optional_ptr> partition_values) { + if (file_state.HasFileState()) { + return PendingFileStateOpen(); + } + PendingFileStateOpen result; + result.pending_file_state = PrepareFileStateLocked(std::move(output_path), partition_values); + result.open_job = make_shared_ptr(); + file_state.open_job = result.open_job; + return result; } -optional_ptr CopyToFileGlobalState::AddFile(const string &file_name) { - auto file_info = make_uniq(file_name); - optional_ptr result; - if (op.return_type == CopyFunctionReturnType::WRITTEN_FILE_STATISTICS) { - file_info->file_stats = make_uniq(); - result = file_info.get(); +PendingFileStateOpen +CopyToFileGlobalState::CreatePartitionFileStateOpenLocked(FileStateHandle &file_state, string output_path, + optional_ptr> partition_values) { + if (file_state.HasFileState()) { + return PendingFileStateOpen(); } - written_files.push_back(std::move(file_info)); + PendingFileStateOpen result; + result.pending_file_state = output_files.ReserveFile(std::move(output_path), partition_values); + result.open_job = make_shared_ptr(); + file_state.open_job = result.open_job; return result; } -unique_ptr CopyToFileGlobalState::FinalizeFileStateLocked(unique_ptr file_state) { +void CopyToFileGlobalState::ScheduleFileStateOpen(PendingFileStateOpen pending_file_state_open) { + D_ASSERT(pending_file_state_open); + auto open_job = pending_file_state_open.open_job; + try { + lifecycle_executor.Schedule( + open_job, CopyFileLifecycleWaitMode::INTERRUPTIBLE, + [this, open_job, pending_file_state = std::move(pending_file_state_open.pending_file_state)]() mutable { + EnsureOutputDirectoryReady(); + open_job->Complete(InitializeFileState(std::move(pending_file_state))); + }); + } catch (...) { + if (!open_job->IsFinished()) { + open_job->CompleteException(std::current_exception()); + } + throw; + } +} + +void CopyToFileGlobalState::SchedulePartitionFileStateOpen(PartitionFileOpenRequest request) { + auto open_job = request.open_job; + D_ASSERT(open_job); + try { + lifecycle_executor.Schedule(open_job, CopyFileLifecycleWaitMode::INTERRUPTIBLE, + [this, request = std::move(request)]() mutable { request.Run(*this); }); + } catch (...) { + if (!open_job->IsFinished()) { + open_job->CompleteException(std::current_exception()); + } + throw; + } +} + +void PartitionFileOpenRequest::Run(CopyToFileGlobalState ©_gstate) { + copy_gstate.EnsureOutputDirectoryReady(); + auto &fs = FileSystem::GetFileSystem(copy_gstate.context); + for (auto &dir : directory.directories) { + copy_gstate.EnsureDirectory(dir); + } + + auto output_path = std::move(pending_file_state.output_path); + if (copy_gstate.op.overwrite_mode == CopyOverwriteMode::COPY_APPEND) { + while (fs.FileExists(output_path)) { + if (!copy_gstate.op.filename_pattern.HasUUID()) { + throw InternalException("CopyOverwriteMode::COPY_APPEND without {uuid} - and file exists"); + } + output_path = copy_gstate.op.filename_pattern.CreateFilename(fs, directory.path, + copy_gstate.op.file_extension, offset); + } + } + + { + annotated_lock_guard guard(copy_gstate.lock); + copy_gstate.RegisterPendingFileStatePathLocked(pending_file_state, std::move(output_path)); + } + open_job->Complete(copy_gstate.InitializeFileState(std::move(pending_file_state))); +} + +void CopyToFileGlobalState::RequestFileState(FileStateHandle &file_state, string output_path, + optional_ptr> partition_values) { + PendingFileStateOpen pending_file_state_open; + { + annotated_lock_guard guard(lock); + pending_file_state_open = CreateFileStateOpenLocked(file_state, std::move(output_path), partition_values); + } + if (pending_file_state_open) { + ScheduleFileStateOpen(std::move(pending_file_state_open)); + } +} + +GlobalFileState & +CopyToFileGlobalState::EnsureFileStateReady(FileStateHandle &file_state, + const std::function &create_file_state_fun) { + while (true) { + shared_ptr open_job; + { + annotated_lock_guard guard(lock); + if (file_state.HasFileState()) { + open_job = file_state.open_job; + } + } + if (!open_job) { + create_file_state_fun(file_state); + continue; + } + lifecycle_executor.WaitForJob(*open_job, CopyFileLifecycleWaitMode::INTERRUPTIBLE); + { + annotated_lock_guard guard(lock); + if (file_state.open_job == open_job) { + auto &result = open_job->GetFileState(); + RegisterPrepareGlobalStateLocked(result); + return result; + } + } + } +} + +FileStateHandle CopyToFileGlobalState::FinalizeFileStateLocked(FileStateHandle file_state) { auto prepare_state = prepare_global_state.load(std::memory_order_acquire); - if (prepare_state && RefersToSameObject(*prepare_state.get(), *file_state)) { + if (prepare_state && RefersToSameObject(*prepare_state.get(), file_state.GetFileState())) { prepare_global_state_owned = std::move(file_state); - return nullptr; + return FileStateHandle(); } return file_state; } -void CopyToFileGlobalState::FinalizeFileState(unique_ptr file_state) { +void CopyToFileGlobalState::FinalizeFileState(FileStateHandle file_state) { + if (!file_state) { + return; + } + lifecycle_executor.WaitForJob(*file_state.open_job, CopyFileLifecycleWaitMode::DRAIN); { annotated_lock_guard guard(lock); file_state = FinalizeFileStateLocked(std::move(file_state)); } if (file_state) { - op.function.copy_to_finalize(context, *op.bind_data, *file_state->data); + auto finalize_job = make_shared_ptr(); + auto state = file_state.TakeFileState(); + auto state_holder = make_shared_ptr>(std::move(state)); + auto finalize = op.function.copy_to_finalize; + auto &context_ref = context; + auto &bind_data = *op.bind_data; + try { + lifecycle_executor.Schedule(finalize_job, CopyFileLifecycleWaitMode::DRAIN, + [finalize, &context_ref, &bind_data, state_holder]() mutable { + FinalizeLifecycleFileState(context_ref, finalize, bind_data, + std::move(*state_holder)); + }); + } catch (...) { + if (!finalize_job->IsFinished() && state_holder && *state_holder) { + try { + FinalizeLifecycleFileState(context_ref, finalize, bind_data, std::move(*state_holder)); + } catch (...) { + } + } + throw; + } } } -unique_ptr CopyToFileGlobalState::TryFinalizeOwnedFileStateLocked() { +FileStateHandle CopyToFileGlobalState::TryFinalizeOwnedFileStateLocked() { if (prepare_global_state_owned) { + prepare_global_state.store(nullptr, std::memory_order_release); return std::move(prepare_global_state_owned); } - return nullptr; + return FileStateHandle(); } void CopyToFileGlobalState::TryFinalizeOwnedFileState() { - unique_ptr file_state; + FileStateHandle file_state; { annotated_lock_guard guard(lock); file_state = TryFinalizeOwnedFileStateLocked(); } if (file_state) { - op.function.copy_to_finalize(context, *op.bind_data, *file_state->data); + FinalizeFileState(std::move(file_state)); } } +void CopyToFileGlobalState::WaitForLifecycleTasks() { + lifecycle_executor.WaitAll(CopyFileLifecycleWaitMode::DRAIN); +} + //===--------------------------------------------------------------------===// -// Copy Local State Implementation +// Copy Local State //===--------------------------------------------------------------------===// CopyToFileLocalState::CopyToFileLocalState(const PhysicalCopyToFile &op_p, ExecutionContext &context_p, CopyToFileGlobalState &gstate_p) @@ -2514,6 +3487,10 @@ PhysicalCopyToFile::PhysicalCopyToFile(PhysicalPlan &physical_plan, vector PhysicalCopyToFile::ParamsToString() const { InsertionOrderPreservingMap result; result["FORMAT"] = StringUtil::Upper(function.name.GetIdentifierName()); @@ -2598,32 +3575,10 @@ static bool PhysicalCopyRotateNow(const PhysicalCopyToFile &op, GlobalFileState //===--------------------------------------------------------------------===// unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext &context) const { if (partition_output || per_thread_output || Rotate()) { - auto &fs = FileSystem::GetFileSystem(context); - if (!fs.IsRemoteFile(file_path)) { - if (fs.FileExists(file_path)) { - // the target file exists AND is a file (not a directory) - // for local files we can remove the file if OVERWRITE_OR_IGNORE is enabled - if (overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE) { - fs.RemoveFile(file_path); - } else { - throw IOException("Cannot write to \"%s\" - it exists and is a file, not a directory! Enable " - "OVERWRITE option to overwrite the file", - file_path); - } - } - } - - // what if the target exists and is a directory - if (!fs.DirectoryExists(file_path)) { - fs.CreateDirectory(file_path); - } else { - CheckDirectory(fs, file_path, overwrite_mode); - } - auto state = make_uniq(*this, context); + state->ScheduleOutputDirectorySetup(); if (!partition_output && !per_thread_output && Rotate() && write_empty_file) { - annotated_lock_guard guard(state->lock); - state->global_state = state->CreateFileStateLocked(); + state->RequestFileState(state->global_state); } if (partition_output) { @@ -2661,6 +3616,15 @@ SinkResultType PhysicalCopyToFile::Sink(ExecutionContext &context, DataChunk &ch return SinkResultType::NEED_MORE_INPUT; } + if (per_thread_output) { + if (!lstate.global_file_state) { + gstate.create_file_state_fun(lstate.global_file_state); + } + } else { + gstate.create_file_state_fun(gstate.global_state); + } + auto &file_state = per_thread_output ? lstate.global_file_state : gstate.global_state; + if (!lstate.batch) { lstate.batch = make_uniq(context.client, expected_types); lstate.batch->InitializeAppend(lstate.batch_append_state); @@ -2670,9 +3634,7 @@ SinkResultType PhysicalCopyToFile::Sink(ExecutionContext &context, DataChunk &ch const CopyFunctionBatchAnalyzer batch_analyzer(*lstate.batch, batch_size, batch_size_bytes); if (batch_analyzer.MeetsFlushCriteria()) { lstate.batch_append_state.current_chunk_state.handles.clear(); - auto &file_state_ptr = per_thread_output ? lstate.global_file_state : gstate.global_state; - PrepareAndFlushBatch(context.client, gstate, file_state_ptr, gstate.create_file_state_fun, - std::move(lstate.batch)); + PrepareAndFlushBatch(context.client, gstate, file_state, gstate.create_file_state_fun, std::move(lstate.batch)); } return SinkResultType::NEED_MORE_INPUT; @@ -2762,39 +3724,39 @@ SinkFinalizeType PhysicalCopyToFile::Finalize(Pipeline &pipeline, Event &event, // already happened in combine if (NumericCast(gstate.rows_copied.load()) == 0 && sink_state != nullptr) { // no rows from source, write schema to file - annotated_lock_guard guard(gstate.lock); - gstate.global_state = gstate.CreateFileStateLocked(); + gstate.RequestFileState(gstate.global_state); } } if (gstate.global_state) { gstate.FinalizeFileState(std::move(gstate.global_state)); - - if (use_tmp_file) { - D_ASSERT(!per_thread_output); - D_ASSERT(!partition_output); - D_ASSERT(!file_size_bytes.IsValid()); - D_ASSERT(!Rotate()); - MoveTmpFile(context, file_path); - } } gstate.TryFinalizeOwnedFileState(); + gstate.WaitForLifecycleTasks(); + + if (use_tmp_file) { + D_ASSERT(!per_thread_output); + D_ASSERT(!partition_output); + D_ASSERT(!file_size_bytes.IsValid()); + D_ASSERT(!Rotate()); + MoveTmpFile(context, file_path); + } return SinkFinalizeType::READY; } void PhysicalCopyToFile::PrepareAndFlushBatch(ClientContext &context, GlobalSinkState &gstate_p, - unique_ptr &file_state_ptr, - const std::function()> &create_file_state_fun, + FileStateHandle &file_state, + const std::function &create_file_state_fun, unique_ptr batch) const { auto [batch_analyzer, prepared_batch] = - PrepareBatch(context, gstate_p, file_state_ptr, create_file_state_fun, std::move(batch)); - FlushBatch(context, gstate_p, file_state_ptr, create_file_state_fun, batch_analyzer, std::move(prepared_batch)); + PrepareBatch(context, gstate_p, file_state, create_file_state_fun, std::move(batch)); + FlushBatch(context, gstate_p, file_state, create_file_state_fun, batch_analyzer, std::move(prepared_batch)); } //===--------------------------------------------------------------------===// -// Legacy +// Legacy Batch API //===--------------------------------------------------------------------===// struct LegacyCopyPreparedBatch : public PreparedBatchData { explicit LegacyCopyPreparedBatch(unique_ptr collection_p) @@ -2835,58 +3797,11 @@ static void FlushLegacyCopyBatch(ClientContext &context, const CopyFunction &fun } //===--------------------------------------------------------------------===// -// Prepare/Flush Batch +// Batch Interface //===--------------------------------------------------------------------===// -static void EnsureFileState(CopyToFileGlobalState &gstate, unique_ptr &file_state_ptr, - const std::function()> &create_file_state_fun) { - auto file_state_key = &file_state_ptr; - while (true) { - bool create_file_state = false; - { - annotated_lock_guard guard(gstate.lock); - if (file_state_ptr) { - return; - } - if (gstate.creating_file_states.insert(file_state_key).second) { - create_file_state = true; - } - } - - if (!create_file_state) { - TaskScheduler::YieldThread(); - continue; - } - - unique_ptr new_file_state; - try { - new_file_state = create_file_state_fun(); - } catch (...) { - annotated_lock_guard guard(gstate.lock); - gstate.creating_file_states.erase(file_state_key); - throw; - } - - unique_ptr unused_file_state; - { - annotated_lock_guard guard(gstate.lock); - if (!file_state_ptr) { - file_state_ptr = std::move(new_file_state); - } else { - unused_file_state = std::move(new_file_state); - } - gstate.creating_file_states.erase(file_state_key); - } - if (unused_file_state) { - gstate.FinalizeFileState(std::move(unused_file_state)); - } - return; - } -} - pair> -PhysicalCopyToFile::PrepareBatch(ClientContext &context, GlobalSinkState &gstate_p, - unique_ptr &file_state_ptr, - const std::function()> &create_file_state_fun, +PhysicalCopyToFile::PrepareBatch(ClientContext &context, GlobalSinkState &gstate_p, FileStateHandle &file_state, + const std::function &create_file_state_fun, unique_ptr batch) const { auto &gstate = gstate_p.Cast(); const CopyFunctionBatchAnalyzer batch_analyzer(*batch, batch_size, batch_size_bytes); @@ -2898,8 +3813,7 @@ PhysicalCopyToFile::PrepareBatch(ClientContext &context, GlobalSinkState &gstate // Ensure we have a global state for prepares auto prepare_global_state = gstate.prepare_global_state.load(std::memory_order_acquire); if (!prepare_global_state) { - D_ASSERT(!file_state_ptr); - EnsureFileState(gstate, file_state_ptr, create_file_state_fun); + gstate.EnsureFileStateReady(file_state, create_file_state_fun); prepare_global_state = gstate.prepare_global_state.load(std::memory_order_acquire); D_ASSERT(prepare_global_state); } @@ -2908,48 +3822,54 @@ PhysicalCopyToFile::PrepareBatch(ClientContext &context, GlobalSinkState &gstate return {batch_analyzer, function.prepare_batch(context, *bind_data, *prepare_global_state->data, std::move(batch))}; } -void PhysicalCopyToFile::FlushBatch(ClientContext &context, GlobalSinkState &gstate_p, - unique_ptr &file_state_ptr, - const std::function()> &create_file_state_fun, +void PhysicalCopyToFile::FlushBatch(ClientContext &context, GlobalSinkState &gstate_p, FileStateHandle &file_state, + const std::function &create_file_state_fun, const CopyFunctionBatchAnalyzer &batch_analyzer, unique_ptr prepared_batch) const { auto &gstate = gstate_p.Cast(); while (true) { - EnsureFileState(gstate, file_state_ptr, create_file_state_fun); + gstate.EnsureFileStateReady(file_state, create_file_state_fun); // Decide which file to flush to annotated_unique_lock global_guard(gstate.lock); - if (!file_state_ptr) { + // Another thread may have rotated the file state since EnsureFileStateReady, so re-check readiness + if (!file_state.IsReady()) { global_guard.unlock(); - TaskScheduler::YieldThread(); + gstate.lifecycle_executor.WorkOnTaskOrYield(); continue; } - annotated_unique_lock file_guard(file_state_ptr->lock); - if (PhysicalCopyRotateNow(*this, *file_state_ptr)) { + auto ready_file_state = file_state.GetFileStatePtr(); + if (!ready_file_state) { + global_guard.unlock(); + gstate.lifecycle_executor.WorkOnTaskOrYield(); + continue; + } + annotated_unique_lock file_guard(ready_file_state->lock); + if (PhysicalCopyRotateNow(*this, *ready_file_state)) { // Global state must be rotated. Move to local scope, create an new one, and immediately release global lock - auto owned_file_state = std::move(file_state_ptr); + auto owned_file_state = std::move(file_state); file_guard.unlock(); global_guard.unlock(); - EnsureFileState(gstate, file_state_ptr, create_file_state_fun); + gstate.EnsureFileStateReady(file_state, create_file_state_fun); // Finalize this file! gstate.FinalizeFileState(std::move(owned_file_state)); } else { global_guard.unlock(); - file_state_ptr->num_batches++; + ready_file_state->num_batches++; DUCKDB_LOG(context, PhysicalOperatorLogType, *this, "PhysicalCopyToFile", "FlushBatch", - {{"file", file_state_ptr->path}, + {{"file", ready_file_state->path}, {"rows", to_string(batch_analyzer.current_batch_size)}, {"size", to_string(batch_analyzer.current_batch_size_bytes)}, {"reason", EnumUtil::ToString(batch_analyzer.ToReason())}}); if (UsesLegacyCopyBatchAPI(function)) { - FlushLegacyCopyBatch(context, function, *bind_data, *file_state_ptr->data, *prepared_batch); + FlushLegacyCopyBatch(context, function, *bind_data, *ready_file_state->data, *prepared_batch); } else { - function.flush_batch(context, *bind_data, *file_state_ptr->data, *prepared_batch); + function.flush_batch(context, *bind_data, *ready_file_state->data, *prepared_batch); } break; } @@ -2978,21 +3898,23 @@ unique_ptr PhysicalCopyToFile::GetGlobalSourceState(ClientCon SourceResultType PhysicalCopyToFile::GetDataInternal(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { auto &gstate = sink_state->Cast(); + gstate.WaitForLifecycleTasks(); annotated_lock_guard global_guard(gstate.lock); if (return_type == CopyFunctionReturnType::WRITTEN_FILE_STATISTICS) { auto &source_state = input.global_state.Cast(); - idx_t next_end = MinValue(source_state.offset + STANDARD_VECTOR_SIZE, gstate.written_files.size()); + idx_t next_end = + MinValue(source_state.offset + STANDARD_VECTOR_SIZE, gstate.output_files.WrittenFileCount()); idx_t count = next_end - source_state.offset; for (idx_t i = 0; i < count; i++) { - auto &file_entry = *gstate.written_files[source_state.offset + i]; + auto &file_entry = gstate.output_files.GetWrittenFile(source_state.offset + i); if (use_tmp_file) { file_entry.file_path = GetNonTmpFile(context.client, file_entry.file_path); } ReturnStatistics(chunk, file_entry); } source_state.offset += count; - return source_state.offset < gstate.written_files.size() ? SourceResultType::HAVE_MORE_OUTPUT - : SourceResultType::FINISHED; + return source_state.offset < gstate.output_files.WrittenFileCount() ? SourceResultType::HAVE_MORE_OUTPUT + : SourceResultType::FINISHED; } switch (return_type) { @@ -3002,7 +3924,7 @@ SourceResultType PhysicalCopyToFile::GetDataInternal(ExecutionContext &context, case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: { chunk.data[0].Append(Value::BIGINT(NumericCast(gstate.rows_copied.load()))); vector file_name_list; - for (auto &file_info : gstate.written_files) { + for (auto &file_info : gstate.output_files.GetWrittenFiles()) { if (use_tmp_file) { file_name_list.emplace_back(GetNonTmpFile(context.client, file_info->file_path)); } else { diff --git a/src/duckdb/src/execution/physical_operator.cpp b/src/duckdb/src/execution/physical_operator.cpp index e8c5066df..f725ae0cd 100644 --- a/src/duckdb/src/execution/physical_operator.cpp +++ b/src/duckdb/src/execution/physical_operator.cpp @@ -1,4 +1,5 @@ #include "duckdb/execution/physical_operator.hpp" +#include "duckdb/common/vector/dictionary_vector.hpp" #include "duckdb/function/table_function.hpp" #include "duckdb/common/printer.hpp" @@ -340,10 +341,8 @@ enum class CachingPhysicalOperatorExecuteMode : uint8_t { RETURN_CACHED }; -static CachingPhysicalOperatorExecuteMode SelectExecutionMode(const DataChunk &chunk, - const OperatorResultType child_result, - CachingOperatorState &state, - ClientContext &client_context) { +static CachingPhysicalOperatorExecuteMode +SelectExecutionMode(const DataChunk &chunk, const OperatorResultType child_result, CachingOperatorState &state) { if (state.can_cache_chunk == OperatorCachingMode::NONE) { return CachingPhysicalOperatorExecuteMode::RETURN_CHUNK; } @@ -378,12 +377,7 @@ static CachingPhysicalOperatorExecuteMode SelectExecutionMode(const DataChunk &c return CachingPhysicalOperatorExecuteMode::RETURN_CHUNK; } else if (chunk.size() <= CachingPhysicalOperator::CACHE_THRESHOLD && !needs_continuation_chunk) { // We have filtered out a significant amount of tuples - - if (!state.cached_chunk) { - // Initialize cached_chunk - state.cached_chunk = make_uniq(); - state.cached_chunk->Initialize(Allocator::Get(client_context), chunk.GetTypes()); - } + // The cache is materialised lazily by AppendToCache on first use if (has_space_for_chunk_in_cache) { // We can just append, do and return empty chunk @@ -422,6 +416,110 @@ static CachingPhysicalOperatorExecuteMode SelectExecutionMode(const DataChunk &c return CachingPhysicalOperatorExecuteMode::RETURN_CHUNK; } +static bool ChunkHasGlobalDictionary(const DataChunk &chunk) { + for (idx_t col_idx = 0; col_idx < chunk.ColumnCount(); col_idx++) { + if (DictionaryVector::IsGlobalDictionary(chunk.data[col_idx])) { + return true; + } + } + return false; +} + +//! Switch the empty cache to dictionary mode: pin each global dictionary column's upstream entry and allocate its +//! sel accumulator. Columns keep a real (resettable) cache so a flushed dict column flattens on the next Reset. +static void SeedDictCache(CachingOperatorState &state, DataChunk &source) { + const idx_t col_count = source.ColumnCount(); + state.dict_columns.clear(); + state.dict_columns.resize(col_count); + for (idx_t col_idx = 0; col_idx < col_count; col_idx++) { + if (!DictionaryVector::IsGlobalDictionary(source.data[col_idx])) { + continue; + } + auto &slot = state.dict_columns[col_idx]; + slot.entry = source.data[col_idx].BufferMutable().Cast().GetEntryPtr(); + slot.accumulated_sel.Initialize(STANDARD_VECTOR_SIZE); + } + state.dict_cache_active = true; +} + +//! Append source into the cache (created lazily). On the first append into an empty cache, detect +//! global dictionary columns; those concatenate their selection indices instead of flattening. +static void AppendToCache(CachingOperatorState &state, DataChunk &source, ClientContext &client_context) { + if (!state.cached_chunk) { + state.cached_chunk = make_uniq(); + state.cached_chunk->Initialize(Allocator::Get(client_context), source.GetTypes()); + } + auto &cache = *state.cached_chunk; + if (cache.size() == 0 && !state.dict_cache_active && ChunkHasGlobalDictionary(source)) { + SeedDictCache(state, source); + } + if (!state.dict_cache_active) { + // no dict columns: plain flat append + cache.Append(source); + return; + } + const idx_t base = cache.size(); + const idx_t added = source.size(); + // accumulated_sel is sized STANDARD_VECTOR_SIZE and the caching state machine guarantees base + added stays + // within it. Index accumulation has no overrun guard of its own (unlike the flat Append), so assert it. + D_ASSERT(base + added <= STANDARD_VECTOR_SIZE); + for (idx_t col_idx = 0; col_idx < cache.ColumnCount(); col_idx++) { + auto &slot = state.dict_columns[col_idx]; + if (slot.entry) { + // dict column: every later chunk must be the same global dictionary. Throw (not D_ASSERT) + // because the Cast below is UB on a non-dict vector in release, accumulating foreign bytes as indices. + auto &source_col = source.data[col_idx]; + if (source_col.GetVectorType() != VectorType::DICTIONARY_VECTOR || + DictionaryVector::DictionaryId(source_col).empty() || + !DictionaryVector::IsGlobalDictionary(source_col)) { + throw InternalException("dict-surviving cache: column %llu received a non-global-dictionary " + "chunk after being seeded for dictionary caching", + static_cast(col_idx)); + } + // An id mismatch past the encoding check is a producer bug (never user-triggerable), so assert. + D_ASSERT(source_col.Buffer().Cast().GetEntry().id == slot.entry->id); + const auto &source_sel = DictionaryVector::SelVector(source_col); + for (idx_t row = 0; row < added; row++) { + slot.accumulated_sel.set_index(base + row, source_sel.get_index(row)); + } + } else { + // flat column: append per column. The D_ASSERT catches a refactor that routes a dict placeholder here. + D_ASSERT(!slot.entry); + FlatVector::SetSize(cache.data[col_idx], base); + cache.data[col_idx].Append(source.data[col_idx], added, VectorAppendMode::ERROR_ON_NO_SPACE); + } + } + // dict columns are rewrapped on flush, flat columns already sized; this only sets the cardinality + cache.SetChildCardinality(base + added); +} + +//! After moving the cache into chunk, re-wrap each dict column as a DICTIONARY_VECTOR over the +//! pinned upstream entry, carrying its id and global_dictionary flag through unchanged. +static void RewrapDictColumns(CachingOperatorState &state, DataChunk &chunk, idx_t count) { + for (idx_t col_idx = 0; col_idx < chunk.ColumnCount(); col_idx++) { + auto &slot = state.dict_columns[col_idx]; + if (!slot.entry) { + continue; + } + chunk.data[col_idx].Dictionary(slot.entry, slot.accumulated_sel, count); + } +} + +//! Move the cache into chunk (reconstructing dict columns) and re-initialize an empty cache for the next +//! batch. With no dict columns this is the plain flat flush. +static void FlushCacheToChunk(CachingOperatorState &state, DataChunk &chunk, ClientContext &client_context) { + if (!state.dict_cache_active) { + chunk.Move(*state.cached_chunk); + state.cached_chunk->Initialize(Allocator::Get(client_context), chunk.GetTypes()); + return; + } + const idx_t count = state.cached_chunk->size(); + chunk.Move(*state.cached_chunk); + RewrapDictColumns(state, chunk, count); + state.cached_chunk->Initialize(Allocator::Get(client_context), chunk.GetTypes()); + state.ResetDictCache(); +} + OperatorResultType CachingPhysicalOperator::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, GlobalOperatorState &gstate, OperatorState &state_p) const { auto &state = state_p.Cast(); @@ -452,33 +550,32 @@ OperatorResultType CachingPhysicalOperator::Execute(ExecutionContext &context, D } } - const auto execution_mode = SelectExecutionMode(chunk, child_result, state, context.client); + const auto execution_mode = SelectExecutionMode(chunk, child_result, state); + // Appends and flushes MUST route through AppendToCache / FlushCacheToChunk: a raw Append flattens a zero-width + // dict placeholder and a raw flush drops the dict. (The continuation case below Moves a fresh chunk in -- a + // full replacement, not an append -- so raw dict columns pass through verbatim.) switch (execution_mode) { case CachingPhysicalOperatorExecuteMode::RETURN_CACHED_APPEND_CHUNK: { auto tmp = make_uniq(); tmp->Move(chunk); - chunk.Move(*state.cached_chunk); - state.cached_chunk->Initialize(Allocator::Get(context.client), chunk.GetTypes()); - state.cached_chunk->Append(*tmp); + FlushCacheToChunk(state, chunk, context.client); + AppendToCache(state, *tmp, context.client); break; } case CachingPhysicalOperatorExecuteMode::RETURN_CACHED_PLUS_CHUNK: - state.cached_chunk->Append(chunk); - chunk.Move(*state.cached_chunk); - state.cached_chunk->Initialize(Allocator::Get(context.client), chunk.GetTypes()); + AppendToCache(state, chunk, context.client); + FlushCacheToChunk(state, chunk, context.client); break; case CachingPhysicalOperatorExecuteMode::RETURN_CACHED: D_ASSERT(chunk.size() == 0); - chunk.Move(*state.cached_chunk); - state.cached_chunk->Initialize(Allocator::Get(context.client), chunk.GetTypes()); + FlushCacheToChunk(state, chunk, context.client); break; case CachingPhysicalOperatorExecuteMode::RETURN_CACHED_THEN_CHUNK_VIA_CONTINUATION: { // Swap chunk and *state.cached_chunk auto tmp = make_uniq(); tmp->Move(chunk); - chunk.Move(*state.cached_chunk); - state.cached_chunk->Initialize(Allocator::Get(context.client), chunk.GetTypes()); + FlushCacheToChunk(state, chunk, context.client); state.cached_chunk->Move(*tmp); // Now chunk holds what was in (*state.cached_chunk), and it's returned directly @@ -488,7 +585,7 @@ OperatorResultType CachingPhysicalOperator::Execute(ExecutionContext &context, D return OperatorResultType::HAVE_MORE_OUTPUT; } case CachingPhysicalOperatorExecuteMode::APPEND_CHUNK: { - state.cached_chunk->Append(chunk); + AppendToCache(state, chunk, context.client); chunk.Reset(); break; } @@ -496,6 +593,17 @@ OperatorResultType CachingPhysicalOperator::Execute(ExecutionContext &context, D break; } + // A flushed/reset dict column can leave the reused output chunk holding a DICTIONARY_VECTOR over a null cache + // slot that Reset cannot flatten; on an empty result that desyncs the chunk, so flatten stale columns to flat. + if (chunk.size() == 0) { + for (auto &vector : chunk.data) { + const auto vector_type = vector.GetVectorType(); + if (vector_type != VectorType::FLAT_VECTOR && vector_type != VectorType::CONSTANT_VECTOR) { + vector.Initialize(); + } + } + } + return child_result; } @@ -504,7 +612,13 @@ OperatorFinalizeResultType CachingPhysicalOperator::FinalExecute(ExecutionContex OperatorState &state_p) const { auto &state = state_p.Cast(); if (state.cached_chunk) { + const idx_t count = state.cached_chunk->size(); + const bool dict_cache_active = state.dict_cache_active; chunk.Move(*state.cached_chunk); + if (dict_cache_active) { + RewrapDictColumns(state, chunk, count); + state.ResetDictCache(); + } state.cached_chunk.reset(); } return OperatorFinalizeResultType::FINISHED; diff --git a/src/duckdb/src/execution/physical_plan/plan_window.cpp b/src/duckdb/src/execution/physical_plan/plan_window.cpp index a85dbcba8..08ef06297 100644 --- a/src/duckdb/src/execution/physical_plan/plan_window.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_window.cpp @@ -112,15 +112,17 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalWindow &op) { continue; } - // CSE Elimination: Search for a previous match + // CSE Elimination: Search for a previous match (volatile expressions must not be deduplicated) bool cse = false; - for (idx_t i = 0; i < matching.size(); ++i) { - const auto match_idx = matching[i]; - auto &match_expr = op.expressions[match_idx]->Cast(); - if (wexpr.Equals(match_expr)) { - projection_map[input_width + expr_idx] = output_pos + i; - cse = true; - break; + if (!wexpr.IsVolatile()) { + for (idx_t i = 0; i < matching.size(); ++i) { + const auto match_idx = matching[i]; + auto &match_expr = op.expressions[match_idx]->Cast(); + if (wexpr.Equals(match_expr)) { + projection_map[input_width + expr_idx] = output_pos + i; + cse = true; + break; + } } } if (cse) { diff --git a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp index e02bfc322..11419bed3 100644 --- a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp +++ b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp @@ -6,10 +6,13 @@ #include "duckdb/common/types/list_segment.hpp" #include "duckdb/function/aggregate/list_aggregate.hpp" #include "duckdb/function/aggregate_function.hpp" +#include "duckdb/function/aggregate_state_layout.hpp" #include "duckdb/function/function_binder.hpp" +#include "duckdb/function/scalar/generic_common.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/parser/expression_map.hpp" #include "duckdb/parallel/thread_context.hpp" @@ -19,6 +22,42 @@ namespace duckdb { namespace { + +//! Maps each ORDER BY key to a buffered column, reusing an argument column when the key matches one. Returns the +//! buffered column types and fills `orders` with the (column, modifiers) per key. +vector MapSortedColumns(const vector> &children, + const vector &order_bys, + vector &orders) { + vector column_types; + for (const auto &child : children) { + column_types.emplace_back(child->GetReturnType()); + } + for (const auto &order : order_bys) { + idx_t column = DConstants::INVALID_INDEX; + for (idx_t arg = 0; arg < children.size(); ++arg) { + if (children[arg]->Equals(*order.expression)) { + column = arg; + break; + } + } + if (column == DConstants::INVALID_INDEX) { + column = column_types.size(); + column_types.emplace_back(order.expression->GetReturnType()); + } + orders.push_back(SortedAggregateStateOrder {column, order.type, order.null_order}); + } + return column_types; +} + +//! The struct type of a buffered row: the buffered columns named v0, v1, ... +LogicalType BufferStructType(const vector &column_types) { + child_list_t children; + for (idx_t i = 0; i < column_types.size(); i++) { + children.emplace_back("v" + to_string(i), column_types[i]); + } + return LogicalType::STRUCT(std::move(children)); +} + struct SortedAggregateBindData : public FunctionData { using Expressions = vector>; using BindInfoPtr = unique_ptr; @@ -28,58 +67,23 @@ struct SortedAggregateBindData : public FunctionData { BindInfoPtr &bind_info, OrderBys &order_bys) : context(context), function(aggregate), bind_info(std::move(bind_info)), threshold(Settings::Get(context)) { - // Describe the arguments. - for (const auto &child : children) { - buffered_cols.emplace_back(buffered_cols.size()); - buffered_types.emplace_back(child->GetReturnType()); - - // Column 0 in the sort data is the group number - scan_cols.emplace_back(buffered_cols.size()); + vector order_spec; + buffered_types = MapSortedColumns(children, order_bys, order_spec); + const idx_t argument_count = children.size(); + // the arguments are the leading buffered columns (referencing their input slot); appended sort keys follow + buffered_cols.resize(buffered_types.size()); + for (idx_t i = 0; i < argument_count; i++) { + buffered_cols[i] = i; + scan_cols.emplace_back(i + 1); + scan_types.emplace_back(buffered_types[i]); } - scan_types = buffered_types; - - // The first sort column is the group number. It is prefixed onto the buffered data - sort_types.emplace_back(LogicalType::USMALLINT); - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, - make_uniq(sort_types.back(), 0U)); - - // Determine whether we are sorted on all the arguments. - // Even if we are not, we want to share inputs for sorting. - for (idx_t ord_idx = 0; ord_idx < order_bys.size(); ++ord_idx) { - auto order = order_bys[ord_idx].Copy(); - bool matched = false; - const auto &type = order.expression->GetReturnType(); - - for (idx_t arg_idx = 0; arg_idx < children.size(); ++arg_idx) { - auto &child = children[arg_idx]; - if (child->Equals(*order.expression)) { - order.expression = make_uniq(type, arg_idx + 1); - matched = true; - break; - } - } - - if (!matched) { - sorted_on_args = false; - buffered_cols.emplace_back(children.size() + ord_idx); - buffered_types.emplace_back(type); - order.expression = make_uniq(type, buffered_cols.size()); + for (idx_t o = 0; o < order_spec.size(); o++) { + if (order_spec[o].column >= argument_count) { + buffered_cols[order_spec[o].column] = argument_count + o; } - - orders.emplace_back(std::move(order)); - } - - // The buffered rows are stored in a linked list of structs - child_list_t buffered_children; - for (idx_t i = 0; i < buffered_types.size(); i++) { - buffered_children.emplace_back("v" + to_string(i), buffered_types[i]); - sort_types.emplace_back(buffered_types[i]); } - buffered_struct_type = LogicalType::STRUCT(std::move(buffered_children)); - GetSegmentDataFunctions(buffered_funcs, buffered_struct_type); - - // Only scan the argument columns after sorting - sort = make_uniq(context, orders, sort_types, scan_cols); + sorted_on_args = (buffered_types.size() == argument_count); + BuildSort(order_spec); } SortedAggregateBindData(ClientContext &context, BoundAggregateExpression &expr) @@ -92,6 +96,44 @@ struct SortedAggregateBindData : public FunctionData { expr.ArgOrdersMutable()) { } + //! Reconstruct from an exported buffer state - the buffer struct, the per-key (column, modifiers) and the leading + //! argument count fully describe the layout, no original expressions are needed. + SortedAggregateBindData(ClientContext &context, const BoundAggregateFunction &inner_function, + unique_ptr inner_bind_info, const LogicalType &buffer_struct, + const vector &order_spec, idx_t argument_count) + : context(context), function(inner_function), bind_info(std::move(inner_bind_info)), + threshold(Settings::Get(context)) { + for (auto &child : StructType::GetChildTypes(buffer_struct)) { + buffered_cols.emplace_back(buffered_cols.size()); + buffered_types.emplace_back(child.second); + } + for (idx_t i = 0; i < argument_count; i++) { + scan_cols.emplace_back(i + 1); + scan_types.emplace_back(buffered_types[i]); + } + sorted_on_args = (argument_count == buffered_types.size()); + BuildSort(order_spec); + } + + //! Builds the sort once the buffered columns and per-key (column, modifiers) are known: prefixes the group number, + //! lays out the buffered struct and creates the sort. Sort keys reference buffered columns, offset by the prefix. + void BuildSort(const vector &order_spec) { + sort_types.emplace_back(LogicalType::USMALLINT); + orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, + make_uniq(LogicalType::USMALLINT, 0U)); + for (const auto &buffered_type : buffered_types) { + sort_types.emplace_back(buffered_type); + } + buffered_struct_type = BufferStructType(buffered_types); + GetSegmentDataFunctions(buffered_funcs, buffered_struct_type); + for (const auto &entry : order_spec) { + orders.emplace_back(entry.order_type, entry.null_order, + make_uniq(buffered_types[entry.column], + UnsafeNumericCast(entry.column + 1))); + } + sort = make_uniq(context, orders, sort_types, scan_cols); + } + SortedAggregateBindData(const SortedAggregateBindData &other) : context(other.context), function(other.function), sort_types(other.sort_types), scan_cols(other.scan_cols), scan_types(other.scan_types), buffered_cols(other.buffered_cols), buffered_types(other.buffered_types), @@ -471,8 +513,56 @@ struct SortedAggregateFunction { } }; +//! The exported state of a sorted aggregate is its buffer of values: a LIST. +AggregateStateLayout SortedAggregateGetStateType(AggregateLayoutInput &input) { + auto &bind_data = input.bind_data->Cast(); + AggregateStateLayout layout; + layout.type = LogicalType::LIST(bind_data.buffered_struct_type); + layout.total_state_size = AlignValue(sizeof(SortedAggregateState)); + layout.field = BuildStateField>(); + AggregateStateField::PopulateListFunctions(layout.type, layout.field); + return layout; +} + +//! Builds the sorted aggregate wrapper AggregateFunction (shared by the forward export path and the re-bind path). +AggregateFunction CreateSortedAggregateWrapper(const Identifier &name, const vector &arguments, + const LogicalType &return_type, FunctionNullHandling null_handling) { + AggregateFunction ordered_aggregate( + name, arguments, return_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, SortedAggregateFunction::ScatterUpdate, + ListCombineFunction, SortedAggregateFunction::Finalize, null_handling, nullptr, + nullptr, nullptr, nullptr, SortedAggregateFunction::WindowBatch); + ordered_aggregate.SetInitLocalStateFinalizeCallback(SortedAggregateFinalizeState::Init); + ordered_aggregate.SetStructStateExport(SortedAggregateGetStateType); + return ordered_aggregate; +} + } // namespace +void FunctionBinder::GetSortedAggregateStateLayout(const BoundAggregateExpression &expr, LogicalType &buffer_struct, + vector &orders, idx_t &argument_count) { + D_ASSERT(expr.GetOrderBys()); + argument_count = expr.GetChildren().size(); + buffer_struct = BufferStructType(MapSortedColumns(expr.GetChildren(), expr.GetOrderBys()->orders, orders)); +} + +pair> +FunctionBinder::BindSortedAggregateState(ClientContext &context, const BoundAggregateFunction &inner_function, + unique_ptr inner_bind_info, const LogicalType &buffer_struct, + const vector &orders, idx_t argument_count) { + const auto null_handling = inner_function.GetProperties().GetNullHandling(); + auto bind_data = make_uniq(context, inner_function, std::move(inner_bind_info), + buffer_struct, orders, argument_count); + // the wrapper consumes the buffered columns (the struct fields) + vector arguments; + for (const auto &child : StructType::GetChildTypes(buffer_struct)) { + arguments.emplace_back(child.second); + } + auto result_function = CreateSortedAggregateWrapper(inner_function.GetName(), arguments, + inner_function.GetReturnType(), null_handling); + return make_pair(std::move(result_function), std::move(bind_data)); +} + void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateExpression &expr, const vector> &groups, optional_ptr> grouping_sets) { @@ -480,8 +570,9 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE // not a sorted aggregate: return return; } - // Remove unnecessary ORDER BY clauses and return if nothing remains - if (Settings::Get(context)) { + // the exported state buffers the values, so the ORDER BY must be preserved verbatim - skip the simplification + const bool state_export = expr.StateExportMode() == AggregateStateExportMode::STATE_EXPORT; + if (!state_export && Settings::Get(context)) { if (expr.GetOrderBysMutable()->Simplify(groups, grouping_sets)) { expr.GetOrderBysMutable().reset(); return; @@ -490,6 +581,29 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE auto &bound_function = expr.Function(); auto &children = expr.GetChildrenMutable(); auto &order_bys = *expr.GetOrderBysMutable(); + + if (state_export) { + // the statistics optimizer may have narrowed the buffered column types since the exported type was fixed at + // bind time (e.g. a small-range group column INTEGER->TINYINT) - cast them back so the buffer matches its + // declared type. The widening preserves sort order; casting reused columns to the same type keeps matching. + LogicalType plan_struct; + vector order_columns; + idx_t argument_count; + GetSortedAggregateStateLayout(expr, plan_struct, order_columns, argument_count); + auto &logical_fields = StructType::GetChildTypes(ListType::GetChildType(expr.GetReturnType())); + auto cast_to = [&](unique_ptr &e, const LogicalType &type) { + if (e->GetReturnType() != type) { + e = BoundCastExpression::AddCastToType(context, std::move(e), type); + } + }; + for (idx_t i = 0; i < children.size() && i < logical_fields.size(); i++) { + cast_to(children[i], logical_fields[i].second); + } + for (idx_t o = 0; o < order_bys.orders.size() && o < order_columns.size(); o++) { + cast_to(order_bys.orders[o].expression, logical_fields[order_columns[o].column].second); + } + } + auto sorted_bind = make_uniq(context, expr); if (!sorted_bind->sorted_on_args) { @@ -506,18 +620,18 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE } // Replace the aggregate with the wrapper - AggregateFunction ordered_aggregate(bound_function.GetName(), arguments, bound_function.GetReturnType(), - AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - SortedAggregateFunction::ScatterUpdate, - ListCombineFunction, SortedAggregateFunction::Finalize, - bound_function.GetProperties().GetNullHandling(), nullptr, nullptr, nullptr, - nullptr, SortedAggregateFunction::WindowBatch); - ordered_aggregate.SetInitLocalStateFinalizeCallback(SortedAggregateFinalizeState::Init); + auto ordered_aggregate = + CreateSortedAggregateWrapper(bound_function.GetName(), arguments, bound_function.GetReturnType(), + bound_function.GetProperties().GetNullHandling()); expr.FunctionMutable().ReplaceImplementation(ordered_aggregate); expr.BindInfoMutable() = std::move(sorted_bind); expr.GetOrderBysMutable().reset(); + + if (state_export) { + // wire the export onto the wrapper - the AGGREGATE_STATE return type was already set at bind time + ExportAggregateFunction::SetStateExport(expr, expr.GetReturnType()); + } } void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundWindowExpression &expr) { diff --git a/src/duckdb/src/function/cast/variant/from_variant.cpp b/src/duckdb/src/function/cast/variant/from_variant.cpp index 852708a96..bc0d79d3f 100644 --- a/src/duckdb/src/function/cast/variant/from_variant.cpp +++ b/src/duckdb/src/function/cast/variant/from_variant.cpp @@ -330,7 +330,9 @@ static bool ConvertVariantToArray(FromVariantConversionData &conversion_data, Ve } FindValues(conversion_data.variant, row_index, new_sel, child_data_entry); - CastVariant(conversion_data, child, new_sel, total_offset, array_size, row_index); + if (!CastVariant(conversion_data, child, new_sel, total_offset, array_size, row_index)) { + return false; + } total_offset += array_size; } return true; diff --git a/src/duckdb/src/function/function_binder.cpp b/src/duckdb/src/function/function_binder.cpp index 58fba7064..c729e0e77 100644 --- a/src/duckdb/src/function/function_binder.cpp +++ b/src/duckdb/src/function/function_binder.cpp @@ -628,6 +628,35 @@ static bool RequiresCollationPropagation(const LogicalType &type) { return type.id() == LogicalTypeId::VARCHAR && !type.HasAlias(); } +//! Recursively extracts the collation of a (possibly nested) type, e.g. the element collation of a LIST(VARCHAR). +static string ExtractCollationFromType(const LogicalType &type) { + switch (type.id()) { + case LogicalTypeId::VARCHAR: + return RequiresCollationPropagation(type) ? StringType::GetCollation(type) : string(); + case LogicalTypeId::LIST: + return ExtractCollationFromType(ListType::GetChildType(type)); + case LogicalTypeId::ARRAY: + return ExtractCollationFromType(ArrayType::GetChildType(type)); + default: + return string(); + } +} + +//! Returns a copy of the type with the collation applied to every (nested) VARCHAR leaf. +static LogicalType ApplyCollationToType(const LogicalType &type, const LogicalType &collation_type) { + switch (type.id()) { + case LogicalTypeId::VARCHAR: + return RequiresCollationPropagation(type) ? collation_type : type; + case LogicalTypeId::LIST: + return LogicalType::LIST(ApplyCollationToType(ListType::GetChildType(type), collation_type)); + case LogicalTypeId::ARRAY: + return LogicalType::ARRAY(ApplyCollationToType(ArrayType::GetChildType(type), collation_type), + ArrayType::GetSize(type)); + default: + return type; + } +} + static string ExtractCollation(const vector> &children) { string collation; for (auto &arg : children) { @@ -645,6 +674,20 @@ static string ExtractCollation(const vector> &children) { return collation; } +//! Like ExtractCollation, but also considers the collation of nested (LIST/ARRAY) VARCHAR elements. +static string ExtractNestedCollation(const vector> &children) { + string collation; + for (auto &arg : children) { + auto child_collation = ExtractCollationFromType(arg->GetReturnType()); + if (collation.empty()) { + collation = child_collation; + } else if (!child_collation.empty() && collation != child_collation) { + throw BinderException("Cannot combine types with different collation!"); + } + } + return collation; +} + static void PropagateCollations(ClientContext &, BoundSimpleFunction &bound_function, vector> &children) { if (!RequiresCollationPropagation(bound_function.GetReturnType())) { @@ -663,7 +706,7 @@ static void PropagateCollations(ClientContext &, BoundSimpleFunction &bound_func static void PushCollations(ClientContext &context, BoundSimpleFunction &bound_function, vector> &children, CollationType type) { - auto collation = ExtractCollation(children); + auto collation = ExtractNestedCollation(children); if (collation.empty()) { // no collation to push return; @@ -675,12 +718,14 @@ static void PushCollations(ClientContext &context, BoundSimpleFunction &bound_fu } // push collations to the children for (auto &arg : children) { + // apply the collation to the (possibly nested) varchar leaves of the argument type + auto collated_type = ApplyCollationToType(arg->GetReturnType(), collation_type); if (RequiresCollationPropagation(arg->GetReturnType())) { // if this is a varchar type - propagate the collation arg->SetReturnType(collation_type); } // now push the actual collation handling - ExpressionBinder::PushCollation(context, arg, arg->GetReturnType(), type); + ExpressionBinder::PushCollation(context, arg, collated_type, type); } } diff --git a/src/duckdb/src/function/function_list.cpp b/src/duckdb/src/function/function_list.cpp index a5f9c40fb..53284b4d9 100644 --- a/src/duckdb/src/function/function_list.cpp +++ b/src/duckdb/src/function/function_list.cpp @@ -101,6 +101,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION(IsDistinctFromFun), DUCKDB_SCALAR_FUNCTION(IsNotDistinctFromFun), DUCKDB_SCALAR_FUNCTION(BetweenFun), + DUCKDB_SCALAR_FUNCTION(InternalCompressGeometryPointFun), DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUbigintFun), DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUintegerFun), DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUsmallintFun), @@ -111,6 +112,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION(InternalCompressStringUintegerFun), DUCKDB_SCALAR_FUNCTION(InternalCompressStringUsmallintFun), DUCKDB_SCALAR_FUNCTION(InternalCompressStringUtinyintFun), + DUCKDB_SCALAR_FUNCTION(InternalDecompressGeometryPointFun), DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralBigintFun), DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralHugeintFun), DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralIntegerFun), diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_geometry.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_geometry.cpp new file mode 100644 index 000000000..e8d70c6c2 --- /dev/null +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_geometry.cpp @@ -0,0 +1,134 @@ +#include "duckdb/common/bswap.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/function/scalar/compressed_materialization_functions.hpp" +#include "duckdb/function/scalar/compressed_materialization_utils.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" + +namespace duckdb { + +namespace { + +// An internally stored non-empty POINT-XY geometry is always exactly these sizes: +// 1 byte little-endian byte order + 4 byte WKB type/meta (= 1) + 2 doubles (X, Y) +constexpr auto POINT_XY_HEADER_SIZE = sizeof(uint8_t) + sizeof(uint32_t); +constexpr auto POINT_XY_COORD_SIZE = sizeof(double) + sizeof(double); +constexpr auto POINT_XY_BLOB_SIZE = POINT_XY_HEADER_SIZE + POINT_XY_COORD_SIZE; + +uhugeint_t GeometryPointCompress(const string_t &input) { + D_ASSERT(input.GetSize() == POINT_XY_BLOB_SIZE); + + // The 16 coordinate bytes (X, Y) follow the fixed 5-byte WKB header (byte order + type) + const auto coords = const_data_ptr_cast(input.GetData()) + POINT_XY_HEADER_SIZE; + + // Pack the coordinates into a UHUGEINT, reversing the byte order so that an unsigned integer + // comparison of the result matches the lexicographic (memcmp) byte comparison. + uhugeint_t result; + const auto result_ptr = data_ptr_cast(&result); + for (idx_t i = 0; i < POINT_XY_COORD_SIZE; i++) { + result_ptr[i] = coords[POINT_XY_COORD_SIZE - 1 - i]; + } + return BSwapIfBE(result); +} + +void GeometryPointCompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(args.data[0], result, GeometryPointCompress, + FunctionErrors::CANNOT_ERROR); +} + +void GeometryPointDecompress(const uhugeint_t &input, const data_ptr_t output) { + const auto le_input = BSwapIfBE(input); + const auto le_input_ptr = const_data_ptr_cast(&le_input); + + // Reconstruct the fixed POINT-XY WKB header: little-endian byte order (1) + WKB type POINT (1) + Store(1, output); + Store(1, output + sizeof(uint8_t)); + + // Restore the coordinate bytes (inverse of the byte reversal in GeometryPointCompress) + const auto coords = output + POINT_XY_HEADER_SIZE; + for (idx_t i = 0; i < POINT_XY_COORD_SIZE; i++) { + coords[POINT_XY_COORD_SIZE - 1 - i] = le_input_ptr[i]; + } +} + +void GeometryPointDecompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { + D_ASSERT(args.ColumnCount() == 1); + const auto count = args.size(); + + UnifiedVectorFormat input_data; + args.data[0].ToUnifiedFormat(input_data); + const auto input = UnifiedVectorFormat::GetData(input_data); + + result.SetVectorType(VectorType::FLAT_VECTOR); + auto result_data = FlatVector::GetDataMutable(result); + auto &result_mask = FlatVector::ValidityMutable(result); + + for (idx_t i = 0; i < count; i++) { + const auto idx = input_data.sel->get_index(i); + if (!input_data.validity.RowIsValid(idx)) { + result_mask.SetInvalid(i); + continue; + } + auto blob = StringVector::EmptyString(result, POINT_XY_BLOB_SIZE); + GeometryPointDecompress(input[idx], data_ptr_cast(blob.GetDataWriteable())); + blob.Finalize(); + result_data[i] = blob; + } +} + +void CMGeometryPointCompressSerialize(Serializer &serializer, const optional_ptr bind_data, + const BoundScalarFunction &function) { + serializer.WriteProperty(100, "arguments", function.GetArguments()); + serializer.WriteProperty(101, "return_type", function.GetReturnType()); +} + +unique_ptr CMGeometryPointCompressDeserialize(Deserializer &deserializer, BoundScalarFunction &function) { + function.GetArguments() = deserializer.ReadProperty>(100, "arguments"); + function.SetReturnType(deserializer.ReadProperty(101, "return_type")); + function.SetFunctionCallback(GeometryPointCompressFunction); + return nullptr; +} + +void CMGeometryPointDecompressSerialize(Serializer &serializer, const optional_ptr bind_data, + const BoundScalarFunction &function) { + serializer.WriteProperty(100, "arguments", function.GetArguments()); + serializer.WriteProperty(101, "return_type", function.GetReturnType()); +} + +unique_ptr CMGeometryPointDecompressDeserialize(Deserializer &deserializer, + BoundScalarFunction &function) { + function.GetArguments() = deserializer.ReadProperty>(100, "arguments"); + function.SetReturnType(deserializer.ReadProperty(101, "return_type")); + function.SetFunctionCallback(GeometryPointDecompressFunction); + return nullptr; +} + +} // namespace + +ScalarFunction CMGeometryPointCompressFun::GetFunction() { + ScalarFunction result(Identifier("__internal_compress_geometry_point"), {LogicalType::GEOMETRY()}, + LogicalType::UHUGEINT, GeometryPointCompressFunction, CMUtils::Bind); + result.SetSerializeCallback(CMGeometryPointCompressSerialize); + result.SetDeserializeCallback(CMGeometryPointCompressDeserialize); + result.SetErrorMode(FunctionErrors::CANNOT_ERROR); + return result; +} + +ScalarFunction CMGeometryPointDecompressFun::GetFunction() { + ScalarFunction result(Identifier("__internal_decompress_geometry_point"), {LogicalType::UHUGEINT}, + LogicalType::GEOMETRY(), GeometryPointDecompressFunction, CMUtils::Bind); + result.SetSerializeCallback(CMGeometryPointDecompressSerialize); + result.SetDeserializeCallback(CMGeometryPointDecompressDeserialize); + result.SetErrorMode(FunctionErrors::CANNOT_ERROR); + return result; +} + +ScalarFunction InternalCompressGeometryPointFun::GetFunction() { + return CMGeometryPointCompressFun::GetFunction(); +} + +ScalarFunction InternalDecompressGeometryPointFun::GetFunction() { + return CMGeometryPointDecompressFun::GetFunction(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp b/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp index 9842b0171..990229bd9 100644 --- a/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp +++ b/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp @@ -4,6 +4,8 @@ #include "duckdb/common/vector_operations/binary_executor.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/storage/statistics/geometry_stats.hpp" +#include "duckdb/storage/statistics/string_stats.hpp" namespace duckdb { @@ -11,8 +13,82 @@ static void FromWKBFunction(DataChunk &input, ExpressionState &state, Vector &re Geometry::FromBinary(input.data[0], result, input.size(), true); } +static auto FromWKBStats(ClientContext &context, FunctionStatisticsInput &input) -> unique_ptr { + const auto &child_stats = input.child_stats[0]; + + // Start from fully unknown geometry stats and copy over the validity (null-ness) of the input. + auto result = GeometryStats::CreateUnknown(input.expr.GetReturnType()); + result.CopyValidity(child_stats); + + if (!StringStats::HasMinMax(child_stats)) { + // No min/max available, we can't say anything about the types + return result.ToUnique(); + } + + // The lexicographically smallest and largest WKB blobs. Any prefix shared by both is shared by every row, + // so if the first 5 bytes (byte order + type meta) match, all rows have the exact same geometry type and ZM flags. + const auto min_blob = StringStats::Min(child_stats); + const auto max_blob = StringStats::Max(child_stats); + + constexpr idx_t WKB_HEADER_SIZE = sizeof(uint8_t) + sizeof(uint32_t); + if (min_blob.size() < WKB_HEADER_SIZE || max_blob.size() < WKB_HEADER_SIZE) { + return result.ToUnique(); + } + if (memcmp(min_blob.data(), max_blob.data(), WKB_HEADER_SIZE) != 0) { + // The headers differ, so multiple geometry types may be present + return result.ToUnique(); + } + + // Now parse the 5-byte WKB header (byte order + type meta) into a geometry/vertex type. + const auto header = const_data_ptr_cast(min_blob.data()); + const auto le = Load(header); + + auto meta = Load(header + sizeof(uint8_t)); + if (!le) { + meta = BSwap(meta); + } + + const auto type_id = (meta & 0x0000FFFF) % 1000; + const auto flag_id = (meta & 0x0000FFFF) / 1000; + if (type_id < 1 || type_id > 7 || flag_id > 3) { + // Unsupported or invalid geometry type, we can't say anything about the types + return result.ToUnique(); + } + + // Z/M may also be signalled through the Extended-WKB high bits (matching Geometry::FromBinary) + const auto has_z = ((flag_id & 0x01) != 0) || ((meta & 0x80000000) != 0); + const auto has_m = ((flag_id & 0x02) != 0) || ((meta & 0x40000000) != 0); + + const auto geom_type = static_cast(type_id); + const auto vert_type = static_cast((has_z ? 1 : 0) | (has_m ? 2 : 0)); + + // The single inferred type implies a minimum body length: a POINT serializes all of its ordinates inline + // (even when empty, encoded as NaNs); every other type serializes at least a 4-byte element count. If the + // shortest row (the exact minimum over all rows) can't hold that, the column has a truncated blob, so bail. + const idx_t vert_dims = 2 + (has_z ? 1 : 0) + (has_m ? 1 : 0); + const auto min_valid_size = geom_type == GeometryType::POINT ? WKB_HEADER_SIZE + vert_dims * sizeof(double) + : WKB_HEADER_SIZE + sizeof(uint32_t); + + const auto min_len = StringStats::MinStringLength(child_stats); + if (!min_len.IsValid() || min_len.GetIndex() < min_valid_size) { + // Only bounds the byte envelope, not that a declared element count matches the body. + // In general, we cant guarantee that the WKB is valid without parsing every row. + // But we're generally OK with optimizing assuming valid input - anything else is essentially UB. + return result.ToUnique(); + } + + // All rows share this single geometry type. The extent and emptiness can't be inferred + // from the truncated string stats, so those remain unknown. + auto &types = GeometryStats::GetTypes(result); + types = GeometryTypeSet::Empty(); + types.Add(geom_type, vert_type); + + return result.ToUnique(); +} + ScalarFunction StGeomfromwkbFun::GetFunction() { ScalarFunction function({LogicalType::BLOB}, LogicalType::GEOMETRY(), FromWKBFunction); + function.SetStatisticsCallback(FromWKBStats); return function; } diff --git a/src/duckdb/src/function/scalar/list/contains_or_position.cpp b/src/duckdb/src/function/scalar/list/contains_or_position.cpp index 510f1d426..47353cb44 100644 --- a/src/duckdb/src/function/scalar/list/contains_or_position.cpp +++ b/src/duckdb/src/function/scalar/list/contains_or_position.cpp @@ -22,14 +22,17 @@ static void ListSearchFunction(DataChunk &input, ExpressionState &state, Vector } ScalarFunction ListContainsFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::TEMPLATE("T")}, - LogicalType::BOOLEAN, ListSearchFunction); + auto fun = ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::TEMPLATE("T")}, + LogicalType::BOOLEAN, ListSearchFunction); + fun.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); + return fun; } ScalarFunction ListPositionFun::GetFunction() { auto fun = ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::TEMPLATE("T")}, LogicalType::INTEGER, ListSearchFunction); fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + fun.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return fun; } diff --git a/src/duckdb/src/function/scalar/operator/arithmetic.cpp b/src/duckdb/src/function/scalar/operator/arithmetic.cpp index 51e840dc5..41bc8dd81 100644 --- a/src/duckdb/src/function/scalar/operator/arithmetic.cpp +++ b/src/duckdb/src/function/scalar/operator/arithmetic.cpp @@ -96,11 +96,11 @@ static Value NumericStatsValue(const LogicalType &type, T value) { D_ASSERT(type.IsNumeric()); switch (type.InternalType()) { case PhysicalType::FLOAT: - return Value::FLOAT(value); + return Value::FLOAT(static_cast(value)); case PhysicalType::DOUBLE: - return Value::DOUBLE(value); + return Value::DOUBLE(static_cast(value)); default: - return Value::Numeric(type, value); + return Value::Numeric(type, static_cast(value)); } } diff --git a/src/duckdb/src/function/scalar/string/like.cpp b/src/duckdb/src/function/scalar/string/like.cpp index c728dba20..b945fdeb4 100644 --- a/src/duckdb/src/function/scalar/string/like.cpp +++ b/src/duckdb/src/function/scalar/string/like.cpp @@ -432,6 +432,10 @@ bool ILikeOperatorFunction(string_t &str, string_t &pattern, char escape = '\0') LowerCase(pat_data, pat_size, pat_ldata.get()); string_t str_lcase(str_ldata.get(), UnsafeNumericCast(str_llength)); string_t pat_lcase(pat_ldata.get(), UnsafeNumericCast(pat_llength)); + // '\0' is the "no escape" sentinel: use the non-escape matcher so embedded NUL bytes are matched literally + if (escape == '\0') { + return LikeOperatorFunction(str_lcase, pat_lcase); + } return LikeOperatorFunction(str_lcase, pat_lcase, escape); } @@ -504,6 +508,67 @@ void LikeEscapeFunction(DataChunk &args, ExpressionState &state, Vector &result) str, pattern, escape, result, FUNC::template Operation); } +// Execution function for ILIKE / NOT ILIKE ... ESCAPE. Mirrors ILikeFunction: when the pattern and escape are +// both constant, lowercase the pattern once and reuse a scratch buffer for the per-row string lowercasing. If the +// escape character does not occur in the pattern, escape handling is a no-op, so we build the case-folded +// LikeMatcher and use the fast SIMD contains path; otherwise we fall back to the generic escape-aware matcher on +// the lowercased values. Non-constant pattern/escape falls back to the per-row ternary path. +template +void ILikeEscapeFunction(DataChunk &args, ExpressionState &state, Vector &result) { + auto &str_vec = args.data[0]; + auto &pattern_vec = args.data[1]; + auto &escape_vec = args.data[2]; + + if (pattern_vec.GetVectorType() == VectorType::CONSTANT_VECTOR && !ConstantVector::IsNull(pattern_vec) && + escape_vec.GetVectorType() == VectorType::CONSTANT_VECTOR && !ConstantVector::IsNull(escape_vec)) { + auto pattern = *ConstantVector::GetData(pattern_vec); + auto escape = *ConstantVector::GetData(escape_vec); + char escape_char = GetEscapeChar(escape); + + // lowercase the pattern exactly once, up front + idx_t pat_llength = LowerLength(pattern.GetData(), pattern.GetSize()); + auto pat_ldata = make_unsafe_uniq_array_uninitialized(pat_llength); + LowerCase(pattern.GetData(), pattern.GetSize(), pat_ldata.get()); + string_t pat_lcase(pat_ldata.get(), UnsafeNumericCast(pat_llength)); + + // the matcher cannot honor escape semantics, so only use it when the escape char never appears in the + // (lowercased) pattern, in which case escape is irrelevant and the pattern is a plain LIKE pattern + unique_ptr matcher; + bool escape_active = + escape_char != '\0' && memchr(pat_lcase.GetData(), escape_char, pat_lcase.GetSize()) != nullptr; + if (!escape_active) { + matcher = LikeMatcher::CreateLikeMatcher(string(pat_lcase.GetData(), pat_lcase.GetSize())); + } + + // reusable scratch buffer for lowercasing each string value (grown on demand) + idx_t scratch_size = 0; + unsafe_unique_array scratch; + UnaryExecutor::Execute(str_vec, result, args.size(), [&](string_t str) { + idx_t str_llength = LowerLength(str.GetData(), str.GetSize()); + if (str_llength > scratch_size) { + scratch = make_unsafe_uniq_array_uninitialized(str_llength); + scratch_size = str_llength; + } + LowerCase(str.GetData(), str.GetSize(), scratch.get()); + string_t str_lcase(scratch.get(), UnsafeNumericCast(str_llength)); + // '\0' escape means no escape: use the non-escape matcher so embedded NUL bytes are matched literally + bool match = matcher ? matcher->Match(str_lcase) + : (escape_char == '\0' ? LikeOperatorFunction(str_lcase, pat_lcase) + : LikeOperatorFunction(str_lcase, pat_lcase, escape_char)); + return INVERT ? !match : match; + }); + return; + } + // non-constant pattern/escape: fall back to the generic per-row implementation + if (INVERT) { + TernaryExecutor::Execute( + str_vec, pattern_vec, escape_vec, result, NotILikeEscapeOperator::Operation); + } else { + TernaryExecutor::Execute( + str_vec, pattern_vec, escape_vec, result, ILikeEscapeOperator::Operation); + } +} + template unique_ptr ILikePropagateStats(ClientContext &context, FunctionStatisticsInput &input) { auto &child_stats = input.child_stats; @@ -516,6 +581,48 @@ unique_ptr ILikePropagateStats(ClientContext &context, FunctionS return nullptr; } +// Execution function for ILIKE / NOT ILIKE on the (possibly) Unicode path. +// When the pattern is constant we lowercase it exactly once instead of once per row, and we reuse a single +// scratch buffer to lowercase each string value instead of heap-allocating per row. This avoids two heap +// allocations + a redundant pattern case-fold on every row that the generic ILikeOperatorFunction incurs. +// (The ASCII-only fast path installed by ILikePropagateStats already avoids allocations, so it is unaffected.) +template +void ILikeFunction(DataChunk &input, ExpressionState &state, Vector &result) { + auto &pattern_vec = input.data[1]; + if (pattern_vec.GetVectorType() == VectorType::CONSTANT_VECTOR && !ConstantVector::IsNull(pattern_vec)) { + // constant pattern: lowercase it exactly once, up front + auto pattern = *ConstantVector::GetData(pattern_vec); + idx_t pat_llength = LowerLength(pattern.GetData(), pattern.GetSize()); + auto pat_ldata = make_unsafe_uniq_array_uninitialized(pat_llength); + LowerCase(pattern.GetData(), pattern.GetSize(), pat_ldata.get()); + string_t pat_lcase(pat_ldata.get(), UnsafeNumericCast(pat_llength)); + + // Build a case-insensitive matcher from the lowercased pattern. Because both the pattern and each string + // value are lowercased, a case-sensitive LikeMatcher over lowercased data is equivalent to ILIKE, and it + // uses the fast SIMD FindStrInStr/memcmp contains path. Returns null for patterns with '_' or only '%'; + // in that case we fall back to the generic recursive matcher on the lowercased values. + auto matcher = LikeMatcher::CreateLikeMatcher(string(pat_lcase.GetData(), pat_lcase.GetSize())); + + // reusable scratch buffer for lowercasing each string value (grown on demand) + idx_t scratch_size = 0; + unsafe_unique_array scratch; + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](string_t str) { + idx_t str_llength = LowerLength(str.GetData(), str.GetSize()); + if (str_llength > scratch_size) { + scratch = make_unsafe_uniq_array_uninitialized(str_llength); + scratch_size = str_llength; + } + LowerCase(str.GetData(), str.GetSize(), scratch.get()); + string_t str_lcase(scratch.get(), UnsafeNumericCast(str_llength)); + bool match = matcher ? matcher->Match(str_lcase) : LikeOperatorFunction(str_lcase, pat_lcase); + return INVERT ? !match : match; + }); + return; + } + // non-constant pattern: fall back to the generic per-row implementation + BinaryExecutor::ExecuteStandard(input.data[0], input.data[1], result, input.size()); +} + template void RegularLikeFunction(DataChunk &input, ExpressionState &state, Vector &result) { auto &func_expr = state.expr.Cast(); @@ -549,15 +656,14 @@ ScalarFunction GlobPatternFun::GetFunction() { ScalarFunction ILikeFun::GetFunction() { ScalarFunction ilike("~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction, nullptr, - ILikePropagateStats); + ILikeFunction, nullptr, ILikePropagateStats); ilike.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return ilike; } ScalarFunction NotILikeFun::GetFunction() { ScalarFunction not_ilike("!~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction, nullptr, + ILikeFunction, nullptr, ILikePropagateStats); not_ilike.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return not_ilike; @@ -580,7 +686,7 @@ ScalarFunction NotLikeEscapeFun::GetFunction() { ScalarFunction IlikeEscapeFun::GetFunction() { ScalarFunction ilike_escape("ilike_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::BOOLEAN, LikeEscapeFunction); + LogicalType::BOOLEAN, ILikeEscapeFunction); ilike_escape.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return ilike_escape; } @@ -588,7 +694,7 @@ ScalarFunction IlikeEscapeFun::GetFunction() { ScalarFunction NotIlikeEscapeFun::GetFunction() { ScalarFunction not_ilike_escape("not_ilike_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::BOOLEAN, LikeEscapeFunction); + LogicalType::BOOLEAN, ILikeEscapeFunction); not_ilike_escape.SetCollationHandling(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS); return not_ilike_escape; } diff --git a/src/duckdb/src/function/scalar/system/aggregate_export.cpp b/src/duckdb/src/function/scalar/system/aggregate_export.cpp index f84bbdddd..d07ecc33c 100644 --- a/src/duckdb/src/function/scalar/system/aggregate_export.cpp +++ b/src/duckdb/src/function/scalar/system/aggregate_export.cpp @@ -1,9 +1,7 @@ #include "duckdb/common/vector/flat_vector.hpp" #include "duckdb/common/vector/list_vector.hpp" #include "duckdb/common/vector/struct_vector.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/common/types/list_segment.hpp" -#include "duckdb/common/types/variant_value.hpp" #include "duckdb/function/aggregate_state_layout.hpp" #include "duckdb/function/create_sort_key.hpp" #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" @@ -133,39 +131,39 @@ struct LoadOp { } }; -// Store rows from the packed binary state buffer into a result vector. +// Store rows from the packed binary state buffer into a result vector at [offset, offset + count). struct StoreOp { template - static void Operation(Vector &result, idx_t count, const data_ptr_t *sources, idx_t field_offset) { - auto dst = FlatVector::Writer(result, count); + static void Operation(Vector &result, idx_t count, const data_ptr_t *sources, idx_t field_offset, idx_t offset) { + auto dst = FlatVector::Writer(result, count, offset); for (idx_t i = 0; i < count; i++) { dst.WriteValue(Load(sources[i] + field_offset)); } } }; -// Recursively serialize a state field to a result vector. +// Recursively serialize a state field to a result vector, writing the `count` rows at [offset, offset + count). // base: accumulated byte offset from the state slot start to this field's parent base. // Each child's field_offset is relative to that parent base. static void SerializeField(const LogicalType &type, const AggregateStateField &field, Vector &result, idx_t count, - const data_ptr_t *addresses, idx_t base) { + const data_ptr_t *addresses, idx_t base, idx_t offset) { switch (field.kind) { case AggregateFieldKind::OPTIONAL_VALUE: D_ASSERT(field.children.size() == 1); for (idx_t i = 0; i < count; i++) { if (!Load(addresses[i] + base + field.field_offset)) { - FlatVector::SetNull(result, i, true); + FlatVector::SetNull(result, offset + i, true); } } - SerializeField(type, field.children[0], result, count, addresses, base); + SerializeField(type, field.children[0], result, count, addresses, base, offset); break; case AggregateFieldKind::SORT_KEY: for (idx_t i = 0; i < count; i++) { - if (!FlatVector::Validity(result).RowIsValid(i)) { + if (!FlatVector::Validity(result).RowIsValid(offset + i)) { continue; } const string_t sort_key = Load(addresses[i] + base + field.field_offset); - CreateSortKeyHelpers::DecodeSortKey(sort_key, result, i, + CreateSortKeyHelpers::DecodeSortKey(sort_key, result, offset + i, OrderModifiers(field.sort_key_order, OrderByNullType::NULLS_LAST)); } break; @@ -175,12 +173,12 @@ static void SerializeField(const LogicalType &type, const AggregateStateField &f const idx_t new_base = base + field.field_offset; for (idx_t field_idx = 0; field_idx < field.children.size(); field_idx++) { SerializeField(child_types[field_idx].second, field.children[field_idx], struct_entries[field_idx], count, - addresses, new_base); + addresses, new_base, offset); } break; } case AggregateFieldKind::PRIMITIVE: - TemplateDispatch(type.InternalType(), result, count, addresses, base + field.field_offset); + TemplateDispatch(type.InternalType(), result, count, addresses, base + field.field_offset, offset); break; case AggregateFieldKind::LIST: { // linked list field: build the result LIST vector from each state's linked list @@ -195,7 +193,8 @@ static void SerializeField(const LogicalType &type, const AggregateStateField &f const auto &element = field.children[0]; if (element.kind != AggregateFieldKind::SORT_KEY) { // elements are stored directly - build the result LIST vector from each state's linked list - field.list_functions.BuildLists(linked_lists, result, 0); + // (BuildLists appends to the result's child, writing the list entries at [offset, offset + count)) + field.list_functions.BuildLists(linked_lists, result, offset); break; } // the elements are sort keys: build the physically stored (BLOB) elements into a temporary LIST vector, then @@ -203,14 +202,15 @@ static void SerializeField(const LogicalType &type, const AggregateStateField &f Vector physical_list(LogicalType::LIST(LogicalType::BLOB), count); field.list_functions.BuildLists(linked_lists, physical_list, 0); - ListVector::Reserve(result, ListVector::GetListSize(physical_list)); + // append to the result child, starting after any rows already written at a lower offset + idx_t child_offset = ListVector::GetListSize(result); + ListVector::Reserve(result, child_offset + ListVector::GetListSize(physical_list)); auto &result_child = ListVector::GetChildMutable(result); auto result_entries = FlatVector::GetDataMutable(result); const OrderModifiers modifiers(element.sort_key_order, OrderByNullType::NULLS_LAST); - idx_t child_offset = 0; for (const auto list_entry : physical_list.Values>()) { - const auto row = list_entry.GetIndex(); + const auto row = offset + list_entry.GetIndex(); if (!list_entry.IsValid()) { // an empty linked list is exported as NULL, matching the finalize semantics of list aggregates FlatVector::SetNull(result, row, true); @@ -324,18 +324,18 @@ static void DeserializeState(const BoundAggregateFunction &aggr, const Aggregate DeserializeField(layout.type, layout.field, input_vec, count, dest_buffer, layout.total_state_size, 0, allocator); } -static void SerializeState(const AggregateStateLayout &layout, Vector &result, idx_t count, - const data_ptr_t *addresses) { - SerializeField(layout.type, layout.field, result, count, addresses, 0); +static void SerializeState(const AggregateStateLayout &layout, Vector &result, idx_t count, const data_ptr_t *addresses, + idx_t offset) { + SerializeField(layout.type, layout.field, result, count, addresses, 0, offset); } static void SerializeState(const BoundAggregateFunction &aggr, optional_ptr bind_data, const AggregateStateLayout &layout, Vector &states, idx_t count, Vector &result, - ArenaAllocator &allocator) { + ArenaAllocator &allocator, idx_t offset) { if (aggr.HasExportAggregateStateCallback()) { - // the aggregate explicitly serializes its own states + // the aggregate explicitly serializes its own states, writing the count rows at [offset, offset + count) AggregateFinalizeInputData aggr_input_data(aggr, bind_data, allocator); - aggr.GetExportAggregateStateCallback()(states, aggr_input_data, result, count, 0); + aggr.GetExportAggregateStateCallback()(states, aggr_input_data, result, count, offset); return; } const data_ptr_t *addresses; @@ -344,7 +344,7 @@ static void SerializeState(const BoundAggregateFunction &aggr, optional_ptr(states); } - SerializeState(layout, result, count, addresses); + SerializeState(layout, result, count, addresses, offset); } // destroys the temporary underlying-aggregate states referenced by `states` (no-op if the aggregate has no @@ -489,7 +489,7 @@ void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &r bind_data.aggr.GetStateCombineCallback()(local_state.addresses0, local_state.addresses1, aggr_input_data, count); SerializeState(bind_data.aggr, bind_data.bind_data.get(), layout, local_state.addresses1, count, result, - local_state.allocator); + local_state.allocator, 0); } catch (...) { destroy_states(); throw; @@ -593,6 +593,9 @@ void ParseStateParameters(const Value ¶meters, vector &argument } } +// parses an ORDER BY spec (a list of {column, order} structs) - shared by the re-bind path and to_aggregate_state +void ParseOrderBys(const Value &order_value, idx_t column_count, vector &orders); + unique_ptr BindAggregateStateInternal(ClientContext &context, BoundSimpleFunction &function, vector> &arguments) { auto &arg_return_type = arguments[0]->GetReturnType(); @@ -619,7 +622,27 @@ unique_ptr BindAggregateStateInternal(ClientContext &co map constant_parameters; ParseStateParameters(entry->second, argument_types, constant_parameters); - return BindExportedAggregate(context, function_name, argument_types, constant_parameters); + auto inner = BindExportedAggregate(context, function_name, argument_types, constant_parameters); + + auto order_entry = ext_info->properties.find("order_bys"); + if (order_entry == ext_info->properties.end()) { + // a plain (non-ordered) aggregate state + return inner; + } + // an ordered aggregate state: the value is the buffer (LIST) - reconstruct the sorted wrapper + // around the inner aggregate so finalize sorts the buffer and combine concatenates buffers + const auto buffer_struct = ListType::GetChildType(arg_return_type); + const idx_t column_count = StructType::GetChildTypes(buffer_struct).size(); + vector orders; + ParseOrderBys(order_entry->second, column_count, orders); + // the leading buffered columns are the inner aggregate's bound arguments (post constant-erasure) + const idx_t argument_count = inner->aggr.GetArguments().size(); + + auto reconstructed = FunctionBinder::BindSortedAggregateState(context, inner->aggr, std::move(inner->bind_data), + buffer_struct, orders, argument_count); + BoundAggregateFunction wrapper(reconstructed.first); + const auto state_size = wrapper.GetStateSizeCallback()(wrapper); + return make_uniq(std::move(wrapper), std::move(reconstructed.second), state_size); } unique_ptr BindAggregateState(BindScalarFunctionInput &input) { @@ -678,15 +701,12 @@ void ExportAggregateFinalize(Vector &state, AggregateFinalizeInputData &aggr_inp auto layout = GetLayout(aggr_input_data.function, aggr_input_data.bind_data); + // only flatten the result the first time it is written - ordered aggregates finalize one group at a time at + // increasing offsets, appending into the (already flat) result if (offset == 0) { - SerializeState(layout, result, count, addresses_ptrs); - return; + result.Flatten(); } - // finalizing at a non-zero offset (e.g. ordered aggregates) - serialize into a temporary vector and copy the - // result into place so the rest of the result vector is left untouched - Vector temp(result.GetType(), count); - SerializeState(layout, temp, count, addresses_ptrs); - VectorOperations::Copy(temp, result, count, 0, offset); + SerializeState(layout, result, count, addresses_ptrs, offset); } // the executor invokes this callback with combine_aggr's own bind data (ExportAggregateBindData) - the underlying @@ -772,30 +792,20 @@ void CombineAggrFinalize(Vector &state, AggregateFinalizeInputData &aggr_input_d auto layout = GetLayout(underlying_aggr, bind_data.bind_data.get()); + // only flatten the result the first time it is written - ordered aggregates finalize one group at a time at + // increasing offsets, appending into the (already flat) result if (offset == 0) { - SerializeState(underlying_aggr, bind_data.bind_data.get(), layout, state, count, result, - aggr_input_data.allocator); - return; + result.Flatten(); } - // finalizing at a non-zero offset (e.g. ordered aggregates) - serialize into a temporary vector and copy the - // result into place so the rest of the result vector is left untouched - Vector temp(result.GetType(), count); - SerializeState(underlying_aggr, bind_data.bind_data.get(), layout, state, count, temp, aggr_input_data.allocator); - VectorOperations::Copy(temp, result, count, 0, offset); + SerializeState(underlying_aggr, bind_data.bind_data.get(), layout, state, count, result, aggr_input_data.allocator, + offset); } -// constructs the AGGREGATE_STATE type for the given bound aggregate function -// the state layout (a struct) is aliased to AGGREGATE_STATE, with the function name and signature stored in the -// extension type info so that the aggregate can be re-bound later (e.g. by FINALIZE/COMBINE) -LogicalType CreateAggregateStateType(const BoundAggregateFunction &bound_function, - optional_ptr bind_data) { - auto layout = bound_function.GetStateType(bind_data); - // deep copy the type before modifying it - SetAlias/SetExtensionInfo modify the (shared) extra type info in - // place, and the state layout type can share its type info with e.g. the aggregate's input expressions - LogicalType state_layout = layout.type.DeepCopy(); - state_layout.SetAlias("AGGREGATE_STATE"); - auto ext_info = make_uniq(); - ext_info->properties.emplace("function_name", bound_function.GetName()); +// stores the function name and signature (with any constant parameters) into the extension type info, so that the +// aggregate can be re-bound later (e.g. by FINALIZE/COMBINE) +void EncodeStateParameters(ExtensionTypeInfo &ext_info, const BoundAggregateFunction &bound_function, + const AggregateStateLayout &layout) { + ext_info.properties.emplace("function_name", bound_function.GetName()); auto &original_arguments = bound_function.GetOriginalArguments().empty() ? bound_function.GetArguments() : bound_function.GetOriginalArguments(); vector arguments; @@ -804,7 +814,7 @@ LogicalType CreateAggregateStateType(const BoundAggregateFunction &bound_functio for (auto &arg : original_arguments) { arguments.push_back(Value::TYPE(arg)); } - ext_info->properties.emplace("parameters", Value::LIST(LogicalType::TYPE(), std::move(arguments))); + ext_info.properties.emplace("parameters", Value::LIST(LogicalType::TYPE(), std::move(arguments))); } else { // some parameters were bound to a constant (e.g. string_agg's separator) - store the parameters as a list of // (type, value) pairs, where the value holds the constant the parameter must be re-bound with @@ -820,8 +830,55 @@ LogicalType CreateAggregateStateType(const BoundAggregateFunction &bound_functio arguments.push_back(Value::STRUCT(std::move(children))); } auto entry_type = LogicalType::STRUCT({{"type", LogicalType::TYPE()}, {"value", LogicalType::VARIANT()}}); - ext_info->properties.emplace("parameters", Value::LIST(entry_type, std::move(arguments))); + ext_info.properties.emplace("parameters", Value::LIST(entry_type, std::move(arguments))); } +} + +// the STRUCT type of a single ORDER BY entry encoded into an ordered aggregate state's extension info. It matches the +// shape of to_aggregate_state's ORDER BY argument: the buffered column the key sorts on, and the modifier string. +LogicalType OrderByEntryType() { + return LogicalType::STRUCT({{"column", LogicalType::UINTEGER}, {"order", LogicalType::VARCHAR}}); +} + +// renders an order modifier (e.g. ASC NULLS LAST) into the string form parsed by OrderModifiers::Parse +string OrderModifierToString(OrderType order_type, OrderByNullType null_order) { + return StringUtil::Format("%s %s", EnumUtil::ToChars(order_type), EnumUtil::ToChars(null_order)); +} + +// constructs the AGGREGATE_STATE type for the given bound aggregate function +// the state layout (a struct) is aliased to AGGREGATE_STATE, with the function name and signature stored in the +// extension type info so that the aggregate can be re-bound later (e.g. by FINALIZE/COMBINE) +LogicalType CreateAggregateStateType(const BoundAggregateFunction &bound_function, + optional_ptr bind_data) { + auto layout = bound_function.GetStateType(bind_data); + // deep copy the type before modifying it - SetAlias/SetExtensionInfo modify the (shared) extra type info in + // place, and the state layout type can share its type info with e.g. the aggregate's input expressions + LogicalType state_layout = layout.type.DeepCopy(); + state_layout.SetAlias("AGGREGATE_STATE"); + auto ext_info = make_uniq(); + EncodeStateParameters(*ext_info, bound_function, layout); + state_layout.SetExtensionInfo(std::move(ext_info)); + return state_layout; +} + +// constructs the AGGREGATE_STATE type for an ordered aggregate - the state is the buffer of values +// (LIST), with the inner signature and the ORDER BY spec stored in the extension info +LogicalType CreateSortedAggregateStateType(const BoundAggregateFunction &inner_function, + optional_ptr inner_bind_data, const LogicalType &buffer_struct, + const vector &orders) { + LogicalType state_layout = LogicalType::LIST(buffer_struct); + state_layout.SetAlias("AGGREGATE_STATE"); + auto ext_info = make_uniq(); + EncodeStateParameters(*ext_info, inner_function, inner_function.GetStateType(inner_bind_data)); + // per key: the buffered column it sorts on and the modifier string (the argument count is re-derived on re-bind) + vector order_values; + for (auto &order : orders) { + child_list_t children; + children.emplace_back("column", Value::UINTEGER(UnsafeNumericCast(order.column))); + children.emplace_back("order", Value(OrderModifierToString(order.order_type, order.null_order))); + order_values.push_back(Value::STRUCT(std::move(children))); + } + ext_info->properties.emplace("order_bys", Value::LIST(OrderByEntryType(), std::move(order_values))); state_layout.SetExtensionInfo(std::move(ext_info)); return state_layout; } @@ -868,6 +925,49 @@ void ParseConstantParameters(const Value &constants, idx_t argument_count, map &orders) { + if (order_value.IsNull()) { + return; + } + if (order_value.type().id() != LogicalTypeId::LIST) { + throw BinderException( + "to_aggregate_state: the ORDER BY argument must be a list of {column, order} structs, e.g. " + "[{'column': 1, 'order': 'DESC NULLS LAST'}]"); + } + for (auto &entry : ListValue::GetChildren(order_value)) { + if (entry.IsNull() || entry.type().id() != LogicalTypeId::STRUCT) { + throw BinderException("to_aggregate_state: each ORDER BY entry must be a {column, order} struct"); + } + auto &field_types = StructType::GetChildTypes(entry.type()); + auto &field_values = StructValue::GetChildren(entry); + Value column, order; + for (idx_t f = 0; f < field_types.size(); f++) { + if (field_types[f].first == "column") { + column = field_values[f]; + } else if (field_types[f].first == "order") { + order = field_values[f]; + } + } + if (column.IsNull() || order.IsNull()) { + throw BinderException("to_aggregate_state: each ORDER BY entry must have a non-NULL 'column' and 'order'"); + } + SortedAggregateStateOrder state_order; + state_order.column = column.GetValue(); + if (state_order.column >= column_count) { + throw BinderException( + "to_aggregate_state: ORDER BY column %llu is out of range (the state has %llu columns)", + (uint64_t)state_order.column, (uint64_t)column_count); + } + auto modifiers = OrderModifiers::Parse(StringValue::Get(order)); + state_order.order_type = modifiers.order_type; + state_order.null_order = modifiers.null_type; + orders.push_back(state_order); + } +} + unique_ptr ToAggregateStateBind(BindScalarFunctionInput &input) { auto &bound_function = input.GetBoundFunction(); auto &arguments = input.GetArguments(); @@ -916,6 +1016,29 @@ unique_ptr ToAggregateStateBind(BindScalarFunctionInput &input) { "Aggregate function \"%s\" does not have a state type callback defined - cannot convert to its state", function_name); } + + if (arguments.size() > 4) { + // an ordered aggregate state: the value is the buffer (LIST), the fifth argument the ORDER BY + auto &state_type = arguments[0]->GetReturnType(); + if (state_type.id() != LogicalTypeId::LIST || + ListType::GetChildType(state_type).id() != LogicalTypeId::STRUCT) { + throw BinderException("to_aggregate_state: an ordered aggregate state value must be a LIST of STRUCTs (the " + "buffer of values), e.g. [{'v0': ...}, ...]"); + } + const auto buffer_struct = ListType::GetChildType(state_type); + const idx_t column_count = StructType::GetChildTypes(buffer_struct).size(); + vector orders; + auto order_value = ExpressionExecutor::EvaluateScalar(context, *arguments[4]); + ParseOrderBys(order_value, column_count, orders); + if (orders.empty()) { + throw BinderException("to_aggregate_state: an ordered aggregate state must have at least one ORDER BY key"); + } + bound_function.GetArguments()[0] = LogicalType::LIST(buffer_struct); + bound_function.SetReturnType( + CreateSortedAggregateStateType(aggr, bind_data->bind_data.get(), buffer_struct, orders)); + return std::move(bind_data); + } + auto state_layout = aggr.GetStateType(bind_data->bind_data.get()).type; bound_function.GetArguments()[0] = state_layout; bound_function.SetReturnType(CreateAggregateStateType(aggr, bind_data->bind_data.get())); @@ -963,6 +1086,17 @@ ExportAggregateFunction::Bind(unique_ptr child_aggrega D_ASSERT(bound_function.HasStateSizeCallback()); D_ASSERT(bound_function.HasStateFinalizeCallback()); D_ASSERT(child_aggregate->Function().GetReturnType().id() != LogicalTypeId::INVALID); + if (child_aggregate->GetOrderBys() && !child_aggregate->GetOrderBys()->orders.empty()) { + // ordered aggregate: export the buffer of values. The sorted wrapper is built later (physical planning); here + // we only fix the AGGREGATE_STATE type so downstream binding (finalize/combine) sees it + LogicalType buffer_struct; + vector orders; + idx_t argument_count; // re-derived from the inner aggregate on re-bind + FunctionBinder::GetSortedAggregateStateLayout(*child_aggregate, buffer_struct, orders, argument_count); + SetStateExport(*child_aggregate, CreateSortedAggregateStateType( + bound_function, child_aggregate->BindInfo().get(), buffer_struct, orders)); + return child_aggregate; + } SetStateExport(*child_aggregate, CreateAggregateStateType(bound_function, child_aggregate->BindInfo().get())); return child_aggregate; } @@ -999,11 +1133,15 @@ ScalarFunction CombineFun::GetFunction() { ScalarFunctionSet ToAggregateStateFun::GetFunctions() { ScalarFunctionSet set("to_aggregate_state"); vector arguments {LogicalTypeId::ANY, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::ANY)}; - for (idx_t constant_params = 0; constant_params < 2; constant_params++) { - if (constant_params) { + for (idx_t optional_args = 0; optional_args < 3; optional_args++) { + if (optional_args == 1) { // optional fourth argument: constant parameter values as a list with one entry per argument (e.g. // [NULL, ',']) arguments.emplace_back(LogicalTypeId::ANY); + } else if (optional_args == 2) { + // optional fifth argument: the ORDER BY of an ordered aggregate, as a list of {column, order} structs + // (e.g. [{'column': 1, 'order': 'DESC NULLS LAST'}]) + arguments.emplace_back(LogicalTypeId::ANY); } ScalarFunction function("to_aggregate_state", arguments, LogicalTypeId::ANY, ToAggregateStateFunction, ToAggregateStateBind); diff --git a/src/duckdb/src/function/scalar/variant/variant_normalize.cpp b/src/duckdb/src/function/scalar/variant/variant_normalize.cpp index c44975bb0..a8780956f 100644 --- a/src/duckdb/src/function/scalar/variant/variant_normalize.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_normalize.cpp @@ -260,9 +260,15 @@ static void VariantNormalizeFunction(DataChunk &input, ExpressionState &state, V VariantNormalizer::Normalize(variant_vec, result); } +static unique_ptr VariantNormalizeStats(ClientContext &context, FunctionStatisticsInput &input) { + // variant_normalize re-encodes the VARIANT in a canonical binary form, but does not change any values + return input.child_stats[0].ToUnique(); +} + ScalarFunction VariantNormalizeFun::GetFunction() { auto variant_type = LogicalType::VARIANT(); - return ScalarFunction("variant_normalize", {variant_type}, variant_type, VariantNormalizeFunction); + return ScalarFunction("variant_normalize", {variant_type}, variant_type, VariantNormalizeFunction, nullptr, + VariantNormalizeStats); } } // namespace duckdb diff --git a/src/duckdb/src/function/scalar/variant/variant_utils.cpp b/src/duckdb/src/function/scalar/variant/variant_utils.cpp index b246ab2c4..d192437dd 100644 --- a/src/duckdb/src/function/scalar/variant/variant_utils.cpp +++ b/src/duckdb/src/function/scalar/variant/variant_utils.cpp @@ -11,6 +11,7 @@ #include "duckdb/function/variant/variant_value_convert.hpp" #include "duckdb/common/type_visitor.hpp" #include "duckdb/function/variant/variant_shredding.hpp" +#include "duckdb/common/types/variant/variant_builder.hpp" namespace duckdb { @@ -437,4 +438,35 @@ bool VariantUtils::VariantSupportsType(const LogicalType &type) { } } +//===--------------------------------------------------------------------===// +// ToVariant sources +//===--------------------------------------------------------------------===// +// The single-pass build machinery (VariantBuilder / EmitIterator / BuildVariant) lives in +// variant_builder.hpp so it can be shared with the parquet reader. A "source" just implements +// 'bool Emit(idx_t row, VariantBuilder &builder)' (returning whether the row is a SQL NULL). +namespace { + +struct VariantIteratorSource { + explicit VariantIteratorSource(const VariantIterator &state) : state(state) { + } + bool Emit(idx_t row, VariantBuilder &builder) const { + auto root = state.Root(row); + //! Root() resolves a missing/absent root to a SQL NULL + if (root.IsNull()) { + return true; + } + EmitIterator(root, builder); + return false; + } + + const VariantIterator &state; +}; + +} // namespace + +void VariantUtils::ToVariant(const VariantIterator &state, idx_t count, Vector &result) { + VariantIteratorSource source(state); + BuildVariant(source, count, result); +} + } // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_variables.cpp b/src/duckdb/src/function/table/system/duckdb_variables.cpp index 4694de73c..d31194313 100644 --- a/src/duckdb/src/function/table/system/duckdb_variables.cpp +++ b/src/duckdb/src/function/table/system/duckdb_variables.cpp @@ -44,7 +44,7 @@ unique_ptr DuckDBVariablesInit(ClientContext &context, for (auto &entry : config.user_variables) { VariableData data; - data.name = entry.first; + data.name = entry.first.GetIdentifierName(); data.value = entry.second; result->variables.push_back(std::move(data)); } diff --git a/src/duckdb/src/function/table/system/logging_utils.cpp b/src/duckdb/src/function/table/system/logging_utils.cpp index aa1c34e73..e90bab231 100644 --- a/src/duckdb/src/function/table/system/logging_utils.cpp +++ b/src/duckdb/src/function/table/system/logging_utils.cpp @@ -92,6 +92,24 @@ static unique_ptr BindEnableLogging(ClientContext &context, TableF } } + // File logging requires a path. Reject switching to file storage without one before mutating any + // state, so the active storage is preserved instead of becoming a path-less storage that throws + // on every later flush (end-of-query and shutdown included). + if (StringUtil::Lower(result->config.storage) == LogConfig::FILE_STORAGE_NAME) { + auto current_storage = StringUtil::Lower(context.db->GetLogManager().GetConfig().storage); + // Already-active file storage keeps its existing path; only guard a fresh switch. + if (current_storage != LogConfig::FILE_STORAGE_NAME) { + auto path_entry = result->storage_config.find("path"); + bool has_usable_path = path_entry != result->storage_config.end() && !path_entry->second.IsNull() && + !path_entry->second.ToString().empty(); + if (!has_usable_path) { + throw InvalidInputException( + "Cannot enable 'file' log storage without a valid path. Provide one via storage_path, " + "e.g. CALL enable_logging(storage='file', storage_path='mylog.csv');"); + } + } + } + // Process positional params if (!input.inputs.empty()) { if (input.inputs[0].type() == LogicalType::VARCHAR) { diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index 423ef4229..6ebad1c71 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev9045" +#define DUCKDB_PATCH_VERSION "0-dev9556" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 6 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.6.0-dev9045" +#define DUCKDB_VERSION "v1.6.0-dev9556" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "17ba7dd2b6" +#define DUCKDB_SOURCE_ID "3518c60315" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/include/duckdb.h b/src/duckdb/src/include/duckdb.h index fa3cf5688..e29f9f1e0 100644 --- a/src/duckdb/src/include/duckdb.h +++ b/src/duckdb/src/include/duckdb.h @@ -504,7 +504,7 @@ typedef struct { } duckdb_bit; //! BIGNUMs are composed of a byte pointer, a size, and an `is_negative` bool. -//! The absolute value of the number is stored in `data` in little endian format. +//! The absolute value of the number is stored in `data` in big endian format. //! You must free `data` with `duckdb_free`. typedef struct { uint8_t *data; diff --git a/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp b/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp index 560d369dd..3940de109 100644 --- a/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp +++ b/src/duckdb/src/include/duckdb/catalog/dependency_manager.hpp @@ -29,6 +29,8 @@ struct DependencySubject { CatalogEntryInfo entry; //! The type of dependency this is (e.g, ownership) DependencySubjectFlags flags; + //! The oid of the subject entry when the dependency was created + optional_idx oid; }; // The entry that relies on the other entry @@ -104,6 +106,7 @@ class DependencyManager { bool IsSystemEntry(CatalogEntry &entry) const; optional_ptr LookupEntry(CatalogTransaction transaction, const LogicalDependency &dependency); optional_ptr LookupEntry(CatalogTransaction transaction, CatalogEntry &dependency); + optional_ptr LookupEntry(CatalogTransaction transaction, const CatalogEntryInfo &info); string CollectDependents(CatalogTransaction transaction, catalog_entry_set_t &entries, CatalogEntryInfo &info); void CleanupDependencies(CatalogTransaction transaction, CatalogEntry &entry); diff --git a/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp b/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp index 3e92a0b26..a1466dd63 100644 --- a/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp +++ b/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp @@ -19,7 +19,7 @@ namespace duckdb_adbc { class AppenderWrapper { public: AppenderWrapper(duckdb_connection conn, const char *catalog, const char *schema, const char *table) - : appender(nullptr) { + : appender(nullptr), create_error_type(DUCKDB_ERROR_UNKNOWN_TYPE) { // Note: duckdb_appender_create_ext allocates an internal wrapper even on failure. // If creation fails, make sure to destroy it to avoid leaking. auto created = duckdb_appender(nullptr); @@ -30,6 +30,7 @@ class AppenderWrapper { if (error_message) { create_error = error_message; } + create_error_type = duckdb_error_data_error_type(error_data); duckdb_destroy_error_data(&error_data); duckdb_appender_destroy(&created); } @@ -52,10 +53,14 @@ class AppenderWrapper { const std::string &CreateError() const { return create_error; } + duckdb_error_type CreateErrorType() const { + return create_error_type; + } private: duckdb_appender appender; std::string create_error; + duckdb_error_type create_error_type; }; class DataChunkWrapper { @@ -191,6 +196,9 @@ AdbcStatusCode StatementSetOptionDouble(struct AdbcStatement *statement, const c const AdbcError *ErrorFromArrayStream(struct ArrowArrayStream *stream, AdbcStatusCode *status); +int ErrorGetDetailCount(const struct AdbcError *error); +struct AdbcErrorDetail ErrorGetDetail(const struct AdbcError *error, int index); + AdbcStatusCode StatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, struct AdbcError *error); diff --git a/src/duckdb/src/include/duckdb/common/allocator.hpp b/src/duckdb/src/include/duckdb/common/allocator.hpp index 2d7d35ead..9f0a72c11 100644 --- a/src/duckdb/src/include/duckdb/common/allocator.hpp +++ b/src/duckdb/src/include/duckdb/common/allocator.hpp @@ -130,6 +130,10 @@ class Allocator { static void SetBackgroundThreads(bool enable); private: + //! Returns free memory in the system heap (glibc) to the OS via malloc_trim, rate-limited to once + //! per 100ms. No-op on non-glibc platforms. Shared by the flush paths of both allocator backends. + static void MallocTrim(idx_t pad); + allocate_function_ptr_t allocate_function; free_function_ptr_t free_function; reallocate_function_ptr_t reallocate_function; diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp index c0b60abf9..0ce7493e0 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp @@ -94,6 +94,13 @@ struct ArrowAppendData { } void AppendValidity(UnifiedVectorFormat &format, idx_t from, idx_t to); + //! Append a (child) vector, routing it through the Arrow extension's duckdb_to_arrow + //! conversion first when one is set. Container appenders must call this instead of + //! append_vector so nested extension types (e.g. arrow.bool8 BOOLEAN) get the same + //! conversion the top-level appender applies, keeping the data layout in sync with + //! the schema declared by SetArrowFormat. + void AppendChild(const Vector &input, idx_t from, idx_t to, idx_t input_size); + public: idx_t row_count = 0; idx_t null_count = 0; diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp index 612937c70..211fe9a1a 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp @@ -38,7 +38,7 @@ struct ArrowListData { auto child_size = child_indices.size(); Vector child_copy(child.GetType()); child_copy.Slice(child, child_sel, child_size); - append_data.child_data[0]->append_vector(*append_data.child_data[0], child_copy, 0, child_size, child_size); + append_data.child_data[0]->AppendChild(child_copy, 0, child_size, child_size); append_data.row_count += size; } diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp index cdd5d7d48..f0ed68367 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp @@ -40,7 +40,7 @@ struct ArrowListViewData { auto child_size = child_indices.size(); Vector child_copy(child.GetType()); child_copy.Slice(child, child_sel, child_size); - append_data.child_data[0]->append_vector(*append_data.child_data[0], child_copy, 0, child_size, child_size); + append_data.child_data[0]->AppendChild(child_copy, 0, child_size, child_size); append_data.row_count += size; } diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp index 560e4c3e4..77116a678 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp @@ -57,8 +57,8 @@ struct ArrowMapData { key_vector_copy.Slice(key_vector, child_sel, list_size); Vector value_vector_copy(value_vector.GetType()); value_vector_copy.Slice(value_vector, child_sel, list_size); - key_data.append_vector(key_data, key_vector_copy, 0, list_size, list_size); - value_data.append_vector(value_data, value_vector_copy, 0, list_size, list_size); + key_data.AppendChild(key_vector_copy, 0, list_size, list_size); + value_data.AppendChild(value_vector_copy, 0, list_size, list_size); append_data.row_count += size; struct_data.row_count += size; diff --git a/src/duckdb/src/include/duckdb/common/enum_util.hpp b/src/duckdb/src/include/duckdb/common/enum_util.hpp index f28d89e1c..83554cd72 100644 --- a/src/duckdb/src/include/duckdb/common/enum_util.hpp +++ b/src/duckdb/src/include/duckdb/common/enum_util.hpp @@ -542,8 +542,6 @@ enum class VariantLogicalType : uint8_t; enum class VariantStatsShreddingState : uint8_t; -enum class VariantValueType : uint8_t; - enum class VectorBufferType : uint8_t; enum class VectorType : uint8_t; @@ -1328,9 +1326,6 @@ const char* EnumUtil::ToChars(VariantLogicalType value); template<> const char* EnumUtil::ToChars(VariantStatsShreddingState value); -template<> -const char* EnumUtil::ToChars(VariantValueType value); - template<> const char* EnumUtil::ToChars(VectorBufferType value); @@ -2124,9 +2119,6 @@ VariantLogicalType EnumUtil::FromString(const char *value); template<> VariantStatsShreddingState EnumUtil::FromString(const char *value); -template<> -VariantValueType EnumUtil::FromString(const char *value); - template<> VectorBufferType EnumUtil::FromString(const char *value); diff --git a/src/duckdb/src/include/duckdb/common/http_util.hpp b/src/duckdb/src/include/duckdb/common/http_util.hpp index baecd2172..dbe4ae311 100644 --- a/src/duckdb/src/include/duckdb/common/http_util.hpp +++ b/src/duckdb/src/include/duckdb/common/http_util.hpp @@ -153,6 +153,8 @@ struct BaseRequest { bool have_request_timing = false; timestamp_t request_start; timestamp_t request_end; + //! Request body size in bytes (the Content-Length we send). Only set for PUT/POST. + idx_t request_body_length = 0; //! Optional per-request network measurements, populated by clients that measure them. bool have_time_to_fst_byte = false; @@ -194,6 +196,7 @@ struct PutRequestInfo : public BaseRequest { idx_t buffer_in_len, const string &content_type) : BaseRequest(RequestType::PUT_REQUEST, path, headers, params), buffer_in(buffer_in), buffer_in_len(buffer_in_len), content_type(content_type) { + request_body_length = buffer_in_len; } const_data_ptr_t buffer_in; @@ -224,6 +227,7 @@ struct PostRequestInfo : public BaseRequest { idx_t buffer_in_len) : BaseRequest(RequestType::POST_REQUEST, path, headers, params), buffer_in(buffer_in), buffer_in_len(buffer_in_len) { + request_body_length = buffer_in_len; } const_data_ptr_t buffer_in; diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp index 6c021a29c..0665fca15 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp @@ -851,7 +851,7 @@ class MultiFileFunction : public TableFunction { static unique_ptr MultiFileCardinality(ClientContext &context, const FunctionData *bind_data) { auto &data = bind_data->Cast(); if (IsEmptyResult(data)) { - return make_uniq(0); + return make_uniq(idx_t(0)); } auto file_list_cardinality_estimate = data.file_list->GetCardinality(context); if (file_list_cardinality_estimate) { diff --git a/src/duckdb/src/include/duckdb/common/serializer/async_memory_governor.hpp b/src/duckdb/src/include/duckdb/common/serializer/async_memory_governor.hpp new file mode 100644 index 000000000..fd4174356 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/async_memory_governor.hpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/async_memory_governor.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +class ClientContext; +class TemporaryMemoryState; + +//! Memory policy shared by the managed async queues. +struct ManagedAsyncMemoryConfig { + //! Maximum queued async bytes retained per regular execution thread. + static constexpr idx_t MAX_PENDING_BYTES_PER_THREAD = 64ULL * 1024ULL * 1024ULL; + //! Minimum async reservation requested per regular execution thread. + static constexpr idx_t MIN_PENDING_BYTES_PER_THREAD = 8ULL * 1024ULL * 1024ULL; + //! Below this reservation, do not retain an async backlog (behave close to synchronous draining). + static constexpr idx_t MIN_RESERVATION_FOR_BACKLOG = 8ULL * 1024ULL * 1024ULL; +}; + +//! Shared TemporaryMemoryManager reservation governor for the managed async queues. +//! Encapsulates the coarse-growth reservation heuristic and backpressure budget so each queue bounds its +//! queued and in-flight backlog through the same memory logic. +class ManagedAsyncMemoryGovernor { +public: + explicit ManagedAsyncMemoryGovernor(ClientContext &client_context); + ~ManagedAsyncMemoryGovernor(); + + ManagedAsyncMemoryGovernor(const ManagedAsyncMemoryGovernor &) = delete; + ManagedAsyncMemoryGovernor &operator=(const ManagedAsyncMemoryGovernor &) = delete; + + //! Whether a TemporaryMemoryState reservation is tracking this queue's backlog. + bool IsActive() const; + //! Grow the reservation coarsely until it covers current_pending_bytes; released only on Release(). + void UpdateReservation(idx_t current_pending_bytes); + //! Current async backlog budget after applying the fixed cap, or 0 when memory is too tight to retain a backlog. + idx_t BackpressureBudget() const; + //! Release the reservation; further UpdateReservation calls may grow it again. + void Release(); + +private: + ClientContext &client_context; + //! Temporary memory reservation state used to limit queued async data. Absent when draining synchronously. + unique_ptr memory_state; + //! Last remaining-size request sent to TemporaryMemoryManager. Grows monotonically until Release(). + idx_t memory_request_bytes = 0; + //! Minimum TemporaryMemoryManager reservation while work is outstanding. + idx_t min_pending_bytes = 0; + //! Hard cap over the TemporaryMemoryState reservation. + idx_t max_pending_bytes = 0; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/async_task_queue.hpp b/src/duckdb/src/include/duckdb/common/serializer/async_task_queue.hpp new file mode 100644 index 000000000..b9962aa43 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/async_task_queue.hpp @@ -0,0 +1,207 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/async_task_queue.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/deque.hpp" +#include "duckdb/common/error_data.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/serializer/async_memory_governor.hpp" +#include "duckdb/parallel/async_result.hpp" + +#include + +namespace duckdb { + +class ClientContext; +class TaskExecutor; + +//! Completion callback for one generic async task. The error is set when the task failed. +using AsyncTaskCompletionCallback = std::function error)>; + +//! One unit of generic async work, tagged with a byte size used for memory accounting. +class AsyncTaskRequest { +public: + AsyncTaskRequest() = default; + AsyncTaskRequest(unique_ptr task, idx_t size, AsyncTaskCompletionCallback completion = nullptr); + + //! The byte size reported for this task (used to bound queued/in-flight memory). + idx_t Size() const; + + unique_ptr task; + idx_t size = 0; + AsyncTaskCompletionCallback completion; +}; + +//! Minimal generic async task scheduler. +//! Each drain task executes exactly one request, so up to max_active_tasks requests run concurrently. +//! Memory policy lives in the ManagedAsyncTaskQueue wrapper; this is the bare scheduling/draining layer. +class AsyncTaskQueue { + friend class AsyncTaskQueueTask; + friend class AsyncTaskQueueTaskGuard; + +public: + DUCKDB_API explicit AsyncTaskQueue(ClientContext &client_context, idx_t max_active_tasks = 0); + DUCKDB_API ~AsyncTaskQueue(); + + AsyncTaskQueue(const AsyncTaskQueue &) = delete; + AsyncTaskQueue &operator=(const AsyncTaskQueue &) = delete; + +public: + //! Return whether tasks are drained by async scheduler tasks. If false, Submit runs the task synchronously. + DUCKDB_API bool IsAsync() const; + //! Return whether the async task executor has captured an error. + DUCKDB_API bool HasError(); + //! Submit one owned task to the configured sync/async path. + DUCKDB_API void Submit(AsyncTaskRequest request); + //! Return queued/in-flight bytes whose tasks have not completed yet. + DUCKDB_API idx_t PendingBytes(); + //! Execute one queued task owned by this queue, or yield if the tasks are already running. + DUCKDB_API void WorkOnPendingTask(); + //! Wait until all submitted tasks have completed. + DUCKDB_API void Flush(); + //! Wait for all tasks and close the queue. + DUCKDB_API void Close(); + //! Surface an error thrown by an async drain task. + DUCKDB_API void RethrowTaskError(); + +private: + //! Schedule one drain task per still-unclaimed pending request, up to max_active_tasks. + void ScheduleTasksInternal(); + //! Release one scheduled/running task slot and its in-flight byte accounting. + void FinishTask(idx_t task_size); + //! Release a task slot for a scheduled task that never started because another task failed. + void CancelScheduledTask(); + //! Release multiple reserved task slots that were never scheduled. + void CancelScheduledTasks(idx_t task_count); + + //! Async task entry point that drains exactly one pending request. + void DrainRequest(); + //! Run one task and invoke its completion callback. + void ExecuteRequest(AsyncTaskRequest request); + //! Invoke a completion callback outside the queue lock. + void CompleteRequest(AsyncTaskRequest &request, idx_t size, optional_ptr error); + //! Throw if a mutating API is used after Close(). + void VerifyOpen() const; + //! Throw if the queue still owns registered or scheduled work. + void VerifyDrained() const; + //! Fail and discard queued tasks after an async failure once all scheduled tasks have stopped. + void CancelPendingRequestsAfterFailure(const ErrorData &error) noexcept; + +private: + ClientContext &client_context; + //! Maximum scheduled/running tasks for this queue. + idx_t max_active_tasks = 1; + + //! Protects state shared between the submitting thread and async tasks. + mutex lock; + //! Tasks waiting for an async drain task. + deque pending_requests; + //! Bytes queued in pending_requests that have not been claimed by a task yet. + idx_t pending_bytes = 0; + //! Bytes owned by running tasks that have not completed yet. + idx_t in_flight_bytes = 0; + //! Scheduled or running tasks for this queue. + idx_t active_tasks = 0; + //! Scheduled tasks that have not yet claimed a request. + idx_t pending_tasks = 0; + //! Set after Close() has drained the queue. Further submissions are rejected. + bool closed = false; + + //! Async task executor. If absent, tasks are executed synchronously on submission. + //! Keep this after task-accounting fields so queued task destructors can still release slots. + unique_ptr executor; +}; + +//! Generic, memory-managed, multi-producer async task queue. +//! Tasks are drained on the ASYNC TaskScheduler pool; queued and in-flight bytes are bounded by a shared +//! TemporaryMemoryManager reservation. Falls back to synchronous execution when async_threads == 0. +//! +//! Contract: +//! - Register is safe to call concurrently from multiple threads. +//! - Up to max_active_tasks tasks run concurrently; each drain task executes one task. +//! - With async_threads == 0, Register executes the task inline on the caller. +//! - The first task error is captured, further scheduling stops, and it is rethrown from WaitAll/Close. +//! Partial completion is possible on error. +//! - Call WaitAll then Close in the consumer's finalize step; the destructor asserts the queue is drained. +class ManagedAsyncTaskQueue { +public: + //! max_active_tasks == 0 -> use TaskScheduler::NumberOfAsyncThreads(). + DUCKDB_API explicit ManagedAsyncTaskQueue(ClientContext &client_context, idx_t max_active_tasks = 0); + DUCKDB_API ~ManagedAsyncTaskQueue(); + + ManagedAsyncTaskQueue(const ManagedAsyncTaskQueue &) = delete; + ManagedAsyncTaskQueue &operator=(const ManagedAsyncTaskQueue &) = delete; + +public: + //! Whether tasks are drained asynchronously (false => Register runs the task synchronously). + DUCKDB_API bool IsAsync() const; + //! Return whether the async task executor has captured an error. + DUCKDB_API bool HasError(); + //! Hand off one unit of work. byte_size feeds the memory accounting (use the serialized payload size). + //! The task's Execute() runs on an ASYNC-pool thread (or synchronously if !IsAsync()); it MUST throw on + //! failure (the first error is captured, scheduling stops, and it is rethrown from WaitAll/Close). + DUCKDB_API void Register(unique_ptr task, idx_t byte_size); + //! Block the calling (producer) thread, helping drain, until queued+in-flight bytes fall under the budget. + DUCKDB_API void ApplyBackpressure(); + //! Drain everything; after WaitAll returns (no error) all registered tasks have completed. + DUCKDB_API void WaitAll(); + //! WaitAll + release the memory reservation + reject further Register calls. + DUCKDB_API void Close(); + //! Surface an error thrown by an async drain task. + DUCKDB_API void RethrowTaskError(); + +private: + //! Whether scheduling should respect the in-flight window, or force all pending tasks to drain. + enum class SchedulePolicy : uint8_t { THRESHOLD, FORCE }; + +private: + //! Submit pending tasks to the low-level queue, bounded by the in-flight window unless forced. + void SchedulePendingTasks(SchedulePolicy policy = SchedulePolicy::THRESHOLD); + //! Grow the shared reservation coarsely to cover the current backlog. + void UpdateMemoryState(); + //! Return queued + submitted bytes that have not completed yet. Caller must hold lock. + idx_t TotalPendingBytes() const; + //! Move one pending task into a submission to the low-level queue. + bool TakePendingTaskRequest(AsyncTaskRequest &request, SchedulePolicy policy); + //! Wrap a request with submitted-byte accounting that runs when the task completes. + void AddCompletionAccounting(AsyncTaskRequest &request); + //! Release byte accounting for one submitted task and refill the in-flight window. + void CompleteSubmittedTask(idx_t size, optional_ptr error); + //! Throw if a mutating API is used after Close(). + void VerifyOpen() const; + //! Throw if the queue still owns registered or submitted work. + void VerifyDrained() const; + //! Fail and discard queued tasks after an async failure once all submitted tasks have stopped. + void CancelPendingTasksAfterFailure(const ErrorData &error) noexcept; + +private: + ClientContext &client_context; + //! Low-level one-unit-per-task scheduler. + unique_ptr task_queue; + //! Shared TemporaryMemoryManager reservation governor bounding queued async task data. + ManagedAsyncMemoryGovernor memory_governor; + //! Maximum number of submitted/running tasks for this queue (the in-flight window). + idx_t max_active_drain_tasks = 1; + + //! Protects state shared between registering threads and async completion callbacks. + mutex lock; + //! Tasks queued for submission to the low-level queue. + deque pending_requests; + //! Bytes queued in pending_requests that have not been submitted yet. + idx_t pending_bytes = 0; + //! Bytes submitted to the low-level queue that have not completed yet. + idx_t submitted_bytes = 0; + //! Submitted tasks that have not completed yet. + idx_t submitted_requests = 0; + //! Set after Close() has drained the queue. Further registration is rejected. + bool closed = false; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/async_write_queue.hpp b/src/duckdb/src/include/duckdb/common/serializer/async_write_queue.hpp index cd7720fb3..0350cbdec 100644 --- a/src/duckdb/src/include/duckdb/common/serializer/async_write_queue.hpp +++ b/src/duckdb/src/include/duckdb/common/serializer/async_write_queue.hpp @@ -13,6 +13,7 @@ #include "duckdb/common/error_data.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/serializer/async_memory_governor.hpp" #include @@ -20,7 +21,6 @@ namespace duckdb { class ClientContext; class TaskExecutor; -class TemporaryMemoryState; //! Compile-time policy used by the async write layers. struct AsyncWriteConfig { @@ -32,10 +32,6 @@ struct AsyncWriteConfig { static constexpr idx_t LOCAL_COALESCE_THRESHOLD = 4096; //! Remote file systems benefit from fewer round trips, so coalesce contiguous small buffers more aggressively. static constexpr idx_t REMOTE_COALESCE_THRESHOLD = 8ULL * 1024ULL * 1024ULL; - //! Maximum queued async bytes retained per regular execution thread. - static constexpr idx_t MAX_PENDING_BYTES_PER_THREAD = 64ULL * 1024ULL * 1024ULL; - //! Minimum async write reservation requested per regular execution thread. - static constexpr idx_t MIN_PENDING_BYTES_PER_THREAD = 8ULL * 1024ULL * 1024ULL; //! Maximum bytes a single managed stream request should submit before yielding scheduler capacity. static constexpr idx_t DRAIN_TASK_BYTE_BUDGET = 16ULL * 1024ULL * 1024ULL; }; @@ -299,16 +295,10 @@ class ManagedAsyncWriteQueue : private AsyncWriteTarget { //! Low-level positional request scheduler. unique_ptr write_queue; - //! Temporary memory reservation state used to limit queued async write data. - unique_ptr memory_state; - //! Last remaining-size request sent to TemporaryMemoryManager. Grows monotonically until close. - idx_t memory_request_bytes = 0; + //! Shared TemporaryMemoryManager reservation governor bounding queued async write data. + ManagedAsyncMemoryGovernor memory_governor; //! Maximum number of submitted/running drain requests for this queue. idx_t max_active_drain_tasks = 1; - //! Minimum TemporaryMemoryManager reservation while writes are outstanding. - idx_t min_pending_bytes = 0; - //! Hard cap over the TemporaryMemoryState reservation. - idx_t max_pending_bytes = 0; //! Maximum bytes one managed async request should submit before yielding scheduler capacity. idx_t drain_task_byte_budget = AsyncWriteConfig::DRAIN_TASK_BYTE_BUDGET; @@ -459,6 +449,8 @@ class ManagedAsyncWriteStreamQueue : private AsyncWriteTarget { idx_t submitted_requests = 0; //! Nested batch depth. While non-zero, async draining and backpressure are delayed. idx_t batch_depth = 0; + //! Whether completion-driven refills should ignore the normal first-task threshold. + bool force_completion_refill = false; //! Next logical offset expected by RegisterWrite. Enforces v1 contiguous-registration semantics. idx_t next_registration_offset = 0; //! Set after Close() has drained the queue. Further write registration is rejected. diff --git a/src/duckdb/src/include/duckdb/common/types/geometry.hpp b/src/duckdb/src/include/duckdb/common/types/geometry.hpp index 75a39f2b7..fe45f1a34 100644 --- a/src/duckdb/src/include/duckdb/common/types/geometry.hpp +++ b/src/duckdb/src/include/duckdb/common/types/geometry.hpp @@ -119,10 +119,27 @@ class GeometryExtent { return GeometryExtent {EMPTY_MIN, EMPTY_MIN, EMPTY_MIN, EMPTY_MIN, EMPTY_MAX, EMPTY_MAX, EMPTY_MAX, EMPTY_MAX}; } - // Does this extent have any X/Y values set? - // In other words, is the range of the x/y axes not empty and not unknown? + // Does this extent have the X axis set? + // In other words, is the range of the x-axis not empty and not unknown? + bool HasX() const { + return std::isfinite(x_min) && std::isfinite(x_max); + } + // Does this extent have the Y axis set? + // In other words, is the range of the y-axis not empty and not unknown? + bool HasY() const { + return std::isfinite(y_min) && std::isfinite(y_max); + } + // Does this extent have both X and Y axes set? + // In other words, are the ranges of both the x and y axes not empty and not unknown? + // Used to gate serialization, where a non-finite axis cannot be represented. bool HasXY() const { - return std::isfinite(x_min) && std::isfinite(y_min) && std::isfinite(x_max) && std::isfinite(y_max); + return HasX() && HasY(); + } + // Can this extent be used for X/Y zonemap pruning? + // A single finite axis is enough: an unknown axis is treated as an infinite range, + // which intersects everything, so pruning simply degrades to the finite axis. + bool CanPruneXY() const { + return HasX() || HasY(); } // Does this extent have any Z values set? // In other words, is the range of the Z-axis not empty and not unknown? diff --git a/src/duckdb/src/include/duckdb/common/types/list_segment.hpp b/src/duckdb/src/include/duckdb/common/types/list_segment.hpp index f43947e4b..557f4117d 100644 --- a/src/duckdb/src/include/duckdb/common/types/list_segment.hpp +++ b/src/duckdb/src/include/duckdb/common/types/list_segment.hpp @@ -47,8 +47,8 @@ struct ListSegmentFunctions; typedef ListSegment *(*create_segment_t)(const ListSegmentFunctions &functions, ArenaAllocator &allocator, uint16_t capacity); typedef void (*write_data_to_segment_t)(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, - idx_t &entry_idx); + ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t offset, + idx_t count); //! Scans up to "count" rows from the state's current position into the result vector at result_offset, //! advancing the state - returns the number of rows scanned (less than "count" only if the scan is exhausted) typedef idx_t (*scan_data_t)(const ListSegmentFunctions &functions, ListSegmentScanState &state, idx_t count, @@ -62,8 +62,9 @@ struct ListSegmentFunctions { vector child_functions; - void AppendRow(ArenaAllocator &allocator, LinkedList &linked_list, RecursiveUnifiedVectorFormat &input_data, - idx_t &entry_idx) const; + //! Append "count" rows of input_data (starting at "offset") to the linked list + void AppendRows(ArenaAllocator &allocator, LinkedList &linked_list, RecursiveUnifiedVectorFormat &input_data, + idx_t offset, idx_t count) const; //! Append all rows of the given list entry (indexing into child_data) to the linked list void AppendListEntry(ArenaAllocator &allocator, LinkedList &linked_list, RecursiveUnifiedVectorFormat &child_data, const list_entry_t &list_entry) const; diff --git a/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp b/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp index 0ffa291d9..96f987b66 100644 --- a/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/types/selection_vector.hpp @@ -19,6 +19,11 @@ class VectorBuffer; struct SelectionData { DUCKDB_API explicit SelectionData(idx_t count); + // Out-of-line destructor: prevents GCC IPA-ICF from folding + // _Sp_counted_ptr_inplace::_M_dispose with the + // corresponding instantiation for TemplatedValidityData, which produces + // a spurious -Warray-bounds with g++ >= 14. + DUCKDB_API ~SelectionData(); AllocatedData owned_data; }; diff --git a/src/duckdb/src/include/duckdb/common/types/string_type.hpp b/src/duckdb/src/include/duckdb/common/types/string_type.hpp index 69b43e946..dd0cc8b92 100644 --- a/src/duckdb/src/include/duckdb/common/types/string_type.hpp +++ b/src/duckdb/src/include/duckdb/common/types/string_type.hpp @@ -40,6 +40,7 @@ struct string_t { string_t() = default; explicit string_t(uint32_t len) { value.inlined.length = len; + memset(value.inlined.inlined, 0, INLINE_BYTES); } string_t(const char *data, uint32_t len) { value.inlined.length = len; diff --git a/src/duckdb/src/include/duckdb/common/types/value.hpp b/src/duckdb/src/include/duckdb/common/types/value.hpp index 4ecce53b6..fecd35f84 100644 --- a/src/duckdb/src/include/duckdb/common/types/value.hpp +++ b/src/duckdb/src/include/duckdb/common/types/value.hpp @@ -485,6 +485,11 @@ struct TypeValue { DUCKDB_API static LogicalType GetType(const Value &value); }; +struct VariantValue { + //! Convert a (non-null) VARIANT-typed Value back to a plain Value + DUCKDB_API static Value GetValue(const Value &variant_val); +}; + //! Return the internal integral value for any type that is stored as an integral value internally //! This can be used on values of type integer, uinteger, but also date, timestamp, decimal, etc struct IntegralValue { diff --git a/src/duckdb/src/include/duckdb/common/types/variant.hpp b/src/duckdb/src/include/duckdb/common/types/variant.hpp index 2b7219de4..619b62c42 100644 --- a/src/duckdb/src/include/duckdb/common/types/variant.hpp +++ b/src/duckdb/src/include/duckdb/common/types/variant.hpp @@ -72,6 +72,16 @@ struct VariantNestedData { uint32_t children_idx; }; +//! The (width, scale) of a DECIMAL value - the physical storage type (and hence the payload) follows +//! from the width (see VariantDecimalData::GetPhysicalType) +struct VariantDecimalProperties { + VariantDecimalProperties(uint32_t width, uint32_t scale) : width(width), scale(scale) { + } + + uint32_t width; + uint32_t scale; +}; + struct VariantDecimalData { public: VariantDecimalData(uint32_t width, uint32_t scale, const_data_ptr_t value_ptr) diff --git a/src/duckdb/src/include/duckdb/common/types/variant/variant_builder.hpp b/src/duckdb/src/include/duckdb/common/types/variant/variant_builder.hpp new file mode 100644 index 000000000..2d94fdbc6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/variant/variant_builder.hpp @@ -0,0 +1,590 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/variant/variant_builder.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/vector/flat_vector.hpp" +#include "duckdb/common/vector/vector_writer.hpp" +#include "duckdb/common/vector/list_vector.hpp" +#include "duckdb/common/vector/string_vector.hpp" +#include "duckdb/common/vector/variant_vector.hpp" +#include "duckdb/common/types/variant_iterator.hpp" +#include "duckdb/common/serializer/varint.hpp" +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/datetime.hpp" +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/interval.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/variant.hpp" +#include "duckdb/common/hugeint.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/owning_string_map.hpp" +#include "duckdb/common/numeric_utils.hpp" +#include "duckdb/common/limits.hpp" + +#include + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// Building a VARIANT in a single pass +//===--------------------------------------------------------------------===// +// The canonical (unshredded) VARIANT layout is built directly while traversing the source tree once. +// Rather than a separate "analyze sizes" pass followed by an in-place "convert" pass, the blob bytes +// and the (values / children / keys) entries are accumulated into growable buffers, and copied into +// the result vectors once the per-row sizes are known. The data is sourced either from a +// vector, a VariantIterator (unshredding), or a ParquetVariantIterator (the parquet reader). + +//! Sentinel marking an array child (whose 'key_id' is NULL) +constexpr uint32_t VARIANT_INVALID_KEY = NumericLimits::Maximum(); + +inline void VariantBuilderAppendVarint(string &blob, uint32_t value) { + auto size = GetVarintSize(value); + auto pos = blob.size(); + blob.resize(pos + size); + VarintEncode(value, data_ptr_cast(blob.data()) + pos); +} + +template +void VariantBuilderAppendFixed(string &blob, T value) { + auto pos = blob.size(); + blob.resize(pos + sizeof(T)); + Store(value, data_ptr_cast(blob.data()) + pos); +} + +inline void VariantBuilderAppendBytes(string &blob, const_data_ptr_t data, idx_t size) { + blob.append(const_char_ptr_cast(data), size); +} + +inline uint32_t VariantBuilderGetOrCreateIndex(OrderedOwningStringMap &dictionary, const string_t &key) { + auto unsorted_idx = dictionary.size(); + //! This will later be remapped to the sorted idx (see FinalizeVariantKeys in 'to_variant.cpp') + return dictionary.emplace(std::make_pair(key, unsorted_idx)).first->second; +} + +//! The physical storage type of a DECIMAL of the given width (matches VariantDecimalData::GetPhysicalType) +inline PhysicalType VariantDecimalPhysicalType(uint32_t width) { + if (width > DecimalWidth::max) { + return PhysicalType::INT128; + } else if (width > DecimalWidth::max) { + return PhysicalType::INT64; + } else if (width > DecimalWidth::max) { + return PhysicalType::INT32; + } + return PhysicalType::INT16; +} + +//! Accumulates the canonical representation of a single chunk while traversing the source once. +struct VariantBuilder { + explicit VariantBuilder(OrderedOwningStringMap &dictionary) : dictionary(dictionary) { + } + + //! values: (type_id, byte_offset) in pre-order + vector type_ids; + vector byte_offsets; + //! children: (key_id, value_id) - key_id is VARIANT_INVALID_KEY for array elements + vector child_key_ids; + vector child_value_ids; + //! one (unsorted) dictionary index per object-child key slot + vector key_slots; + //! the blob of the row currently being built (reused across rows) + string blob; + //! maps a key string to its (unsorted) dictionary index, owned by the result's keys vector + OrderedOwningStringMap &dictionary; + + //! the offsets at which the current row's entries begin + idx_t row_values = 0; + idx_t row_children = 0; + idx_t row_keys = 0; + + void BeginRow() { + row_values = type_ids.size(); + row_children = child_value_ids.size(); + row_keys = key_slots.size(); + blob.clear(); + } + //! The current value / child / key index, relative to the start of the row + uint32_t LocalValue() const { + return NumericCast(type_ids.size() - row_values); + } + uint32_t LocalChild() const { + return NumericCast(child_value_ids.size() - row_children); + } + uint32_t LocalKey() const { + return NumericCast(key_slots.size() - row_keys); + } + + //! Emit a VARIANT_NULL value + void EmitNull() { + type_ids.push_back(static_cast(VariantLogicalType::VARIANT_NULL)); + byte_offsets.push_back(NumericCast(blob.size())); + } + + //! Emit an OBJECT value with 'n' children (assumed to be in lexicographic key order). 'key_fn(i)' + //! returns the (string_t) key of child i; 'emit_fn(i)' must emit exactly one value for child i. + template + void EmitObject(idx_t n, KEY_FN &&key_fn, EMIT_FN &&emit_fn) { + auto byte_offset = NumericCast(blob.size()); + type_ids.push_back(static_cast(VariantLogicalType::OBJECT)); + byte_offsets.push_back(byte_offset); + VariantBuilderAppendVarint(blob, NumericCast(n)); + if (!n) { + return; + } + VariantBuilderAppendVarint(blob, LocalChild()); + auto block = child_value_ids.size(); + child_value_ids.resize(block + n); + child_key_ids.resize(block + n); + for (idx_t i = 0; i < n; i++) { + child_value_ids[block + i] = LocalValue(); + child_key_ids[block + i] = LocalKey(); + key_slots.push_back(VariantBuilderGetOrCreateIndex(dictionary, key_fn(i))); + emit_fn(i); + } + } + + //! Emit an ARRAY value with 'n' elements. 'emit_fn(i)' must emit exactly one value for element i. + template + void EmitArray(idx_t n, EMIT_FN &&emit_fn) { + auto byte_offset = NumericCast(blob.size()); + type_ids.push_back(static_cast(VariantLogicalType::ARRAY)); + byte_offsets.push_back(byte_offset); + VariantBuilderAppendVarint(blob, NumericCast(n)); + if (!n) { + return; + } + VariantBuilderAppendVarint(blob, LocalChild()); + auto block = child_value_ids.size(); + child_value_ids.resize(block + n); + child_key_ids.resize(block + n); + for (idx_t i = 0; i < n; i++) { + child_value_ids[block + i] = LocalValue(); + child_key_ids[block + i] = VARIANT_INVALID_KEY; + emit_fn(i); + } + } + + //! Emit a plain (non-nested) Value as a primitive variant value + void EmitPrimitive(const Value &primitive, uint32_t byte_offset) { + auto type_id = primitive.type().id(); + VariantLogicalType variant_type; + switch (type_id) { + case LogicalTypeId::BOOLEAN: + variant_type = primitive.GetValue() ? VariantLogicalType::BOOL_TRUE : VariantLogicalType::BOOL_FALSE; + break; + case LogicalTypeId::SQLNULL: + variant_type = VariantLogicalType::VARIANT_NULL; + break; + case LogicalTypeId::TINYINT: + variant_type = VariantLogicalType::INT8; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::SMALLINT: + variant_type = VariantLogicalType::INT16; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::INTEGER: + variant_type = VariantLogicalType::INT32; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::BIGINT: + variant_type = VariantLogicalType::INT64; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::HUGEINT: + variant_type = VariantLogicalType::INT128; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::UTINYINT: + variant_type = VariantLogicalType::UINT8; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::USMALLINT: + variant_type = VariantLogicalType::UINT16; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::UINTEGER: + variant_type = VariantLogicalType::UINT32; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::UBIGINT: + variant_type = VariantLogicalType::UINT64; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::UHUGEINT: + variant_type = VariantLogicalType::UINT128; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::DOUBLE: + variant_type = VariantLogicalType::DOUBLE; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::FLOAT: + variant_type = VariantLogicalType::FLOAT; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::DATE: + variant_type = VariantLogicalType::DATE; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::TIMESTAMP_TZ: + variant_type = VariantLogicalType::TIMESTAMP_MICROS_TZ; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::TIMESTAMP_TZ_NS: + variant_type = VariantLogicalType::TIMESTAMP_NANOS_TZ; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::TIMESTAMP: + variant_type = VariantLogicalType::TIMESTAMP_MICROS; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::TIMESTAMP_SEC: + variant_type = VariantLogicalType::TIMESTAMP_SEC; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::TIMESTAMP_MS: + variant_type = VariantLogicalType::TIMESTAMP_MILIS; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::TIME: + variant_type = VariantLogicalType::TIME_MICROS; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::TIME_NS: + variant_type = VariantLogicalType::TIME_NANOS; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::TIME_TZ: + variant_type = VariantLogicalType::TIME_MICROS_TZ; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::TIMESTAMP_NS: + variant_type = VariantLogicalType::TIMESTAMP_NANOS; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::INTERVAL: + variant_type = VariantLogicalType::INTERVAL; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::UUID: + variant_type = VariantLogicalType::UUID; + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case LogicalTypeId::DECIMAL: { + variant_type = VariantLogicalType::DECIMAL; + auto &type = primitive.type(); + uint8_t width; + uint8_t scale; + type.GetDecimalProperties(width, scale); + VariantBuilderAppendVarint(blob, width); + VariantBuilderAppendVarint(blob, scale); + switch (type.InternalType()) { + case PhysicalType::INT16: + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case PhysicalType::INT32: + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case PhysicalType::INT64: + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + case PhysicalType::INT128: + VariantBuilderAppendFixed(blob, primitive.GetValueUnsafe()); + break; + default: + throw InternalException("Unexpected physical type for Decimal value: %s", + EnumUtil::ToString(type.InternalType())); + } + break; + } + case LogicalTypeId::BLOB: + case LogicalTypeId::BIGNUM: + case LogicalTypeId::BIT: + case LogicalTypeId::GEOMETRY: + case LogicalTypeId::VARCHAR: { + if (type_id == LogicalTypeId::BLOB) { + variant_type = VariantLogicalType::BLOB; + } else if (type_id == LogicalTypeId::BIGNUM) { + variant_type = VariantLogicalType::BIGNUM; + } else if (type_id == LogicalTypeId::BIT) { + variant_type = VariantLogicalType::BITSTRING; + } else if (type_id == LogicalTypeId::GEOMETRY) { + variant_type = VariantLogicalType::GEOMETRY; + } else { + variant_type = VariantLogicalType::VARCHAR; + } + auto string_data = primitive.GetValueUnsafe(); + VariantBuilderAppendVarint(blob, NumericCast(string_data.GetSize())); + VariantBuilderAppendBytes(blob, const_data_ptr_cast(string_data.GetData()), string_data.GetSize()); + break; + } + default: + throw InternalException("Encountered unrecognized LogicalType in EmitPrimitive: %s", + primitive.type().ToString()); + } + type_ids.push_back(static_cast(variant_type)); + byte_offsets.push_back(byte_offset); + } + + //! Emit a primitive value sourced from a VariantNode-like cursor. The fixed-width payload is fetched by + //! value via 'it.GetData()' (chosen by 'type_id'); strings via 'it.GetString()'; decimals via + //! 'it.GetDecimalProperties()' followed by 'it.GetData()' at the physical type implied by the width. + template + void EmitPrimitiveNode(const NODE &it, VariantLogicalType type_id) { + auto byte_offset = NumericCast(blob.size()); + type_ids.push_back(static_cast(type_id)); + byte_offsets.push_back(byte_offset); + switch (type_id) { + case VariantLogicalType::VARIANT_NULL: + case VariantLogicalType::BOOL_TRUE: + case VariantLogicalType::BOOL_FALSE: + break; + case VariantLogicalType::INT8: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::INT16: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::INT32: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::INT64: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::INT128: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::UINT8: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::UINT16: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::UINT32: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::UINT64: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::UINT128: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::FLOAT: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::DOUBLE: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::UUID: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::DATE: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::TIME_MICROS: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::TIME_NANOS: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::TIME_MICROS_TZ: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::TIMESTAMP_SEC: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::TIMESTAMP_MILIS: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::TIMESTAMP_MICROS: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::TIMESTAMP_NANOS: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::TIMESTAMP_MICROS_TZ: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::TIMESTAMP_NANOS_TZ: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::INTERVAL: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case VariantLogicalType::DECIMAL: { + auto properties = it.GetDecimalProperties(); + VariantBuilderAppendVarint(blob, properties.width); + VariantBuilderAppendVarint(blob, properties.scale); + switch (VariantDecimalPhysicalType(properties.width)) { + case PhysicalType::INT16: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case PhysicalType::INT32: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + case PhysicalType::INT64: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + default: + VariantBuilderAppendFixed(blob, it.template GetData()); + break; + } + break; + } + case VariantLogicalType::VARCHAR: + case VariantLogicalType::BLOB: + case VariantLogicalType::BIGNUM: + case VariantLogicalType::BITSTRING: + case VariantLogicalType::GEOMETRY: { + auto str = it.GetString(); + VariantBuilderAppendVarint(blob, NumericCast(str.GetSize())); + VariantBuilderAppendBytes(blob, const_data_ptr_cast(str.GetData()), str.GetSize()); + break; + } + default: + throw InternalException("EmitPrimitiveNode: unhandled VariantLogicalType (%d)", static_cast(type_id)); + } + } +}; + +//===--------------------------------------------------------------------===// +// Emit (source: a VariantNode-like cursor) +//===--------------------------------------------------------------------===// +//! Collect the (non-missing) object children of a node in lexicographic key order +template +auto CollectObjectChildren(const NODE &it) { + auto object = it.GetObjectChildren(VariantIterationOrder::LEXICOGRAPHIC); + using EntryT = std::decay_t; + vector children; + for (auto &entry : object) { + children.push_back(entry); + } + return children; +} + +//! Traverse a VariantNode-like cursor 'it' (any type exposing the node concept) into the builder. +template +void EmitIterator(const NODE &it, VariantBuilder &builder) { + if (it.IsNull() || it.IsMissing()) { + builder.EmitNull(); + return; + } + + auto type_id = it.GetTypeId(); + switch (type_id) { + case VariantLogicalType::OBJECT: { + auto children = CollectObjectChildren(it); + builder.EmitObject( + children.size(), [&](idx_t i) { return children[i].key; }, + [&](idx_t i) { EmitIterator(children[i].value, builder); }); + break; + } + case VariantLogicalType::ARRAY: { + auto array = it.GetArrayChildren(); + builder.EmitArray(array.size(), [&](idx_t i) { EmitIterator(array[i], builder); }); + break; + } + default: + builder.EmitPrimitiveNode(it, type_id); + break; + } +} + +//===--------------------------------------------------------------------===// +// Build driver +//===--------------------------------------------------------------------===// +//! Build the canonical (unshredded) VARIANT 'result' vector for 'count' rows by emitting each row of +//! 'source' (which provides 'bool Emit(idx_t row, VariantBuilder &builder)' returning whether the row +//! is a SQL NULL) into a shared VariantBuilder, then materializing the accumulated buffers. +template +void BuildVariant(SOURCE &source, idx_t count, Vector &result) { + if (count == 0) { + return; + } + + auto &keys = VariantVector::GetKeys(result); + auto &keys_entry = ListVector::GetChildMutable(keys); + auto &children = VariantVector::GetChildren(result); + auto &values = VariantVector::GetValues(result); + auto &blob_vector = VariantVector::GetData(result); + auto blob_writer = FlatVector::Writer(blob_vector, count); + + //! The dictionary is backed by the keys vector's string allocator so the finalized keys are owned + //! by the result (see FinalizeVariantKeys). + OrderedOwningStringMap dictionary(StringVector::GetStringAllocator(keys_entry)); + VariantBuilder builder(dictionary); + + vector keys_entries(count); + vector children_entries(count); + vector values_entries(count); + + for (idx_t row = 0; row < count; row++) { + builder.BeginRow(); + bool is_null = source.Emit(row, builder); + blob_writer.WriteValue(string_t(builder.blob.data(), NumericCast(builder.blob.size()))); + if (is_null) { + //! SPEC: If a Variant is missing in a context where a value is required, readers must return a Variant null + FlatVector::SetNull(result, row, true); + } + keys_entries[row] = list_entry_t(builder.row_keys, builder.LocalKey()); + children_entries[row] = list_entry_t(builder.row_children, builder.LocalChild()); + values_entries[row] = list_entry_t(builder.row_values, builder.LocalValue()); + } + + auto total_keys = builder.key_slots.size(); + auto total_children = builder.child_value_ids.size(); + auto total_values = builder.type_ids.size(); + + //! Size the list child vectors now that the totals are known + ListVector::Reserve(keys, total_keys); + ListVector::SetListSize(keys, total_keys); + ListVector::Reserve(children, total_children); + ListVector::SetListSize(children, total_children); + ListVector::Reserve(values, total_values); + ListVector::SetListSize(values, total_values); + + VariantVectorData variant_data(result); + for (idx_t row = 0; row < count; row++) { + variant_data.keys_data[row] = keys_entries[row]; + variant_data.children_data[row] = children_entries[row]; + variant_data.values_data[row] = values_entries[row]; + } + + //! values + if (total_values) { + memcpy(variant_data.type_ids_data, builder.type_ids.data(), total_values * sizeof(uint8_t)); + memcpy(variant_data.byte_offset_data, builder.byte_offsets.data(), total_values * sizeof(uint32_t)); + } + + //! children + for (idx_t i = 0; i < total_children; i++) { + variant_data.values_index_data[i] = builder.child_value_ids[i]; + if (builder.child_key_ids[i] == VARIANT_INVALID_KEY) { + variant_data.keys_index_validity.SetInvalid(i); + } else { + variant_data.keys_index_data[i] = builder.child_key_ids[i]; + } + } + + //! keys: map each key slot to its (unsorted) dictionary index, then finalize (sort + remap) + SelectionVector keys_selvec(total_keys); + for (idx_t i = 0; i < total_keys; i++) { + keys_selvec.set_index(i, builder.key_slots[i]); + } + VariantUtils::FinalizeVariantKeys(result, dictionary, keys_selvec, total_keys); + keys_entry.Slice(keys_selvec, total_keys); + + FlatVector::SetSize(result, count); + result.Verify(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/variant_iterator.hpp b/src/duckdb/src/include/duckdb/common/types/variant_iterator.hpp index 744d971ff..67cbd92d9 100644 --- a/src/duckdb/src/include/duckdb/common/types/variant_iterator.hpp +++ b/src/duckdb/src/include/duckdb/common/types/variant_iterator.hpp @@ -103,6 +103,9 @@ class VariantNode; class VariantIterator { public: explicit VariantIterator(const Vector &variant); + //! Build directly from the separate unshredded + shredded components (the STRUCT(unshredded, shredded) + //! intermediate produced during unshredding, which is not itself a SHREDDED_VECTOR) + VariantIterator(const Vector &unshredded, const Vector &shredded); public: //! Whether the row is a (SQL) NULL variant @@ -152,15 +155,25 @@ class VariantNode { //! The logical type of the value the cursor points at VariantLogicalType GetTypeId() const; - //! Returns the fixed-width primitive payload loaded as T (e.g. GetData()) + //! Returns the fixed-width primitive payload loaded as T (e.g. GetData()). For a DECIMAL the + //! payload follows the (width, scale) prefix, so T must match the physical type implied by the width. template T GetData() const { + if (GetTypeId() == VariantLogicalType::DECIMAL) { + return Load(GetDecimal().value_ptr); + } return Load(GetDataPointer()); } - //! Returns a pointer to the raw payload of a fixed-width primitive value - const_data_ptr_t GetDataPointer() const; //! Returns the (variable-length) string payload of a VARCHAR/BLOB/BIGNUM/GEOMETRY/BITSTRING value string_t GetString() const; + //! Returns the (width, scale) of a DECIMAL value + VariantDecimalProperties GetDecimalProperties() const { + auto decimal = GetDecimal(); + return VariantDecimalProperties(decimal.width, decimal.scale); + } + + //! Returns a pointer to the raw payload of a fixed-width primitive value + const_data_ptr_t GetDataPointer() const; //! Returns the decimal payload of a DECIMAL value VariantDecimalData GetDecimal() const; diff --git a/src/duckdb/src/include/duckdb/common/types/variant_value.hpp b/src/duckdb/src/include/duckdb/common/types/variant_value.hpp deleted file mode 100644 index e9690012c..000000000 --- a/src/duckdb/src/include/duckdb/common/types/variant_value.hpp +++ /dev/null @@ -1,80 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/variant_value.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/map.hpp" -#include "duckdb/common/vector.hpp" -#include "duckdb/common/types/value.hpp" - -namespace duckdb_yyjson { -struct yyjson_mut_doc; -struct yyjson_mut_val; -} // namespace duckdb_yyjson - -namespace duckdb { - -enum class VariantValueType : uint8_t { PRIMITIVE, OBJECT, ARRAY, MISSING }; - -struct VariantValue { -public: - VariantValue() : value_type(VariantValueType::MISSING) { - } - explicit VariantValue(VariantValueType type) : value_type(type) { - } - explicit VariantValue(Value &&val) : value_type(VariantValueType::PRIMITIVE), primitive_value(std::move(val)) { - } - // Delete copy constructor and copy assignment operator - VariantValue(const VariantValue &) = delete; - VariantValue &operator=(const VariantValue &) = delete; - - // Default move constructor and move assignment operator - VariantValue(VariantValue &&) noexcept = default; - VariantValue &operator=(VariantValue &&) noexcept = default; - -public: - bool IsNull() const { - return value_type == VariantValueType::PRIMITIVE && primitive_value.IsNull(); - } - bool IsMissing() const { - return value_type == VariantValueType::MISSING; - } - - static VariantValue NullValue() { - return VariantValue(Value(LogicalType::SQLNULL)); - } - - //! Convert a (non-null) VARIANT-typed Value back to a plain Value - static Value GetValue(const Value &variant_val); - -public: - void AddChild(const string &key, VariantValue &&val); - void AddItem(VariantValue &&val); - - void SetItems(vector &&values); - void ReserveItems(idx_t count); - void AddItems(vector::iterator begin, vector::iterator end); - map TakeObjectChildren(); - const map &ObjectChildren() const; - const vector &ArrayItems() const; - -public: - duckdb_yyjson::yyjson_mut_val *ToJSON(ClientContext &context, duckdb_yyjson::yyjson_mut_doc *doc) const; - static void ToVARIANT(vector &input, Vector &result); - -public: - VariantValueType value_type; - Value primitive_value; - -private: - //! FIXME: how can we get a deterministic child order for a partially shredded object? - map object_children; - vector array_items; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector/dictionary_vector.hpp b/src/duckdb/src/include/duckdb/common/vector/dictionary_vector.hpp index 047687a1c..5afa22ef3 100644 --- a/src/duckdb/src/include/duckdb/common/vector/dictionary_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector/dictionary_vector.hpp @@ -22,6 +22,9 @@ class DictionaryEntry { Vector data; //! Optional id to uniquely identify re-occurring dictionaries string id; + //! True iff the producer wraps this same entry in every output chunk for its lifetime (stable id, no flat + //! fall-through), making it a global dictionary. Set only via CreateReusableGlobalDictionary. + bool global_dictionary = false; //! For caching the hashes of a child buffer (mutable: cache is logically const) mutable mutex cached_hashes_lock; mutable unique_ptr cached_hashes; @@ -151,6 +154,15 @@ struct DictionaryVector { return DictionarySize(vector).IsValid() && !DictionaryId(vector).empty() && CanCacheHashes(vector.GetType()); } static buffer_ptr CreateReusableDictionary(const LogicalType &type, const idx_t &size); + //! Mint a reusable dictionary entry whose lifetime spans the entire producing operator instance + static buffer_ptr CreateReusableGlobalDictionary(const LogicalType &type, const idx_t &size); + //! True iff vector is a DICTIONARY_VECTOR whose entry is a global dictionary + static inline bool IsGlobalDictionary(const Vector &vector) { + if (vector.GetVectorType() != VectorType::DICTIONARY_VECTOR) { + return false; + } + return vector.Buffer().Cast().GetEntry().global_dictionary; + } static const Vector &GetCachedHashes(const Vector &input); }; diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp index 981e715f2..671f13892 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp @@ -15,7 +15,7 @@ namespace duckdb { enum class VerifyExistenceType : uint8_t { APPEND = 0, APPEND_FK = 1, DELETE_FK = 2 }; -enum class ARTConflictType : uint8_t { NO_CONFLICT = 0, CONSTRAINT = 1, TRANSACTION = 2 }; +enum class ARTConflictType : uint8_t { NO_CONFLICT = 0, CONSTRAINT = 1 }; enum class ARTHandlingResult : uint8_t { CONTINUE = 0, SKIP = 1, YIELD = 2, NONE = 3 }; class ConflictManager; diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp index d882e3ba7..f6d8673a2 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art_operator.hpp @@ -152,13 +152,18 @@ class ARTOperator { status = GateStatus::GATE_SET; continue; } - // Unique indexes can have duplicates, if another transaction DELETE + INSERT - // the same key. In that case, the previous value must be kept alive until all - // other transactions do not depend on it anymore. - - // We restrict this transactionality to two-value leaves, so any subsequent - // incoming transaction must fail here. - return ARTConflictType::TRANSACTION; + // A unique ART may temporarily contain a gated two-row leaf during commit for + // DELETE + INSERT of the same key: commit appends the new row first, then + // commit-delete cleanup removes the old row ID. No other main-ART append should + // enter during that window because commit-time main-index appends are serialized + // by the WAL lock or transaction manager commit lock. + // + // Local append and delete indexes should not contain such gates either. + // Note that VerifyLeaf may still legitimately observe the temporary duplicate + // leaf state. + throw FatalException("Corrupted unique ART index \"%s\": encountered an existing gated leaf in unique " + "index while inserting", + art.name); } const auto type = active_node.GetType(); diff --git a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp index 17aa86418..1572e87ee 100644 --- a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp +++ b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp @@ -59,6 +59,12 @@ struct JoinHTScanState { [POINTER] [POINTER] The pointers are either NULL + + Two-phase lifecycle: the constructor populates only layout-INDEPENDENT state; all layout-DEPENDENT state + (layout_ptr, row matchers, tuple_size/pointer_offset/entry_size, data_collection, sink_collection, dead_end, + dict_registry) is published by FinishInitWithLayout on the first build chunk, so slot widths can be chosen from + the data's actual runtime encoding. Until then the JHT is unusable except for the null-safe Count() / + SizeInBytes() accessors; the layout-dependent accessors assert IsLayoutFinalized(). */ class JoinHashTable { public: @@ -226,6 +232,17 @@ class JoinHashTable { optional_ptr predicate_ptr = nullptr, const vector &output_in_probe = {}); ~JoinHashTable(); + //! Initialize layout-dependent state from a layout shared across all per-thread JHTs (deferred ctor body) + void FinishInitWithLayout(shared_ptr published_layout, vector dict_index_width_p = {}); + //! True iff FinishInitWithLayout has populated layout-dependent state + bool IsLayoutFinalized() const { + return layout_ptr.get() != nullptr; + } + + //! Per-column index-width decision for the dict-surviving optimisation, consulted by the layout publisher on + //! the first build chunk. Returns the narrowed index byte width (1/2/4), or 0 to keep native width. + uint8_t GetDictSurvivingIndexWidth(idx_t build_col_idx, const Vector &incoming) const; + //! Add the given data to the HT void Build(PartitionedTupleDataAppendState &append_state, DataChunk &keys, DataChunk &input); //! Merge another HT into this one @@ -275,17 +292,21 @@ class JoinHashTable { } idx_t Count() const { - return data_collection->Count(); + return data_collection ? data_collection->Count() : 0; } idx_t SizeInBytes() const { - return data_collection->SizeInBytes(); + return data_collection ? data_collection->SizeInBytes() : 0; } PartitionedTupleData &GetSinkCollection() { + // Only valid after FinishInitWithLayout; assert so a premature access fails loudly, not as a null-deref. + D_ASSERT(IsLayoutFinalized()); return *sink_collection; } TupleDataCollection &GetDataCollection() { + // Only valid after FinishInitWithLayout (see GetSinkCollection). + D_ASSERT(IsLayoutFinalized()); return *data_collection; } //! Perform a full scan of a build column, filling the provided addresses vector and result vector. @@ -370,6 +391,12 @@ class JoinHashTable { bool use_dict_emission = false; //! Pre-materialized columnar data, one entry per RHS output column vector> dict_arrays; + //! Per build payload column: pinned upstream dict entry. Non-null means the row store carries a narrow dict + //! index for this column instead of the native value. + vector> dict_registry; + //! Per build payload column: byte width of the narrowed dict-index slot (0 = native, else 1/2/4). Parallel to + //! build_types; set by FinishInitWithLayout. + vector dict_index_width; //! Saved NEXT_PTR values, indexed by dict index; only allocated when chains_longer_than_one AllocatedData aux_next_ptrs; //! Typed pointer into aux_next_ptrs; set by BuildDictionaryArrays alongside the allocation @@ -597,6 +624,13 @@ class JoinHashTable { ProbeState &probe_state, DataChunk &probe_chunk, ProbeSpill &probe_spill, ProbeSpillLocalAppendState &spill_state, DataChunk &spill_chunk); +private: + //! True iff the residual predicate (if any) reads build payload column build_col_idx from its row slot + bool ColumnReferencedByResidual(idx_t build_col_idx) const; + //! Validate the incoming dict chunk and pin a self-owned copy of its dictionary into dict_registry on the first + //! chunk; on later chunks assert id continuity. Called per narrowed column from Build. + void PinDictSurvivingColumn(idx_t build_col_idx, const Vector &incoming, uint8_t index_width); + private: //! The current number of radix bits used to partition idx_t radix_bits; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp index d8d31d5d3..5c53acb77 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_copy_to_file.hpp @@ -19,6 +19,7 @@ namespace duckdb { struct GlobalFileState; +struct FileStateHandle; struct BoundOrderByNode; struct CopyToFileInfo { @@ -38,6 +39,7 @@ class PhysicalCopyToFile : public PhysicalOperator { public: PhysicalCopyToFile(PhysicalPlan &physical_plan, vector types, CopyFunction function, unique_ptr bind_data, idx_t estimated_cardinality); + ~PhysicalCopyToFile() override; public: InsertionOrderPreservingMap ParamsToString() const override; @@ -50,16 +52,15 @@ class PhysicalCopyToFile : public PhysicalOperator { bool Rotate() const; - void PrepareAndFlushBatch(ClientContext &context, GlobalSinkState &gstate_p, - unique_ptr &file_state_ptr, - const std::function()> &create_file_state_fun, + void PrepareAndFlushBatch(ClientContext &context, GlobalSinkState &gstate_p, FileStateHandle &file_state, + const std::function &create_file_state_fun, unique_ptr batch) const; pair> - PrepareBatch(ClientContext &context, GlobalSinkState &gstate_p, unique_ptr &file_state_ptr, - const std::function()> &create_file_state_fun, + PrepareBatch(ClientContext &context, GlobalSinkState &gstate_p, FileStateHandle &file_state, + const std::function &create_file_state_fun, unique_ptr batch) const; - void FlushBatch(ClientContext &context, GlobalSinkState &gstate_p, unique_ptr &file_state_ptr, - const std::function()> &create_file_state_fun, + void FlushBatch(ClientContext &context, GlobalSinkState &gstate_p, FileStateHandle &file_state, + const std::function &create_file_state_fun, const CopyFunctionBatchAnalyzer &batch_analyzer, unique_ptr prepared_batch) const; diff --git a/src/duckdb/src/include/duckdb/execution/physical_operator.hpp b/src/duckdb/src/include/duckdb/execution/physical_operator.hpp index e7872e34a..c2f42825c 100644 --- a/src/duckdb/src/include/duckdb/execution/physical_operator.hpp +++ b/src/duckdb/src/include/duckdb/execution/physical_operator.hpp @@ -26,6 +26,7 @@ namespace duckdb { +class DictionaryEntry; class Event; class Executor; class PhysicalOperator; @@ -246,6 +247,13 @@ class PhysicalOperator { } }; +//! A cached column that arrived as a global dictionary: the pinned upstream entry is kept and +//! per-chunk selection indices concatenated, so the dictionary survives the cache instead of flattening +struct CachedDictColumn { + buffer_ptr entry; + SelectionVector accumulated_sel; +}; + //! Contains state for the CachingPhysicalOperator class CachingOperatorState : public OperatorState { public: @@ -261,6 +269,13 @@ class CachingOperatorState : public OperatorState { can_cache_chunk = OperatorCachingMode::NONE; must_return_continuation_chunk = false; cached_result = OperatorResultType::NEED_MORE_INPUT; + ResetDictCache(); + } + + //! Drop the dictionary accumulators, returning the cache to plain flat caching + void ResetDictCache() { + dict_columns.clear(); + dict_cache_active = false; } unique_ptr cached_chunk; @@ -269,6 +284,11 @@ class CachingOperatorState : public OperatorState { OperatorCachingMode can_cache_chunk = OperatorCachingMode::NONE; bool must_return_continuation_chunk = false; OperatorResultType cached_result; + + //! One slot per cached column. Invariant: entry != null iff the column is accumulating a global + //! dictionary; entry == null iff plain flat caching (the common case) + vector dict_columns; + bool dict_cache_active = false; }; //! Base class that caches output from child Operator class. Note that Operators inheriting from this class should also diff --git a/src/duckdb/src/include/duckdb/function/aggregate/list_aggregate.hpp b/src/duckdb/src/include/duckdb/function/aggregate/list_aggregate.hpp index f8f6fdfc6..97d033144 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate/list_aggregate.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate/list_aggregate.hpp @@ -55,7 +55,47 @@ inline void ListUpdateFunction(Vector inputs[], AggregateInputData &aggr_input_d } auto &state = *states[i].GetValue(); aggr_input_data.allocator.AlignNext(); - functions.AppendRow(aggr_input_data.allocator, state.linked_list, input_data, i); + functions.AppendRows(aggr_input_data.allocator, state.linked_list, input_data, i, 1); + } +} + +//! Clustered variant of ListUpdateFunction - appends the rows of each run to that run's state. +//! Contiguous runs are appended in a single batch; scattered runs are appended row by row. +template +inline void ListClusterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, + const ClusteredAggr &clustered, idx_t count) { + D_ASSERT(input_count == 1); + auto &input = inputs[0]; + RecursiveUnifiedVectorFormat input_data; + Vector::RecursiveToUnifiedFormat(input, input_data); + + ListSegmentFunctions functions; + GetSegmentDataFunctions(functions, input.GetType()); + + for (idx_t run_idx = 0; run_idx < clustered.n_group_runs; run_idx++) { + auto &run = clustered.group_runs[run_idx]; + auto &state = *reinterpret_cast(run.state); + auto run_sel = run.sel; + + if (!IGNORE_NULLS && !run_sel) { + // contiguous run covering [0, run.count) without NULL filtering - append in a single batch + aggr_input_data.allocator.AlignNext(); + functions.AppendRows(aggr_input_data.allocator, state.linked_list, input_data, 0, run.count); + continue; + } + + // scattered run and/or NULL filtering - append the rows one by one + for (idx_t k = 0; k < run.count; k++) { + idx_t entry_idx = run_sel ? run_sel[k] : k; + if (IGNORE_NULLS) { + const auto idx = input_data.unified.sel->get_index(entry_idx); + if (!input_data.unified.validity.RowIsValid(idx)) { + continue; + } + } + aggr_input_data.allocator.AlignNext(); + functions.AppendRows(aggr_input_data.allocator, state.linked_list, input_data, entry_idx, 1); + } } } diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/json_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/json_to_variant.hpp index b1f10305f..e04d3b491 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/json_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/json_to_variant.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/types/decimal.hpp" #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/operator/cast_operators.hpp" namespace duckdb { namespace variant { @@ -132,6 +133,14 @@ static bool ConvertJSONObject(yyjson_val *obj, ToVariantGlobalResultData &result return true; } +namespace { + +static inline string_t GetString(yyjson_val *val) { + return string_t(unsafe_yyjson_get_str(val), NumericCast(unsafe_yyjson_get_len(val))); +} + +} // namespace + template static bool ConvertJSONPrimitive(yyjson_val *val, ToVariantGlobalResultData &result, idx_t result_index, bool is_root) { auto json_tag = unsafe_yyjson_get_tag(val); @@ -143,8 +152,7 @@ static bool ConvertJSONPrimitive(yyjson_val *val, ToVariantGlobalResultData &res switch (json_tag) { case YYJSON_TYPE_STR | YYJSON_SUBTYPE_NOESC: - case YYJSON_TYPE_STR | YYJSON_SUBTYPE_NONE: - case YYJSON_TYPE_RAW | YYJSON_SUBTYPE_NONE: { + case YYJSON_TYPE_STR | YYJSON_SUBTYPE_NONE: { WriteVariantMetadata(result, result_index, values_offset_data, blob_offset_data[result_index], nullptr, 0, VariantLogicalType::VARCHAR); uint32_t length = NumericCast(unsafe_yyjson_get_len(val)); @@ -188,11 +196,20 @@ static bool ConvertJSONPrimitive(yyjson_val *val, ToVariantGlobalResultData &res blob_offset_data[result_index] += sizeof(int64_t); break; } + case YYJSON_TYPE_RAW | YYJSON_SUBTYPE_NONE: case YYJSON_TYPE_NUM | YYJSON_SUBTYPE_REAL: { WriteVariantMetadata(result, result_index, values_offset_data, blob_offset_data[result_index], nullptr, 0, VariantLogicalType::DOUBLE); + double value; if (WRITE_DATA) { - auto value = unsafe_yyjson_get_real(val); + if (json_tag == (YYJSON_TYPE_RAW | YYJSON_SUBTYPE_NONE)) { + auto success = TryCast::Operation(GetString(val), value, true); + if (!success) { + return false; + } + } else { + value = unsafe_yyjson_get_real(val); + } memcpy(blob_data + blob_offset_data[result_index], const_data_ptr_cast(&value), sizeof(double)); } blob_offset_data[result_index] += sizeof(double); diff --git a/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp b/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp index cfc948a31..c71dd269f 100644 --- a/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/variant/variant_to_variant.hpp @@ -263,6 +263,8 @@ bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalRe auto &result = result_data.variant; for (idx_t source_index = 0; source_index < count; source_index++) { + //! Map the loop index through the incoming selection to the actual source row. + const auto scan_index = source_data.GetMappedIndex(source_index); auto result_index = selvec ? selvec->get_index(source_index) : source_index; auto &keys_list_entry = result.keys_data[result_index]; @@ -275,7 +277,7 @@ bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalRe uint32_t keys_count = 0; uint32_t blob_size = 0; - if (!source.RowIsValid(source_index)) { + if (!source.RowIsValid(scan_index)) { if (!IGNORE_NULLS) { HandleVariantNull(result_data, result_index, values_offset_data, blob_offset, values_index_selvec, source_index, is_root); @@ -293,23 +295,23 @@ bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalRe //! First write all children //! NOTE: this has to happen first because we use 'values_offset', which is increased when we write the values - auto source_children_list_entry = source.GetChildrenListEntry(source_index); + auto source_children_list_entry = source.GetChildrenListEntry(scan_index); for (idx_t source_children_index = 0; source_children_index < source_children_list_entry.length; source_children_index++) { //! values_index if (WRITE_DATA) { auto &values_offset = values_offset_data[result_index]; - auto source_value_index = source.GetValuesIndex(source_index, source_children_index); + auto source_value_index = source.GetValuesIndex(scan_index, source_children_index); result.values_index_data[children_list_entry.offset + children_offset + source_children_index] = values_offset + source_value_index; } //! keys_index - if (source.KeysIndexIsValid(source_index, source_children_index)) { + if (source.KeysIndexIsValid(scan_index, source_children_index)) { if (WRITE_DATA) { //! Look up the existing key from 'source' - auto source_key_index = source.GetKeysIndex(source_index, source_children_index); - auto &source_key_value = source.GetKey(source_index, source_key_index); + auto source_key_index = source.GetKeysIndex(scan_index, source_children_index); + auto &source_key_value = source.GetKey(scan_index, source_key_index); //! Now write this key to the dictionary of the result auto dict_index = result_data.GetOrCreateIndex(source_key_value); @@ -326,26 +328,25 @@ bool ConvertVariantToVariant(ToVariantSourceData &source_data, ToVariantGlobalRe } } - auto source_values_list_entry = source.GetValuesListEntry(source_index); + auto source_values_list_entry = source.GetValuesListEntry(scan_index); if (WRITE_DATA) { WriteState write_state(keys_offset, children_offset, blob_offset, blob_data, blob_size); for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; source_value_index++) { - auto source_type_id = source.GetTypeId(source_index, source_value_index); + auto source_type_id = source.GetTypeId(scan_index, source_value_index); WriteVariantMetadata(result_data, result_index, values_offset_data, blob_offset + blob_size, nullptr, 0, source_type_id); - VariantVisitor::Visit(source, source_index, source_value_index, - write_state); + VariantVisitor::Visit(source, scan_index, source_value_index, write_state); } } else { AnalyzeState analyze_state(children_offset); for (uint32_t source_value_index = 0; source_value_index < source_values_list_entry.length; source_value_index++) { values_offset_data[result_index]++; - blob_size += VariantVisitor::Visit(source, source_index, - source_value_index, analyze_state); + blob_size += VariantVisitor::Visit(source, scan_index, source_value_index, + analyze_state); } } diff --git a/src/duckdb/src/include/duckdb/function/function_binder.hpp b/src/duckdb/src/include/duckdb/function/function_binder.hpp index c8f11b3ff..d1a1fd702 100644 --- a/src/duckdb/src/include/duckdb/function/function_binder.hpp +++ b/src/duckdb/src/include/duckdb/function/function_binder.hpp @@ -14,11 +14,19 @@ #include "duckdb/function/window_function.hpp" #include "duckdb/function/function_set.hpp" #include "duckdb/common/error_data.hpp" +#include "duckdb/common/enums/order_type.hpp" namespace duckdb { class WindowFunctionCatalogEntry; +//! One ORDER BY key of an exported ordered aggregate state: the buffered struct column it sorts on and the modifiers. +struct SortedAggregateStateOrder { + idx_t column; + OrderType order_type; + OrderByNullType null_order; +}; + //! The FunctionBinder class is responsible for binding functions class FunctionBinder { public: @@ -139,6 +147,21 @@ class FunctionBinder { optional_ptr> grouping_sets); DUCKDB_API static void BindSortedAggregate(ClientContext &context, BoundWindowExpression &expr); + //! Computes the exported buffer layout of an ordered aggregate: the struct of buffered columns (arguments first, + //! then any appended sort keys), the per-key column + modifiers, and the number of leading argument columns. + //! Mirrors the matching done by the sorted aggregate bind data so the export type matches the runtime buffer. + DUCKDB_API static void GetSortedAggregateStateLayout(const BoundAggregateExpression &expr, + LogicalType &buffer_struct, + vector &orders, + idx_t &argument_count); + //! Reconstructs a sorted aggregate wrapper from an exported buffer state so finalize/combine operate on the buffer: + //! finalize sorts by the keys and runs the (already re-bound) inner aggregate, combine concatenates buffers. + //! Returns the wrapper function and its bind data. + DUCKDB_API static pair> + BindSortedAggregateState(ClientContext &context, const BoundAggregateFunction &inner_function, + unique_ptr inner_bind_info, const LogicalType &buffer_struct, + const vector &orders, idx_t argument_count); + DUCKDB_API unique_ptr BindWindowFunction(const WindowFunction &function, vector> children, vector>> keyword_args, vector &orders, diff --git a/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp index f1289cc60..606a7394a 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp @@ -115,6 +115,26 @@ struct InternalCompressStringHugeintFun { static ScalarFunction GetFunction(); }; +struct InternalCompressGeometryPointFun { + static constexpr const char *Name = "__internal_compress_geometry_point"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = ""; + static constexpr const char *Example = ""; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + +struct InternalDecompressGeometryPointFun { + static constexpr const char *Name = "__internal_decompress_geometry_point"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = ""; + static constexpr const char *Example = ""; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + struct InternalDecompressIntegralSmallintFun { static constexpr const char *Name = "__internal_decompress_integral_smallint"; static constexpr const char *Parameters = ""; diff --git a/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_utils.hpp b/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_utils.hpp index 50d9549c5..51bd0dce2 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_utils.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_utils.hpp @@ -41,4 +41,12 @@ struct CMStringDecompressFun { static ScalarFunction GetFunction(const LogicalType &input_type); }; +struct CMGeometryPointCompressFun { + static ScalarFunction GetFunction(); +}; + +struct CMGeometryPointDecompressFun { + static ScalarFunction GetFunction(); +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp index fe5f98b86..ce1f2dc7a 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/variant_utils.hpp @@ -15,6 +15,7 @@ #include "duckdb/common/owning_string_map.hpp" namespace duckdb { +class VariantIterator; struct VariantPathBindData : public FunctionData { public: @@ -151,6 +152,10 @@ struct VariantUtils { //! Whether or not a type is natively supported in variant DUCKDB_API static bool VariantSupportsType(const LogicalType &type); + + //! Build a canonical (unshredded) VARIANT vector by traversing a variant directly through a + //! VariantIterator - avoids materializing the intermediate vector tree + DUCKDB_API static void ToVariant(const VariantIterator &state, idx_t count, Vector &result); }; struct VariantBindUtils { diff --git a/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp b/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp index a94b4e6ff..c3c40691b 100644 --- a/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp +++ b/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp @@ -23,7 +23,7 @@ struct ArrowArrayScanState; typedef void (*cast_arrow_duck_t)(ClientContext &context, Vector &source, Vector &result, idx_t count); -typedef void (*cast_duck_arrow_t)(ClientContext &context, Vector &source, Vector &result, idx_t count); +typedef void (*cast_duck_arrow_t)(ClientContext &context, const Vector &source, Vector &result, idx_t count); class ArrowTypeExtensionData { public: diff --git a/src/duckdb/src/include/duckdb/function/table_function.hpp b/src/duckdb/src/include/duckdb/function/table_function.hpp index 1ce75139a..712a1d0e8 100644 --- a/src/duckdb/src/include/duckdb/function/table_function.hpp +++ b/src/duckdb/src/include/duckdb/function/table_function.hpp @@ -23,6 +23,7 @@ #include "duckdb/function/partition_stats.hpp" #include "duckdb/common/exception/binder_exception.hpp" #include "duckdb/common/enums/order_preservation_type.hpp" +#include "duckdb/common/enums/statement_type.hpp" namespace duckdb { @@ -509,6 +510,9 @@ class TableFunction : public SimpleNamedParameterFunction { // NOLINT: work-arou //! Whether or not the table function supports late materialization bool late_materialization; TableFunctionReturnType return_type; + //! The return type used when this function is invoked through a CALL statement + //! By default a CALL returns a query result - functions that only have side effects can use NOTHING instead + StatementReturnType call_return_type = StatementReturnType::QUERY_RESULT; //! Additional function info, passed to the bind shared_ptr function_info; //! The order preservation type of the table function diff --git a/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp b/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp index f7470757d..1e8579241 100644 --- a/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp +++ b/src/duckdb/src/include/duckdb/function/variant/variant_shredding.hpp @@ -29,7 +29,7 @@ struct VariantColumnStatsData { idx_t total_count = 0; //! indices into the top-level 'columns' vector where the stats for the field/element live - case_insensitive_map_t field_stats; + unordered_map field_stats; idx_t element_stats = DConstants::INVALID_INDEX; }; diff --git a/src/duckdb/src/include/duckdb/main/client_config.hpp b/src/duckdb/src/include/duckdb/main/client_config.hpp index e05d0df27..bdd08ea7a 100644 --- a/src/duckdb/src/include/duckdb/main/client_config.hpp +++ b/src/duckdb/src/include/duckdb/main/client_config.hpp @@ -10,6 +10,7 @@ #include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/common/common.hpp" +#include "duckdb/common/identifier.hpp" #include "duckdb/common/optional_idx.hpp" #include "duckdb/common/enums/output_type.hpp" #include "duckdb/common/progress_bar/progress_bar.hpp" @@ -76,7 +77,7 @@ struct ClientConfig { LocalUserSettings user_settings; //! Variables set by the user - case_insensitive_map_t user_variables; + identifier_map_t user_variables; //! Function that is used to create the result collector for a materialized result. get_result_collector_t get_result_collector = nullptr; @@ -86,6 +87,7 @@ struct ClientConfig { static const ClientConfig &GetConfig(const ClientContext &context); void SetUserVariable(const String &name, Value value); + bool GetUserVariable(const Identifier &name, Value &result); bool GetUserVariable(const string &name, Value &result); void ResetUserVariable(const String &name); diff --git a/src/duckdb/src/include/duckdb/main/extension_entries.hpp b/src/duckdb/src/include/duckdb/main/extension_entries.hpp index 8ac05ab1f..fce288f20 100644 --- a/src/duckdb/src/include/duckdb/main/extension_entries.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_entries.hpp @@ -842,6 +842,7 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"var_pop", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"var_samp", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, {"variance", "core_functions", CatalogType::AGGREGATE_FUNCTION_ENTRY}, + {"variant_bytes_to_variant", "parquet", CatalogType::SCALAR_FUNCTION_ENTRY}, {"variant_to_parquet_variant", "parquet", CatalogType::SCALAR_FUNCTION_ENTRY}, {"vector_type", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, {"version", "core_functions", CatalogType::SCALAR_FUNCTION_ENTRY}, diff --git a/src/duckdb/src/include/duckdb/main/extension_load_options.hpp b/src/duckdb/src/include/duckdb/main/extension_load_options.hpp index efc87e8da..c0a9f1ae7 100644 --- a/src/duckdb/src/include/duckdb/main/extension_load_options.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_load_options.hpp @@ -14,6 +14,11 @@ namespace duckdb { struct ExtensionLoadOptions { + ExtensionLoadOptions() = default; + // NOLINTNEXTLINE: allow implicit conversion from the extension name + ExtensionLoadOptions(string extension_name) : extension_name(std::move(extension_name)) { + } + string extension_name; Identifier alias; }; diff --git a/src/duckdb/src/include/duckdb/main/prepared_statement.hpp b/src/duckdb/src/include/duckdb/main/prepared_statement.hpp index b2608f61c..b4d04d346 100644 --- a/src/duckdb/src/include/duckdb/main/prepared_statement.hpp +++ b/src/duckdb/src/include/duckdb/main/prepared_statement.hpp @@ -12,8 +12,10 @@ #include "duckdb/common/winapi.hpp" #include "duckdb/main/materialized_query_result.hpp" #include "duckdb/main/pending_query_result.hpp" +#include "duckdb/main/client_config.hpp" #include "duckdb/common/error_data.hpp" #include "duckdb/common/case_insensitive_map.hpp" +#include "duckdb/common/string_util.hpp" #include "duckdb/planner/expression/bound_parameter_data.hpp" namespace duckdb { @@ -112,14 +114,27 @@ class PreparedStatement { StringUtil::Join(excess_values, ", ")); } + static bool AllowsUserVariableFallback(const Identifier &identifier) { + auto &name = identifier.GetIdentifierName(); + if (name.empty()) { + return false; + } + return !StringUtil::CharacterIsDigit(name[0]); + } + template static string MissingValuesException(const identifier_map_t ¶meters, - const identifier_map_t &values) { + const identifier_map_t &values, ClientContext *context = nullptr) { // Missing values identifier_set_t missing_set; for (auto &pair : parameters) { auto &name = pair.first; if (!values.count(name)) { + Value variable_value; + if (context && AllowsUserVariableFallback(name) && + ClientConfig::GetConfig(*context).GetUserVariable(name, variable_value)) { + continue; + } missing_set.insert(name); } } @@ -127,28 +142,29 @@ class PreparedStatement { for (auto &val : missing_set) { missing_values.push_back(val); } - return StringUtil::Format("Values were not provided for the following prepared statement parameters: %s", + return StringUtil::Format("Values were not provided for the following parameters: %s", StringUtil::Join(missing_values, ", ")); } template - static void VerifyParameters(const identifier_map_t &provided, const identifier_map_t &expected) { - if (expected.size() == provided.size()) { - // Same amount of identifiers, if - for (auto &pair : expected) { - auto &identifier = pair.first; - if (!provided.count(identifier)) { - throw InvalidInputException(MissingValuesException(expected, provided)); - } + static void VerifyParameters(const identifier_map_t &provided, const identifier_map_t &expected, + ClientContext *context = nullptr) { + for (auto &pair : provided) { + if (!expected.count(pair.first)) { + throw InvalidInputException(ExcessValuesException(expected, provided)); } - return; } - // Mismatch in expected and provided parameters/values - if (expected.size() > provided.size()) { - throw InvalidInputException(MissingValuesException(expected, provided)); - } else { - D_ASSERT(provided.size() > expected.size()); - throw InvalidInputException(ExcessValuesException(expected, provided)); + for (auto &pair : expected) { + auto &identifier = pair.first; + if (provided.count(identifier)) { + continue; + } + Value variable_value; + if (context && AllowsUserVariableFallback(identifier) && + ClientConfig::GetConfig(*context).GetUserVariable(identifier, variable_value)) { + continue; + } + throw InvalidInputException(MissingValuesException(expected, provided, context)); } } diff --git a/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp b/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp index b09c183f3..5030408ca 100644 --- a/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp +++ b/src/duckdb/src/include/duckdb/main/prepared_statement_data.hpp @@ -53,8 +53,10 @@ class PreparedStatementData { void CheckParameterCount(idx_t parameter_count); //! Whether or not the prepared statement data requires the query to rebound for the given parameters bool RequireRebind(ClientContext &context, optional_ptr> values); + //! Fill in missing parameter values from user variables + void PopulateMissingParameterValues(ClientContext &context, identifier_map_t &values) const; //! Bind a set of values to the prepared statement data - DUCKDB_API void Bind(identifier_map_t values); + DUCKDB_API void Bind(ClientContext &context, const identifier_map_t &values); //! Get the expected SQL Type of the bound parameter DUCKDB_API LogicalType GetType(const Identifier &identifier); //! Try to get the expected SQL Type of the bound parameter diff --git a/src/duckdb/src/include/duckdb/main/settings.hpp b/src/duckdb/src/include/duckdb/main/settings.hpp index c97bde0f2..50eea65bf 100644 --- a/src/duckdb/src/include/duckdb/main/settings.hpp +++ b/src/duckdb/src/include/duckdb/main/settings.hpp @@ -585,6 +585,17 @@ struct DebugForceNoCrossProductSetting { static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); }; +struct DebugLocalFileSystemDelayMsSetting { + using RETURN_TYPE = idx_t; + static constexpr const char *Name = "debug_local_file_system_delay_ms"; + static constexpr const char *Description = + "DEBUG SETTING: time to sleep before local file system open/read/write operations"; + static constexpr const char *InputType = "UBIGINT"; + static constexpr const char *DefaultValue = "0"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::GLOBAL_ONLY; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); +}; + struct DebugOrderVerificationSetting { using RETURN_TYPE = DebugOrderVerification; static constexpr const char *Name = "debug_order_verification"; diff --git a/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp b/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp index 8757c895f..44242c165 100644 --- a/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/compressed_materialization.hpp @@ -17,6 +17,10 @@ namespace duckdb { class Optimizer; class ClientContext; class LogicalOperator; +class Expression; +class BoundColumnRefExpression; +class BoundFunctionExpression; +struct JoinCondition; enum class CompressedMaterializationType : uint8_t { INVALID = 0, @@ -85,6 +89,12 @@ typedef column_binding_map_t> statistics_map_t; //! but only if the data enters a materializing operator class CompressedMaterialization { private: + struct VariantJoinKeyInfo { + optional_ptr child; + LogicalType shredded_type; + unique_ptr typed_stats; + }; + //! Somewhat defensive constants that try to limit when compressed materialization is triggered for joins //! We only consider compressed materialization for joins when the build cardinality is greater than this static constexpr idx_t JOIN_BUILD_CARDINALITY_THRESHOLD = 1048576; @@ -112,9 +122,12 @@ class CompressedMaterialization { //! Adds bindings referenced in expression to referenced_bindings static void GetReferencedBindings(const Expression &expression, column_binding_set_t &referenced_bindings); + static bool IsVariantWrapperFunction(const BoundFunctionExpression &expr); + static optional_ptr TryGetVariantWrapperColumnRef(const Expression &expr); //! Updates CMBindingInfo in the binding_map in info void UpdateBindingInfo(CompressedMaterializationInfo &info, const ColumnBinding &binding, CompressedMaterializationType materialization_type); + optional_ptr GetVariantWrapperStats(const Expression &expr); //! Create (de)compress projections around the operator void CreateProjections(unique_ptr &op, CompressedMaterializationInfo &info); @@ -133,6 +146,14 @@ class CompressedMaterialization { unique_ptr GetCompressExpression(unique_ptr input, const BaseStatistics &stats); unique_ptr GetIntegralCompress(unique_ptr input, const BaseStatistics &stats); unique_ptr GetStringCompress(unique_ptr input, const BaseStatistics &stats); + unique_ptr GetGeometryCompress(unique_ptr input, const BaseStatistics &stats); + unique_ptr GetVariantCompress(unique_ptr input, const BaseStatistics &stats); + bool TryGetVariantJoinKeyInfo(const Expression &expr, VariantJoinKeyInfo &result); + static unique_ptr CastVariantJoinKeyStats(const VariantJoinKeyInfo &key_info, + const LogicalType &target_type); + unique_ptr CreateVariantJoinKeyCast(unique_ptr input, const LogicalType &shredded_type, + const LogicalType &target_type); + bool TryCompressVariantComparisonJoinKey(JoinCondition &condition); //! Create an expression that applies a scalar decompression function unique_ptr GetDecompressExpression(unique_ptr input, const LogicalType &result_type, @@ -141,6 +162,10 @@ class CompressedMaterialization { const BaseStatistics &stats); unique_ptr GetStringDecompress(unique_ptr input, const LogicalType &result_type, const BaseStatistics &stats); + unique_ptr GetGeometryDecompress(unique_ptr input, const LogicalType &result_type, + const BaseStatistics &stats); + unique_ptr GetVariantDecompress(unique_ptr input, const LogicalType &result_type, + const BaseStatistics &stats); private: Optimizer &optimizer; diff --git a/src/duckdb/src/include/duckdb/optimizer/projection_pullup.hpp b/src/duckdb/src/include/duckdb/optimizer/projection_pullup.hpp index d6240f407..bee770504 100644 --- a/src/duckdb/src/include/duckdb/optimizer/projection_pullup.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/projection_pullup.hpp @@ -9,13 +9,14 @@ namespace duckdb { class Optimizer; class LogicalOperator; -class ProjectionPullup { +class ProjectionPullup : public LogicalOperatorVisitor { public: explicit ProjectionPullup(Optimizer &optimizer_p, unique_ptr &root) : optimizer(optimizer_p), root(root) { } void Optimize(unique_ptr &op); + void VisitOperator(unique_ptr &op) override; private: Optimizer &optimizer; diff --git a/src/duckdb/src/include/duckdb/parallel/async_result.hpp b/src/duckdb/src/include/duckdb/parallel/async_result.hpp index 9915a6f60..73b1c81f1 100644 --- a/src/duckdb/src/include/duckdb/parallel/async_result.hpp +++ b/src/duckdb/src/include/duckdb/parallel/async_result.hpp @@ -44,11 +44,11 @@ class AsyncResult { AsyncResult &operator=(AsyncResultType t); AsyncResult &operator=(AsyncResult &&) noexcept; //! Schedule held async_tasks into the Executor, eventually unblocking InterruptState - //! needs to be called with non-emopty async_tasks and from BLOCKED state, will empty the async_tasks and transform + //! needs to be called with non-empty async_tasks and from BLOCKED state, will empty the async_tasks and transform //! into INVALID void ScheduleTasks(InterruptState &interrupt_state, Executor &executor); //! Execute tasks synchronously at callsite - //! needs to be called with non-emopty async_tasks and from BLOCKED state, will empty the async_tasks and transform + //! needs to be called with non-empty async_tasks and from BLOCKED state, will empty the async_tasks and transform //! into HAVE_MORE_OUTPUT void ExecuteTasksSynchronously(); diff --git a/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp b/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp index e460a2a94..4af010369 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/inlined_grammar.hpp @@ -1367,8 +1367,9 @@ const char INLINED_PEG_GRAMMAR[] = { "AtUnit <- VersionAtUnit / TimestampAtUnit\n" "VersionAtUnit <- 'VERSION'\n" "TimestampAtUnit <- 'TIMESTAMP'\n" - "JoinClause <- RegularJoinClause / JoinWithoutOnClause\n" + "JoinClause <- JoinByClause / RegularJoinClause / JoinWithoutOnClause\n" "RegularJoinClause <- Asof? JoinType? 'JOIN' TableRef JoinQualifier\n" + "JoinByClause <- 'JOIN' 'BY' Parens('TYPE' ColLabel) TableRef JoinQualifier\n" "Asof <- 'ASOF'\n" "JoinWithoutOnClause <- JoinPrefix 'JOIN' TableRef\n" "JoinQualifier <- OnClause / UsingClause\n" diff --git a/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp b/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp index ad5793360..50eb27c1f 100644 --- a/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp +++ b/src/duckdb/src/include/duckdb/parser/peg/transformer/peg_transformer.hpp @@ -2,7 +2,6 @@ #include "duckdb/parser/peg/ast/unpivot_name_values.hpp" #include "duckdb/parser/peg/transformer/parse_result.hpp" -#include "duckdb/parser/peg/transformer/transform_enum_result.hpp" #include "duckdb/parser/peg/transformer/transform_result.hpp" #include "duckdb/parser/peg/ast/add_column_entry.hpp" #include "duckdb/parser/peg/ast/column_constraint_entry.hpp" @@ -92,11 +91,9 @@ class PEGTransformer { PEGTransformer(ArenaAllocator &allocator, PEGTransformerState &state, const case_insensitive_map_t &transform_functions, - const case_insensitive_map_t &grammar_rules, - const case_insensitive_map_t> &enum_mappings, - ParserOptions &options_p) + const case_insensitive_map_t &grammar_rules, ParserOptions &options_p) : allocator(allocator), state(state), grammar_rules(grammar_rules), transform_functions(transform_functions), - enum_mappings(enum_mappings), options(options_p) { + options(options_p) { } public: @@ -143,23 +140,6 @@ class PEGTransformer { return Transform(child_parse_result); } - template - T TransformEnum(ParseResult &parse_result) { - auto enum_rule_name = parse_result.name; - - auto rule_value = enum_mappings.find(enum_rule_name); - if (rule_value == enum_mappings.end()) { - throw ParserException("Enum transform failed: could not find mapping for '%s'", enum_rule_name); - } - - auto *typed_enum_ptr = dynamic_cast *>(rule_value->second.get()); - if (!typed_enum_ptr) { - throw InternalException("Enum mapping for rule '%s' has an unexpected type.", enum_rule_name); - } - - return typed_enum_ptr->value; - } - template void TransformOptional(ListParseResult &list_pr, idx_t child_idx, T &target) { auto &opt = list_pr.Child(child_idx); @@ -216,7 +196,6 @@ class PEGTransformer { PEGTransformerState &state; const case_insensitive_map_t &grammar_rules; const case_insensitive_map_t &transform_functions; - const case_insensitive_map_t> &enum_mappings; identifier_map_t named_parameter_map; idx_t prepared_statement_parameter_index = 0; PreparedParamType last_param_type = PreparedParamType::INVALID; @@ -227,6 +206,7 @@ class PEGTransformer { vector> stored_cte_map; bool in_window_definition = false; + bool has_anonymous_parameters = false; friend class StackChecker; idx_t stack_depth = 0; @@ -337,26 +317,13 @@ class PEGTransformerFactory { // Registration methods void RegisterComment(); void RegisterCommon(); - void RegisterCreateMacro(); void RegisterCreateTable(); void RegisterExpression(); - void RegisterConnect(); void RegisterPivot(); void RegisterSelect(); void RegisterKeywordsAndIdentifiers(); - void RegisterEnums(); void RegisterGenerated(); -private: - template - void RegisterEnum(const string &rule_name, T value) { - auto existing_rule = enum_mappings.find(rule_name); - if (existing_rule != enum_mappings.end()) { - throw InternalException("EnumRule %s already exists", rule_name); - } - enum_mappings[rule_name] = make_uniq>(value); - } - template void Register(const string &rule_name, FUNC function) { auto existing_rule = sql_transform_functions.find(rule_name); @@ -3382,6 +3349,10 @@ class PEGTransformerFactory { const optional &join_type, unique_ptr table_ref, JoinQualifier join_qualifier); + static unique_ptr TransformJoinByClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result); + static unique_ptr TransformJoinByClause(PEGTransformer &transformer, const string &col_label, + unique_ptr table_ref, JoinQualifier join_qualifier); static unique_ptr TransformAsofInternal(PEGTransformer &transformer, ParseResult &parse_result); static bool TransformAsof(PEGTransformer &transformer); @@ -3821,7 +3792,6 @@ class PEGTransformerFactory { private: PEGParser parser; case_insensitive_map_t sql_transform_functions; - case_insensitive_map_t> enum_mappings; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/peg/transformer/transform_enum_result.hpp b/src/duckdb/src/include/duckdb/parser/peg/transformer/transform_enum_result.hpp deleted file mode 100644 index f8782a539..000000000 --- a/src/duckdb/src/include/duckdb/parser/peg/transformer/transform_enum_result.hpp +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -#include "duckdb/common/common.hpp" - -namespace duckdb { -struct TransformEnumValue { - virtual ~TransformEnumValue() = default; -}; - -template -struct TypedTransformEnumResult : public TransformEnumValue { - explicit TypedTransformEnumResult(T value_p) : value(std::move(value_p)) { - } - T value; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/sql_statement.hpp b/src/duckdb/src/include/duckdb/parser/sql_statement.hpp index d2cbb25e1..53c3a4ddb 100644 --- a/src/duckdb/src/include/duckdb/parser/sql_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/sql_statement.hpp @@ -35,6 +35,8 @@ class SQLStatement { idx_t stmt_length = 0; //! The map of named parameter to param index identifier_map_t named_param_map; + //! Whether the statement contains any anonymous (? or $N) parameters + bool has_anonymous_parameters = false; //! The query text that corresponds to this SQL statement string query; diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index 1cf246e28..57ac66174 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -82,6 +82,7 @@ class IndexVector; enum class BindingMode : uint8_t { STANDARD_BINDING, + PREPARE, EXTRACT_NAMES, EXTRACT_REPLACEMENT_SCANS, EXTRACT_QUALIFIED_NAMES diff --git a/src/duckdb/src/include/duckdb/planner/collation_binding.hpp b/src/duckdb/src/include/duckdb/planner/collation_binding.hpp index aa82a71a2..867397266 100644 --- a/src/duckdb/src/include/duckdb/planner/collation_binding.hpp +++ b/src/duckdb/src/include/duckdb/planner/collation_binding.hpp @@ -16,14 +16,17 @@ struct MapCastInfo; struct MapCastNode; struct DBConfig; -typedef bool (*try_push_collation_t)(ClientContext &context, unique_ptr &source, - const LogicalType &sql_type, CollationType type); +//! Returns the (ordered) list of scalar functions that need to be applied to a value of the given type to make it +//! byte-comparable under its collation. Returns an empty list if no collation needs to be applied. +typedef vector (*get_collation_functions_t)(ClientContext &context, const LogicalType &sql_type, + CollationType type); struct CollationCallback { - explicit CollationCallback(try_push_collation_t try_push_collation_p) : try_push_collation(try_push_collation_p) { + explicit CollationCallback(get_collation_functions_t get_collation_functions_p) + : get_collation_functions(get_collation_functions_p) { } - try_push_collation_t try_push_collation; + get_collation_functions_t get_collation_functions; }; class CollationBinding { diff --git a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp index b082d5810..a20e4f598 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alp/alp_scan.hpp @@ -124,7 +124,13 @@ struct AlpScanState : public SegmentScanState { // Load the offset (metadata) indicating where the vector data starts metadata_ptr -= AlpConstants::METADATA_POINTER_SIZE; auto data_byte_offset = Load(metadata_ptr); - D_ASSERT(data_byte_offset < segment.GetBlockSize()); + const auto block_size = segment.GetBlockSize(); + + if (data_byte_offset >= block_size) { + throw IOException( + "Corrupted ALP segment: stored data_byte_offset (%d) exceeds the segments block size (%d)", + data_byte_offset, block_size); + } idx_t vector_size = MinValue((idx_t)AlpConstants::ALP_VECTOR_SIZE, (count - total_value_count)); @@ -138,7 +144,14 @@ struct AlpScanState : public SegmentScanState { if (uncompressed_mode) { if (!SKIP) { // Read uncompressed values - memcpy(value_buffer, vector_ptr, sizeof(T) * vector_size); + const idx_t value_buffer_copy_size = sizeof(T) * vector_size; + if (vector_ptr + value_buffer_copy_size > segment_data + block_size) { + const auto bytes_remaining_in_block = (segment_data + block_size) - vector_ptr; + throw IOException("Corrupted ALP segment: stored vector_size is invalid, to-copy bytes (%d) " + "would exceed bytes remaining in the block (%d)", + value_buffer_copy_size, bytes_remaining_in_block); + } + memcpy(value_buffer, vector_ptr, value_buffer_copy_size); } return; } @@ -154,21 +167,54 @@ struct AlpScanState : public SegmentScanState { vector_state.bit_width = Load(vector_ptr); vector_ptr += AlpConstants::BIT_WIDTH_SIZE; - D_ASSERT(vector_state.exceptions_count <= vector_size); - D_ASSERT(vector_state.v_factor <= vector_state.v_exponent); - D_ASSERT(vector_state.bit_width <= sizeof(uint64_t) * 8); + if (vector_state.exceptions_count > vector_size) { + throw IOException("Corrupted ALP segment: exceptions_count (%d) exceeds vector_size (%d)", + vector_state.exceptions_count, vector_size); + } + if (vector_state.v_factor > vector_state.v_exponent) { + throw IOException("Corrupted ALP segment: v_factor (%d) exceeds v_exponent (%d)", vector_state.v_factor, + vector_state.v_exponent); + } + if (vector_state.bit_width > sizeof(uint64_t) * 8) { + throw IOException("Corrupted ALP segment: Invalid bit_width encountered: %d", vector_state.bit_width); + } + idx_t read_bytes = 0; if (vector_state.bit_width > 0) { auto bp_size = BitpackingPrimitives::GetRequiredSize(vector_size, vector_state.bit_width); + + const idx_t max_encoded = sizeof(vector_state.for_encoded); + if (bp_size > max_encoded || data_byte_offset + read_bytes + bp_size > block_size) { + throw IOException("Corrupted ALP segment: encoded payload too large"); + } memcpy(vector_state.for_encoded, (void *)vector_ptr, bp_size); vector_ptr += bp_size; + read_bytes += bp_size; } if (vector_state.exceptions_count > 0) { - memcpy(vector_state.exceptions, (void *)vector_ptr, sizeof(EXACT_TYPE) * vector_state.exceptions_count); - vector_ptr += sizeof(EXACT_TYPE) * vector_state.exceptions_count; - memcpy(vector_state.exceptions_positions, (void *)vector_ptr, - AlpConstants::EXCEPTION_POSITION_SIZE * vector_state.exceptions_count); + //! Load the exceptions + const idx_t max_exceptions_size = sizeof(vector_state.exceptions); + const idx_t exceptions_copy_size = sizeof(EXACT_TYPE) * vector_state.exceptions_count; + if (exceptions_copy_size > max_exceptions_size || + data_byte_offset + read_bytes + exceptions_copy_size > block_size) { + throw IOException("Corrupted ALP segment: exceptions payload too large"); + } + memcpy(vector_state.exceptions, (void *)vector_ptr, exceptions_copy_size); + vector_ptr += exceptions_copy_size; + read_bytes += exceptions_copy_size; + + //! Load the exceptions_positions + const idx_t max_exceptions_positions_size = sizeof(vector_state.exceptions_positions); + const idx_t exceptions_positions_copy_size = + AlpConstants::EXCEPTION_POSITION_SIZE * vector_state.exceptions_count; + if (exceptions_positions_copy_size > max_exceptions_positions_size || + data_byte_offset + read_bytes + exceptions_positions_copy_size > block_size) { + throw IOException("Corrupted ALP segment: exceptions_positions payload too large"); + } + memcpy(vector_state.exceptions_positions, (void *)vector_ptr, exceptions_positions_copy_size); + vector_ptr += exceptions_positions_copy_size; + read_bytes += exceptions_positions_copy_size; } // Decode all the vector values to the specified 'value_buffer' diff --git a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp index a911e1068..e4ce4d00c 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/alprd/alprd_scan.hpp @@ -74,26 +74,48 @@ struct AlpRDScanState : public SegmentScanState { // ScanStates never exceed the boundaries of a Segment, // but are not guaranteed to start at the beginning of the Block segment_data = handle.GetDataMutable() + segment.GetBlockOffset(); + const auto block_size = segment.GetBlockSize(); + + idx_t total_segment_offset = segment.GetBlockOffset(); auto metadata_offset = Load(segment_data); + auto segment_ptr = segment_data + AlpRDConstants::METADATA_POINTER_SIZE; + total_segment_offset += AlpRDConstants::METADATA_POINTER_SIZE; + metadata_ptr = segment_data + metadata_offset; + const idx_t metadata_ptr_offset = segment.GetBlockOffset() + metadata_offset; + if (metadata_ptr_offset > block_size) { + throw IOException("Corrupted ALPRD segment: metadata_offset value is corrupted"); + } + + if (total_segment_offset + AlpRDConstants::HEADER_SIZE > block_size) { + throw IOException("Corrupted ALPRD segment: reading header bytes would exceed block space"); + } // Load the Right Bit Width which is in the segment header after the pointer to the first metadata - vector_state.right_bit_width = Load(segment_data + AlpRDConstants::METADATA_POINTER_SIZE); - vector_state.left_bit_width = - Load(segment_data + AlpRDConstants::METADATA_POINTER_SIZE + AlpRDConstants::RIGHT_BIT_WIDTH_SIZE); + vector_state.right_bit_width = Load(segment_ptr); + segment_ptr += AlpRDConstants::RIGHT_BIT_WIDTH_SIZE; + + vector_state.left_bit_width = Load(segment_ptr); + segment_ptr += AlpRDConstants::LEFT_BIT_WIDTH_SIZE; + + uint8_t actual_dictionary_size = Load(segment_ptr); + segment_ptr += AlpRDConstants::N_DICTIONARY_ELEMENTS_SIZE; + + total_segment_offset += AlpRDConstants::HEADER_SIZE; - uint8_t actual_dictionary_size = - Load(segment_data + AlpRDConstants::METADATA_POINTER_SIZE + AlpRDConstants::RIGHT_BIT_WIDTH_SIZE + - AlpRDConstants::LEFT_BIT_WIDTH_SIZE); if (actual_dictionary_size > AlpRDConstants::MAX_DICTIONARY_SIZE) { throw IOException("Corrupt database file: ALPRD dictionary size exceeds maximum"); } idx_t actual_dictionary_size_bytes = static_cast(actual_dictionary_size) * AlpRDConstants::DICTIONARY_ELEMENT_SIZE; + const idx_t left_parts_dict_max_size = sizeof(vector_state.left_parts_dict); + if (total_segment_offset + actual_dictionary_size_bytes > metadata_ptr_offset || + actual_dictionary_size_bytes > left_parts_dict_max_size) { + throw IOException("Corrupted ALPRD segment: actual_dictionary_size is corrupted"); + } // Load the left parts dictionary which is after the segment header and is of a fixed size - memcpy(vector_state.left_parts_dict, (void *)(segment_data + AlpRDConstants::HEADER_SIZE), - actual_dictionary_size_bytes); + memcpy(vector_state.left_parts_dict, segment_ptr, actual_dictionary_size_bytes); } BufferHandle handle; @@ -148,7 +170,12 @@ struct AlpRDScanState : public SegmentScanState { // Load the offset (metadata) indicating where the vector data starts metadata_ptr -= AlpRDConstants::METADATA_POINTER_SIZE; auto data_byte_offset = Load(metadata_ptr); - D_ASSERT(data_byte_offset < segment.GetBlockSize()); + const auto block_size = segment.GetBlockSize(); + if (data_byte_offset >= block_size) { + throw IOException( + "Corrupted ALPRD segment: stored data_byte_offset (%d) exceeds the segments block size (%d)", + data_byte_offset, block_size); + } idx_t vector_size = MinValue((idx_t)AlpRDConstants::ALP_VECTOR_SIZE, (count - total_value_count)); @@ -162,29 +189,61 @@ struct AlpRDScanState : public SegmentScanState { if (uncompressed_mode) { if (!SKIP) { // Read uncompressed values - memcpy(value_buffer, vector_ptr, sizeof(T) * vector_size); + const idx_t value_buffer_copy_size = sizeof(T) * vector_size; + if (vector_ptr + value_buffer_copy_size > segment_data + block_size) { + const auto bytes_remaining_in_block = (segment_data + block_size) - vector_ptr; + throw IOException("Corrupted ALPRD segment: stored vector_size is invalid, to-copy bytes " + "(%d) would exceed bytes remaining in the block (%d)", + value_buffer_copy_size, bytes_remaining_in_block); + } + memcpy(value_buffer, vector_ptr, value_buffer_copy_size); } return; } - if (vector_state.exceptions_count > vector_size) { - throw IOException("Corrupt database file: ALPRD exceptions_count exceeds vector size"); - } auto left_bp_size = BitpackingPrimitives::GetRequiredSize(vector_size, vector_state.left_bit_width); auto right_bp_size = BitpackingPrimitives::GetRequiredSize(vector_size, vector_state.right_bit_width); + idx_t read_bytes = 0; + const idx_t max_left_encoded_size = sizeof(vector_state.left_encoded); + if (left_bp_size > max_left_encoded_size || data_byte_offset + read_bytes + left_bp_size > block_size) { + throw IOException("Corrupted ALPRD segment: left_encoded payload too large"); + } memcpy(vector_state.left_encoded, (void *)vector_ptr, left_bp_size); vector_ptr += left_bp_size; + read_bytes += left_bp_size; + const idx_t max_right_encoded_size = sizeof(vector_state.right_encoded); + if (right_bp_size > max_right_encoded_size || data_byte_offset + read_bytes + right_bp_size > block_size) { + throw IOException("Corrupted ALPRD segment: left_encoded payload too large"); + } memcpy(vector_state.right_encoded, (void *)vector_ptr, right_bp_size); vector_ptr += right_bp_size; + read_bytes += right_bp_size; if (vector_state.exceptions_count > 0) { - memcpy(vector_state.exceptions, (void *)vector_ptr, - AlpRDConstants::EXCEPTION_SIZE * vector_state.exceptions_count); - vector_ptr += AlpRDConstants::EXCEPTION_SIZE * vector_state.exceptions_count; - memcpy(vector_state.exceptions_positions, (void *)vector_ptr, - AlpRDConstants::EXCEPTION_POSITION_SIZE * vector_state.exceptions_count); + //! Load the exceptions + const idx_t max_exceptions_size = sizeof(vector_state.exceptions); + const idx_t exceptions_copy_size = AlpRDConstants::EXCEPTION_SIZE * vector_state.exceptions_count; + if (exceptions_copy_size > max_exceptions_size || + data_byte_offset + read_bytes + exceptions_copy_size > block_size) { + throw IOException("Corrupted ALPRD segment: exceptions payload too large"); + } + memcpy(vector_state.exceptions, (void *)vector_ptr, exceptions_copy_size); + vector_ptr += exceptions_copy_size; + read_bytes += exceptions_copy_size; + + //! Load the exceptions_positions + const idx_t max_exceptions_positions_size = sizeof(vector_state.exceptions_positions); + const idx_t exceptions_positions_copy_size = + AlpRDConstants::EXCEPTION_POSITION_SIZE * vector_state.exceptions_count; + if (exceptions_positions_copy_size > max_exceptions_positions_size || + data_byte_offset + read_bytes + exceptions_positions_copy_size > block_size) { + throw IOException("Corrupted ALPRD segment: exceptions_positions payload too large"); + } + memcpy(vector_state.exceptions_positions, (void *)vector_ptr, exceptions_positions_copy_size); + vector_ptr += exceptions_positions_copy_size; + read_bytes += exceptions_positions_copy_size; } // Decode all the vector values to the specified 'value_buffer' diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp index e5dee8903..350b6dda7 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp @@ -67,6 +67,14 @@ struct PatasGroupState { } value_buffer[0] = (EXACT_TYPE)0; for (idx_t i = 0; i < count; i++) { + if (unpacked_data[i].index_diff > i) { + throw IOException("Corrupted Patas segment: invalid backward reference"); + } + if (unpacked_data[i].significant_bytes > sizeof(EXACT_TYPE) || + unpacked_data[i].trailing_zeros >= sizeof(EXACT_TYPE) * 8) { + throw IOException("Corrupted Patas segment: invalid packed value metadata"); + } + value_buffer[i] = patas::PatasDecompression::DecompressValue( byte_reader, unpacked_data[i].significant_bytes, unpacked_data[i].trailing_zeros, value_buffer[i - unpacked_data[i].index_diff]); @@ -95,6 +103,9 @@ struct PatasScanState : public SegmentScanState { // but are not guaranteed to start at the beginning of the Block segment_data = handle.GetDataMutable() + segment.GetBlockOffset(); auto metadata_offset = Load(segment_data); + if (segment.GetBlockOffset() + metadata_offset > segment.GetBlockSize()) { + throw IOException("Corrupted Patas segment: metadata_offset reaches outside of the blocks memory"); + } metadata_ptr = segment_data + metadata_offset; } @@ -154,7 +165,9 @@ struct PatasScanState : public SegmentScanState { // Load the offset indicating where a groups data starts metadata_ptr -= sizeof(uint32_t); auto data_byte_offset = Load(metadata_ptr); - D_ASSERT(data_byte_offset < segment.GetBlockSize()); + if (segment.GetBlockOffset() + data_byte_offset >= segment.GetBlockSize()) { + throw IOException("Corrupted Patas segment: data_byte_offset would reach outside of the blocks memory"); + } // Initialize the byte_reader with the data values for the group group_state.Init(segment_data + data_byte_offset); diff --git a/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp index 7072f9e75..f22203e0f 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/variant_stats.hpp @@ -1,3 +1,11 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/statistics/variant_stats.hpp +// +// +//===----------------------------------------------------------------------===// + #pragma once #include "duckdb/common/types/variant.hpp" diff --git a/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp b/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp index 53e43516b..fdece29a1 100644 --- a/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp +++ b/src/duckdb/src/include/duckdb/transaction/duck_transaction_manager.hpp @@ -25,7 +25,7 @@ struct DuckCleanupInfo { transaction_t lowest_start_time; vector> transactions; - void Cleanup() noexcept; + void Cleanup(); bool ScheduleCleanup() noexcept; }; diff --git a/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp b/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp index 8a7ecfc77..f67967bd5 100644 --- a/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp +++ b/src/duckdb/src/include/duckdb/transaction/transaction_context.hpp @@ -70,8 +70,8 @@ class TransactionContext { private: ClientContext &context; bool auto_commit; - TransactionInvalidationPolicy invalidation_policy; - bool auto_rollback; + TransactionInvalidationPolicy invalidation_policy = TransactionInvalidationPolicy::STANDARD_POLICY; + bool auto_rollback = false; unique_ptr current_transaction; diff --git a/src/duckdb/src/logging/log_manager.cpp b/src/duckdb/src/logging/log_manager.cpp index b9a41e980..3ef2c4472 100644 --- a/src/duckdb/src/logging/log_manager.cpp +++ b/src/duckdb/src/logging/log_manager.cpp @@ -157,6 +157,15 @@ void LogManager::SetDisabledLogTypes(optional_ptr> disable void LogManager::SetLogStorage(DatabaseInstance &db, const string &storage_name) { unique_lock lck(lock); + // 'SET logging_storage' cannot supply the path that file storage requires, so reject the switch + // here (active storage preserved) and point users at enable_logging instead of installing a + // path-less storage that throws on every later flush. + auto storage_name_to_lower = StringUtil::Lower(storage_name); + if (storage_name_to_lower == LogConfig::FILE_STORAGE_NAME && config.storage != storage_name_to_lower) { + throw InvalidConfigurationException( + "Cannot select 'file' log storage via 'SET logging_storage' because it requires a path. " + "Use CALL enable_logging(storage='file', storage_path='...') instead."); + } SetLogStorageInternal(db, storage_name); } diff --git a/src/duckdb/src/logging/log_types.cpp b/src/duckdb/src/logging/log_types.cpp index ccc32a63d..b0a04e5b0 100644 --- a/src/duckdb/src/logging/log_types.cpp +++ b/src/duckdb/src/logging/log_types.cpp @@ -67,6 +67,7 @@ LogicalType HTTPLogType::GetLogType() { {"url", LogicalType::VARCHAR}, {"start_time", LogicalType::TIMESTAMP_TZ}, {"duration_ms", LogicalType::BIGINT}, + {"request_body_length", LogicalType::UBIGINT}, {"headers", LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)}, }; auto request_type = LogicalType::STRUCT(request_child_list); @@ -102,13 +103,14 @@ string HTTPLogType::ConstructLogMessage(BaseRequest &request, optional_ptr response_child_list = { {"status", Value(EnumUtil::ToString(response->status))}, - {"reason", Value(response->reason)}, + {"reason", Value(response->reason.empty() ? response->GetRequestError() : response->reason)}, {"headers", CreateHTTPHeadersValue(response->headers)}, }; response_value = Value::STRUCT(response_child_list); diff --git a/src/duckdb/src/main/client_config.cpp b/src/duckdb/src/main/client_config.cpp index d543e90cb..ebb26e2b9 100644 --- a/src/duckdb/src/main/client_config.cpp +++ b/src/duckdb/src/main/client_config.cpp @@ -4,11 +4,15 @@ namespace duckdb { +static Identifier StringToIdentifier(const String &name) { + return Identifier(string(name.data(), name.size())); +} + void ClientConfig::SetUserVariable(const String &name, Value value) { - user_variables[name.ToStdString()] = std::move(value); + user_variables[StringToIdentifier(name)] = std::move(value); } -bool ClientConfig::GetUserVariable(const string &name, Value &result) { +bool ClientConfig::GetUserVariable(const Identifier &name, Value &result) { auto entry = user_variables.find(name); if (entry == user_variables.end()) { return false; @@ -17,8 +21,12 @@ bool ClientConfig::GetUserVariable(const string &name, Value &result) { return true; } +bool ClientConfig::GetUserVariable(const string &name, Value &result) { + return GetUserVariable(Identifier(name), result); +} + void ClientConfig::ResetUserVariable(const String &name) { - user_variables.erase(name.ToStdString()); + user_variables.erase(StringToIdentifier(name)); } void ClientConfig::SetDefaultStreamingBufferSize() { diff --git a/src/duckdb/src/main/client_context.cpp b/src/duckdb/src/main/client_context.cpp index 8ac6dfd7c..082cb569c 100644 --- a/src/duckdb/src/main/client_context.cpp +++ b/src/duckdb/src/main/client_context.cpp @@ -371,7 +371,7 @@ ErrorData ClientContext::EndQueryInternal(ClientContextLock &lock, bool success, // Refresh the logger logger->Flush(); LoggingContext context(LogContextScope::CONNECTION); - context.connection_id = reinterpret_cast(this); + context.connection_id = connection_id; logger = db->GetLogManager().CreateLogger(context, true); // Notify any registered state of query end @@ -497,7 +497,7 @@ shared_ptr ClientContext::CreatePreparedStatementInternal #ifdef DEBUG logical_plan->Verify(*this); #endif - if (result->properties.parameter_count > 0 && !parameters.parameters) { + if (!result->value_map.empty() && !parameters.parameters) { // if this is a prepared statement we can choose not to fully plan // if we have parameters, we might want to re-bind when they are available as we can then do more optimizations // in this situation we check if we want to cache the plan at all @@ -589,7 +589,8 @@ QueryProgress ClientContext::GetQueryProgress() { return query_progress; } -void BindPreparedStatementParameters(PreparedStatementData &statement, const PendingQueryParameters ¶meters) { +void BindPreparedStatementParameters(ClientContext &context, PreparedStatementData &statement, + const PendingQueryParameters ¶meters) { identifier_map_t owned_values; if (parameters.parameters) { auto ¶ms = *parameters.parameters; @@ -597,7 +598,7 @@ void BindPreparedStatementParameters(PreparedStatementData &statement, const Pen owned_values.emplace(val); } } - statement.Bind(std::move(owned_values)); + statement.Bind(context, owned_values); } void ClientContext::RebindPreparedStatement(ClientContextLock &lock, const string &query, @@ -644,7 +645,7 @@ ClientContext::PendingPreparedStatementInternal(ClientContextLock &lock, const PendingQueryParameters ¶meters) { D_ASSERT(active_query); auto &statement_data = *statement_data_p; - BindPreparedStatementParameters(statement_data, parameters); + BindPreparedStatementParameters(*this, statement_data, parameters); // Create the query executor. active_query->executor = make_uniq(*this); @@ -736,6 +737,11 @@ PendingExecutionResult ClientContext::ExecuteTaskInternal(ClientContextLock &loc D_ASSERT(active_query->IsOpenResult(result)); bool invalidate_transaction = true; try { + // Surface a pending interrupt even when this thread runs no task that reaches InterruptCheck. + // IsInterrupted() rather than InterruptCheck(): we must not enforce query_deadline here. + if (!dry_run && IsInterrupted()) { + throw InterruptException(); + } auto query_result = active_query->executor->ExecuteTask(dry_run); if (active_query->progress_bar) { auto is_finished = PendingQueryResult::IsResultReady(query_result); @@ -936,19 +942,16 @@ unique_ptr ClientContext::PendingStatementInternal(ClientCon unique_ptr statement, const PendingQueryParameters ¶meters) { // prepare the query for execution - if (parameters.parameters) { - PreparedStatement::VerifyParameters(*parameters.parameters, statement->named_param_map); + if (!statement->named_param_map.empty() && parameters.parameters) { + PreparedStatement::VerifyParameters(*parameters.parameters, statement->named_param_map, this); + } else if (!statement->named_param_map.empty()) { + identifier_map_t empty_parameters; + PreparedStatement::VerifyParameters(empty_parameters, statement->named_param_map, this); } auto prepared = CreatePreparedStatement(lock, query, std::move(statement), parameters, PreparedStatementMode::PREPARE_AND_EXECUTE); - idx_t parameter_count = !parameters.parameters ? 0 : parameters.parameters->size(); - if (prepared->properties.parameter_count > 0 && parameter_count == 0) { - string error_message = StringUtil::Format("Expected %lld parameters, but none were supplied", - prepared->properties.parameter_count); - return ErrorResult(InvalidInputException(error_message), query); - } if (!prepared->properties.bound_all_parameters) { return ErrorResult(InvalidInputException("Not all parameters were bound"), query); } diff --git a/src/duckdb/src/main/config.cpp b/src/duckdb/src/main/config.cpp index d1fab8b1a..879a313a5 100644 --- a/src/duckdb/src/main/config.cpp +++ b/src/duckdb/src/main/config.cpp @@ -109,6 +109,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING(DebugForceExternalSetting), DUCKDB_SETTING(DebugForceFetchRowSetting), DUCKDB_SETTING(DebugForceNoCrossProductSetting), + DUCKDB_SETTING(DebugLocalFileSystemDelayMsSetting), DUCKDB_GLOBAL(DebugOrderVerificationSetting), DUCKDB_SETTING_CALLBACK(DebugPhysicalTableScanExecutionStrategySetting), DUCKDB_SETTING(DebugSkipCheckpointOnCommitSetting), @@ -245,12 +246,12 @@ static const ConfigurationOption internal_options[] = { static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("configure_metrics", 29), DUCKDB_SETTING_ALIAS("custom_profiling_settings", 29), - DUCKDB_SETTING_ALIAS("memory_limit", 126), - DUCKDB_SETTING_ALIAS("null_order", 59), - DUCKDB_SETTING_ALIAS("profile_output", 149), - DUCKDB_SETTING_ALIAS("user", 166), + DUCKDB_SETTING_ALIAS("memory_limit", 127), + DUCKDB_SETTING_ALIAS("null_order", 60), + DUCKDB_SETTING_ALIAS("profile_output", 150), + DUCKDB_SETTING_ALIAS("user", 167), DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 28), - DUCKDB_SETTING_ALIAS("worker_threads", 164), + DUCKDB_SETTING_ALIAS("worker_threads", 165), FINAL_ALIAS}; vector DBConfig::GetOptions() { diff --git a/src/duckdb/src/main/extension/extension_helper.cpp b/src/duckdb/src/main/extension/extension_helper.cpp index 1ec4a3de0..944a8f950 100644 --- a/src/duckdb/src/main/extension/extension_helper.cpp +++ b/src/duckdb/src/main/extension/extension_helper.cpp @@ -2,7 +2,6 @@ #include "duckdb/common/file_system.hpp" #include "duckdb/common/local_file_system.hpp" -#include "duckdb/main/database_file_opener.hpp" #include "duckdb/common/serializer/binary_deserializer.hpp" #include "duckdb/common/serializer/buffered_file_reader.hpp" #include "duckdb/common/string_util.hpp" @@ -10,6 +9,7 @@ #include "duckdb/logging/logger.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/main/database.hpp" +#include "duckdb/main/database_file_opener.hpp" #include "duckdb/main/extension.hpp" #include "duckdb/main/extension_install_info.hpp" #include "duckdb/main/settings.hpp" diff --git a/src/duckdb/src/main/pending_query_result.cpp b/src/duckdb/src/main/pending_query_result.cpp index 612413244..5726e2f86 100644 --- a/src/duckdb/src/main/pending_query_result.cpp +++ b/src/duckdb/src/main/pending_query_result.cpp @@ -86,7 +86,8 @@ unique_ptr PendingQueryResult::ExecuteInternal(ClientContextLock &l } } auto result = context->FetchResultInternal(lock, *this); - Close(); + // release our context reference (cannot use Close(): the context lock is already held here) + context.reset(); return result; } @@ -96,6 +97,14 @@ unique_ptr PendingQueryResult::Execute() { } void PendingQueryResult::Close() { + if (context) { + auto lock = LockContext(); + if (context->IsActiveResult(*lock, *this)) { + // Abandoned before execution finished: release the active-query state now (matching + // InitialCleanup) instead of leaking it until the next query or context teardown. + context->CleanupInternal(*lock, this, false); + } + } context.reset(); } diff --git a/src/duckdb/src/main/prepared_statement.cpp b/src/duckdb/src/main/prepared_statement.cpp index fd33a3c08..e269d4443 100644 --- a/src/duckdb/src/main/prepared_statement.cpp +++ b/src/duckdb/src/main/prepared_statement.cpp @@ -76,7 +76,9 @@ unique_ptr PreparedStatement::Execute(identifier_map_t(ErrorData(ex)); } @@ -119,7 +121,9 @@ unique_ptr PreparedStatement::PendingQuery(identifier_map_t< parameters.parameters = &named_values; try { - VerifyParameters(named_values, named_param_map); + if (!named_param_map.empty()) { + VerifyParameters(named_values, named_param_map, context.get()); + } } catch (const std::exception &ex) { return make_uniq(ErrorData(ex)); } diff --git a/src/duckdb/src/main/prepared_statement_data.cpp b/src/duckdb/src/main/prepared_statement_data.cpp index da01d0f0a..d050523db 100644 --- a/src/duckdb/src/main/prepared_statement_data.cpp +++ b/src/duckdb/src/main/prepared_statement_data.cpp @@ -3,7 +3,9 @@ #include "duckdb/catalog/catalog.hpp" #include "duckdb/common/exception/binder_exception.hpp" #include "duckdb/main/attached_database.hpp" +#include "duckdb/main/client_config.hpp" #include "duckdb/main/database_manager.hpp" +#include "duckdb/main/prepared_statement.hpp" #include "duckdb/transaction/transaction.hpp" namespace duckdb { @@ -39,13 +41,68 @@ bool CheckCatalogIdentity(ClientContext &context, const Identifier &catalog_name return StatementProperties::CatalogIdentity {current_catalog_oid, current_catalog_version} == catalog_identity; } +static BoundParameterData GetParameterValue(ClientContext &context, const identifier_map_t &values, + const Identifier &identifier, bool allow_user_variables) { + auto lookup = values.find(identifier); + if (lookup != values.end()) { + return lookup->second; + } + Value variable_value; + if (allow_user_variables && ClientConfig::GetConfig(context).GetUserVariable(identifier, variable_value)) { + return BoundParameterData(std::move(variable_value)); + } + throw BinderException("Could not find parameter with identifier %s", identifier); +} + +static identifier_map_t GetExpectedParameters(const bound_parameter_map_t &value_map) { + identifier_map_t result; + for (auto &entry : value_map) { + result[entry.first] = result.size(); + } + return result; +} + +static bool HasNamedParameters(const PreparedStatementData &data) { + return data.unbound_statement && !data.unbound_statement->named_param_map.empty(); +} + +static identifier_map_t GetExpectedParameters(const PreparedStatementData &data) { + if (HasNamedParameters(data)) { + return data.unbound_statement->named_param_map; + } + return GetExpectedParameters(data.value_map); +} + +void PreparedStatementData::PopulateMissingParameterValues(ClientContext &context, + identifier_map_t &values) const { + const auto expected_parameters = GetExpectedParameters(*this); + const bool allow_user_variables = HasNamedParameters(*this); + auto verification_context = allow_user_variables ? &context : nullptr; + PreparedStatement::VerifyParameters(values, expected_parameters, verification_context); + for (auto &entry : expected_parameters) { + if (values.count(entry.first)) { + continue; + } + Value variable_value; + const bool can_read_user_variable = + allow_user_variables && PreparedStatement::AllowsUserVariableFallback(entry.first); + if (can_read_user_variable && ClientConfig::GetConfig(context).GetUserVariable(entry.first, variable_value)) { + values[entry.first] = BoundParameterData(std::move(variable_value)); + } + } +} + bool PreparedStatementData::RequireRebind(ClientContext &context, optional_ptr> values) { - idx_t count = values ? values->size() : 0; - CheckParameterCount(count); + identifier_map_t empty_values; + auto ¶meter_values = values ? *values : empty_values; if (!unbound_statement) { throw InternalException("Prepared statement without unbound statement"); } + const auto expected_parameters = GetExpectedParameters(*this); + const bool allow_user_variables = HasNamedParameters(*this); + auto verification_context = allow_user_variables ? &context : nullptr; + PreparedStatement::VerifyParameters(parameter_values, expected_parameters, verification_context); if (properties.always_require_rebind) { // this statement must always be re-bound return true; @@ -56,11 +113,10 @@ bool PreparedStatementData::RequireRebind(ClientContext &context, } for (auto &it : value_map) { auto &identifier = it.first; - auto lookup = values->find(identifier); - if (lookup == values->end()) { - break; - } - if (lookup->second.GetValue().type() != it.second->return_type) { + const bool can_read_user_variable = + allow_user_variables && PreparedStatement::AllowsUserVariableFallback(identifier); + auto parameter_value = GetParameterValue(context, parameter_values, identifier, can_read_user_variable); + if (parameter_value.GetValue().type() != it.second->return_type) { return true; } } @@ -78,20 +134,27 @@ bool PreparedStatementData::RequireRebind(ClientContext &context, return false; } -void PreparedStatementData::Bind(identifier_map_t values) { +void PreparedStatementData::Bind(ClientContext &context, const identifier_map_t &values) { // set parameters D_ASSERT(!unbound_statement || unbound_statement->named_param_map.size() == properties.parameter_count); - CheckParameterCount(values.size()); + if (unbound_statement || !value_map.empty()) { + const auto expected_parameters = GetExpectedParameters(*this); + const bool allow_user_variables = HasNamedParameters(*this); + auto verification_context = allow_user_variables ? &context : nullptr; + PreparedStatement::VerifyParameters(values, expected_parameters, verification_context); + } else if (!values.empty()) { + CheckParameterCount(values.size()); + } // bind the required values + const bool allow_user_variables = HasNamedParameters(*this); for (auto &it : value_map) { const string &identifier = it.first.GetIdentifierName(); - auto lookup = values.find(it.first); - if (lookup == values.end()) { - throw BinderException("Could not find parameter with identifier %s", identifier); - } + const bool can_read_user_variable = + allow_user_variables && PreparedStatement::AllowsUserVariableFallback(it.first); + auto parameter_value = GetParameterValue(context, values, it.first, can_read_user_variable); D_ASSERT(it.second); - auto value = lookup->second.GetValue(); + auto value = parameter_value.GetValue(); if (!value.DefaultTryCastAs(it.second->return_type)) { throw BinderException( "Type mismatch for binding parameter with identifier %s, expected type %s but got type %s", identifier, diff --git a/src/duckdb/src/main/settings/autogenerated_settings.cpp b/src/duckdb/src/main/settings/autogenerated_settings.cpp index 8c650a89a..b2090c098 100644 --- a/src/duckdb/src/main/settings/autogenerated_settings.cpp +++ b/src/duckdb/src/main/settings/autogenerated_settings.cpp @@ -38,6 +38,9 @@ Value AccessModeSetting::GetSetting(const ClientContext &context) { // Allow Parser Override Extension //===----------------------------------------------------------------------===// void AllowParserOverrideExtensionSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("allow_parser_override_extension setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -45,6 +48,9 @@ void AllowParserOverrideExtensionSetting::OnSet(SettingCallbackInfo &info, Value // Arrow Output Version //===----------------------------------------------------------------------===// void ArrowOutputVersionSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("arrow_output_version setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -52,6 +58,9 @@ void ArrowOutputVersionSetting::OnSet(SettingCallbackInfo &info, Value ¶mete // Checkpoint On Detach //===----------------------------------------------------------------------===// void CheckpointOnDetachSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("checkpoint_on_detach setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -74,6 +83,9 @@ Value CustomUserAgentSetting::GetSetting(const ClientContext &context) { // Debug Checkpoint Abort //===----------------------------------------------------------------------===// void DebugCheckpointAbortSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("debug_checkpoint_abort setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -98,6 +110,9 @@ Value DebugOrderVerificationSetting::GetSetting(const ClientContext &context) { // Debug Physical Table Scan Execution Strategy //===----------------------------------------------------------------------===// void DebugPhysicalTableScanExecutionStrategySetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("debug_physical_table_scan_execution_strategy setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -105,6 +120,9 @@ void DebugPhysicalTableScanExecutionStrategySetting::OnSet(SettingCallbackInfo & // Debug Verify Statement //===----------------------------------------------------------------------===// void DebugVerifyStatementSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("debug_verify_statement setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -112,6 +130,9 @@ void DebugVerifyStatementSetting::OnSet(SettingCallbackInfo &info, Value ¶me // Debug Verify Vector //===----------------------------------------------------------------------===// void DebugVerifyVectorSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("debug_verify_vector setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -119,6 +140,9 @@ void DebugVerifyVectorSetting::OnSet(SettingCallbackInfo &info, Value ¶meter // Debug Window Mode //===----------------------------------------------------------------------===// void DebugWindowModeSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("debug_window_mode setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -126,6 +150,9 @@ void DebugWindowModeSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) // Default Io Mode //===----------------------------------------------------------------------===// void DefaultIoModeSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("default_io_mode setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -133,6 +160,9 @@ void DefaultIoModeSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { // Default Transaction Invalidation Policy //===----------------------------------------------------------------------===// void DefaultTransactionInvalidationPolicySetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("default_transaction_invalidation_policy setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -140,6 +170,9 @@ void DefaultTransactionInvalidationPolicySetting::OnSet(SettingCallbackInfo &inf // Deprecated Using Key Syntax //===----------------------------------------------------------------------===// void DeprecatedUsingKeySyntaxSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("deprecated_using_key_syntax setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -147,6 +180,9 @@ void DeprecatedUsingKeySyntaxSetting::OnSet(SettingCallbackInfo &info, Value &pa // Dialect Compatibility Mode //===----------------------------------------------------------------------===// void DialectCompatibilityModeSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("dialect_compatibility_mode setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -177,6 +213,9 @@ Value EnableProgressBarSetting::GetSetting(const ClientContext &context) { // Explain Output //===----------------------------------------------------------------------===// void ExplainOutputSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("explain_output setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -184,6 +223,9 @@ void ExplainOutputSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { // Force Bitpacking Mode //===----------------------------------------------------------------------===// void ForceBitpackingModeSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("force_bitpacking_mode setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -199,6 +241,9 @@ Value HTTPProxySetting::GetSetting(const ClientContext &context) { // Lambda Syntax //===----------------------------------------------------------------------===// void LambdaSyntaxSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("lambda_syntax setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -206,6 +251,9 @@ void LambdaSyntaxSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { // Pin Threads //===----------------------------------------------------------------------===// void PinThreadsSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("pin_threads setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -213,6 +261,9 @@ void PinThreadsSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { // Storage Block Prefetch //===----------------------------------------------------------------------===// void StorageBlockPrefetchSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("storage_block_prefetch setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -220,6 +271,9 @@ void StorageBlockPrefetchSetting::OnSet(SettingCallbackInfo &info, Value ¶me // Table Function Identifier Conversion //===----------------------------------------------------------------------===// void TableFunctionIdentifierConversionSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("table_function_identifier_conversion setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } @@ -227,6 +281,9 @@ void TableFunctionIdentifierConversionSetting::OnSet(SettingCallbackInfo &info, // Validate External File Cache //===----------------------------------------------------------------------===// void ValidateExternalFileCacheSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + if (parameter.IsNull()) { + throw InvalidInputException("validate_external_file_cache setting cannot be NULL"); + } EnumUtil::FromString(StringValue::Get(parameter)); } diff --git a/src/duckdb/src/main/stream_query_result.cpp b/src/duckdb/src/main/stream_query_result.cpp index 03b7e8863..f0baa62b9 100644 --- a/src/duckdb/src/main/stream_query_result.cpp +++ b/src/duckdb/src/main/stream_query_result.cpp @@ -193,6 +193,14 @@ bool StreamQueryResult::IsOpen() { void StreamQueryResult::Close() { buffered_data->Close(); + if (context) { + auto lock = LockContext(); + if (context->IsActiveResult(*lock, *this)) { + // Abandoned before the stream was fully drained: release the active-query state now + // (matching InitialCleanup) instead of leaking it until the next query or context teardown. + context->CleanupInternal(*lock, this, false); + } + } context.reset(); } diff --git a/src/duckdb/src/optimizer/aggregate_function_rewriter.cpp b/src/duckdb/src/optimizer/aggregate_function_rewriter.cpp index 536494c66..8aab455cd 100644 --- a/src/duckdb/src/optimizer/aggregate_function_rewriter.cpp +++ b/src/duckdb/src/optimizer/aggregate_function_rewriter.cpp @@ -182,6 +182,11 @@ class OrderedAggregateMatcher : public ExpressionMatcher { return false; } auto &expr = expr_p.Cast(); + // don't rewrite state-export aggregates - list(x ORDER BY x) EXPORT_STATE would become + // list_sort(list(x) EXPORT_STATE, ...), which cannot bind list_sort on the AGGREGATE_STATE result + if (expr.StateExportMode() == AggregateStateExportMode::STATE_EXPORT) { + return false; + } if (!FunctionMatcher::Match(function, expr.Function().GetName())) { return false; } diff --git a/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp b/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp index d351ccb6a..037ae908d 100644 --- a/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp +++ b/src/duckdb/src/optimizer/common_aggregate_optimizer.cpp @@ -50,6 +50,15 @@ void CommonAggregateOptimizer::ExtractCommonAggregates(LogicalAggregate &aggr) { for (idx_t i = 0; i < aggr.expressions.size(); i++) { ProjectionIndex original_index(i + total_erased); ProjectionIndex new_index(i); + // volatile aggregates must not be deduplicated: each call is independent + if (aggr.expressions[i]->IsVolatile()) { + if (new_index != original_index) { + ColumnBinding original_binding(aggr.aggregate_index, original_index); + ColumnBinding new_binding(aggr.aggregate_index, new_index); + aggregate_map[original_binding] = new_binding; + } + continue; + } auto entry = aggregate_remap.find(*aggr.expressions[i]); if (entry == aggregate_remap.end()) { // aggregate does not exist yet: add it to the map diff --git a/src/duckdb/src/optimizer/compressed_materialization.cpp b/src/duckdb/src/optimizer/compressed_materialization.cpp index c8be97512..430add9cb 100644 --- a/src/duckdb/src/optimizer/compressed_materialization.cpp +++ b/src/duckdb/src/optimizer/compressed_materialization.cpp @@ -5,15 +5,19 @@ #include "duckdb/function/scalar/compressed_materialization_utils.hpp" #include "duckdb/function/scalar/nested_functions.hpp" #include "duckdb/function/scalar/operators.hpp" +#include "duckdb/function/scalar/variant_functions.hpp" #include "duckdb/optimizer/column_binding_replacer.hpp" #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/optimizer/topn_optimizer.hpp" #include "duckdb/planner/binder.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/storage/statistics/geometry_stats.hpp" +#include "duckdb/storage/statistics/variant_stats.hpp" namespace duckdb { @@ -93,6 +97,11 @@ struct CMHelper { static unique_ptr CreateStringFunctionCompress(unique_ptr input, const LogicalType &target_type, unique_ptr compress_stats); + static bool GetVariantCompressInfo(const BaseStatistics &stats, LogicalType &shredded_type, + unique_ptr &typed_stats); + + //! Whether all (non-null) values are non-empty POINTs with XY vertices (so they fit in a UHUGEINT) + static bool GeometryIsAllPointXY(const BaseStatistics &stats); }; //===--------------------------------------------------------------------===// @@ -327,6 +336,40 @@ void CompressedMaterialization::GetReferencedBindings(const Expression &root_exp root_expr, [&](const BoundColumnRefExpression &col_ref) { referenced_bindings.insert(col_ref.Binding()); }); } +bool CompressedMaterialization::IsVariantWrapperFunction(const BoundFunctionExpression &expr) { + const auto &function_name = expr.Function().GetName(); + return function_name == "variant_comparator" || function_name == "variant_normalize"; +} + +optional_ptr +CompressedMaterialization::TryGetVariantWrapperColumnRef(const Expression &expr) { + if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { + return nullptr; + } + auto &function_expr = expr.Cast(); + if (!IsVariantWrapperFunction(function_expr) || function_expr.GetChildren().size() != 1) { + return nullptr; + } + auto &child = *function_expr.GetChildren()[0]; + if (child.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF || + child.GetReturnType().id() != LogicalTypeId::VARIANT) { + return nullptr; + } + return child.Cast(); +} + +optional_ptr CompressedMaterialization::GetVariantWrapperStats(const Expression &expr) { + auto colref = TryGetVariantWrapperColumnRef(expr); + if (!colref) { + return nullptr; + } + auto stats_it = statistics_map.find(colref->Binding()); + if (stats_it == statistics_map.end()) { + return nullptr; + } + return stats_it->second.get(); +} + void CompressedMaterialization::UpdateBindingInfo(CompressedMaterializationInfo &info, const ColumnBinding &binding, CompressedMaterializationType materialization_type) { auto &binding_map = info.binding_map; @@ -576,16 +619,29 @@ unique_ptr CompressedMaterialization::GetCompressExpression( if (type.IsAggregateState()) { return nullptr; } + if (stats.GetType().id() == LogicalTypeId::VARIANT && + input->GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { + auto &function_expr = input->Cast(); + if (IsVariantWrapperFunction(function_expr) && function_expr.GetChildren().size() == 1) { + return GetVariantCompress(std::move(function_expr.GetChildrenMutable()[0]), stats); + } + } if (type != stats.GetType()) { // LCOV_EXCL_START return nullptr; } // LCOV_EXCL_STOP if (type.IsIntegral()) { return GetIntegralCompress(std::move(input), stats); } - if (type.id() == LogicalTypeId::VARCHAR) { + switch (type.id()) { + case LogicalTypeId::VARCHAR: return GetStringCompress(std::move(input), stats); + case LogicalTypeId::GEOMETRY: + return GetGeometryCompress(std::move(input), stats); + case LogicalTypeId::VARIANT: + return GetVariantCompress(std::move(input), stats); + default: + return nullptr; } - return nullptr; } unique_ptr CompressedMaterialization::GetIntegralCompress(unique_ptr input, @@ -701,17 +757,194 @@ unique_ptr CompressedMaterialization::GetStringCompress(uniq return CMHelper::CreateStringFunctionCompress(std::move(input), cast_type, std::move(compress_stats)); } +bool CMHelper::GeometryIsAllPointXY(const BaseStatistics &stats) { + if (stats.GetType().id() != LogicalTypeId::GEOMETRY) { + return false; + } + if (stats.GetStatsType() != StatisticsType::GEOMETRY_STATS) { + return false; + } + // Only POINT-XY geometries are present (and at least one is). Empty points are fine: they are stored as a + // single XY vertex with NaN coordinates, so the WKB blob is always exactly 21 bytes. + if (!GeometryStats::GetTypes(stats).HasOnly(GeometryType::POINT, VertexType::XY)) { + return false; + } + return true; +} + +unique_ptr CompressedMaterialization::GetGeometryCompress(unique_ptr input, + const BaseStatistics &stats) { + if (!CMHelper::GeometryIsAllPointXY(stats)) { + // We can only pack POINT-XY geometries into a UHUGEINT + return nullptr; + } + + const auto target_type = LogicalType::UHUGEINT; + auto compress_function = CMGeometryPointCompressFun::GetFunction(); + vector> arguments; + arguments.emplace_back(std::move(input)); + + BoundScalarFunction bound_function(compress_function); + bound_function.SetReturnType(target_type); + auto compress_expr = make_uniq(std::move(bound_function), std::move(arguments), nullptr); + + auto compress_stats = BaseStatistics::CreateEmpty(target_type); + compress_stats.CopyBase(stats); + return make_uniq(std::move(compress_expr), compress_stats.ToUnique(), + CompressedMaterializationType::FUNCTION); +} + +bool CMHelper::GetVariantCompressInfo(const BaseStatistics &stats, LogicalType &shredded_type, + unique_ptr &typed_stats) { + if (stats.GetType().id() != LogicalTypeId::VARIANT) { + return false; + } + if (!VariantStats::IsShredded(stats)) { + return false; + } + auto structured_type = VariantStats::GetShreddedStructuredType(stats); + if (structured_type.IsNested()) { + // We can only compress VARIANT columns that are shredded on a primitive type + return false; + } + auto &shredded_stats = VariantStats::GetShreddedStats(stats); + if (!VariantShreddedStats::IsFullyShredded(shredded_stats)) { + // Partially shredded - some values do not fit the shredded type, the cast would fail for those + return false; + } + switch (structured_type.id()) { + case LogicalTypeId::BOOLEAN: + case LogicalTypeId::TINYINT: + case LogicalTypeId::SMALLINT: + case LogicalTypeId::INTEGER: + case LogicalTypeId::BIGINT: + case LogicalTypeId::HUGEINT: + case LogicalTypeId::UTINYINT: + case LogicalTypeId::USMALLINT: + case LogicalTypeId::UINTEGER: + case LogicalTypeId::UBIGINT: + case LogicalTypeId::UHUGEINT: + case LogicalTypeId::DECIMAL: + case LogicalTypeId::VARCHAR: + case LogicalTypeId::BLOB: + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_NS: + case LogicalTypeId::TIMESTAMP_SEC: + case LogicalTypeId::TIMESTAMP_MS: + case LogicalTypeId::TIMESTAMP: + case LogicalTypeId::TIMESTAMP_NS: + case LogicalTypeId::TIMESTAMP_TZ: + case LogicalTypeId::TIMESTAMP_TZ_NS: + case LogicalTypeId::UUID: + break; + default: + // We require that equal values of the shredded type always have identical VARIANT binary representations, + // e.g., FLOAT/DOUBLE ("-0.0" == "0.0") and INTERVAL ('1 month' == '30 days') do not qualify + return false; + } + auto &typed = VariantStats::GetTypedStats(shredded_stats); + if (typed.GetType() != structured_type) { // LCOV_EXCL_START + return false; + } // LCOV_EXCL_STOP + typed_stats = typed.ToUnique(); + if (stats.CanHaveNull()) { + // Both SQL NULL and the VARIANT null value become SQL NULL when casting to the shredded type + // (and they are indistinguishable at the top level of a VARIANT column, so this is lossless) + typed_stats->Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + shredded_type = structured_type; + return true; +} + +unique_ptr CompressedMaterialization::GetVariantCompress(unique_ptr input, + const BaseStatistics &stats) { + LogicalType shredded_type; + unique_ptr typed_stats; + if (!CMHelper::GetVariantCompressInfo(stats, shredded_type, typed_stats)) { + return nullptr; + } + + // VARIANT comparison keys are wrapped by the binder. For fully shredded primitive variants, comparing the + // shredded value gives the same equality/order semantics without materializing comparator blobs. + if (input->GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { + auto &function_expr = input->Cast(); + if (IsVariantWrapperFunction(function_expr) && function_expr.GetChildren().size() == 1) { + input = std::move(function_expr.GetChildrenMutable()[0]); + } + } + + auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(input), shredded_type); + + // Try to compress the shredded type further using the typed statistics + if (shredded_type.IsIntegral() && GetTypeIdSize(shredded_type.InternalType()) > 1) { + LogicalType offset_type; + Value range_value; + Value min; + if (CMHelper::GetIntegralOffsetCompressInfo(context, shredded_type, *typed_stats, offset_type, min, + range_value) && + GetTypeIdSize(offset_type.InternalType()) < GetTypeIdSize(shredded_type.InternalType())) { + // We always use the offset compress function (not a value-preserving cast) so that decompression + // can unambiguously derive how to restore the shredded type from the statistics alone + return CMHelper::CreateIntegralFunctionCompress(std::move(cast_expr), shredded_type, offset_type, min, + range_value, *typed_stats); + } + } else if (shredded_type.id() == LogicalTypeId::VARCHAR) { + LogicalType string_type = LogicalType::INVALID; + uint32_t max_string_length = 0; + if (CMHelper::GetStringCompressInfo(*typed_stats, string_type, max_string_length)) { + auto compress_stats = CMHelper::CreateStringCompressStats(*typed_stats, string_type, max_string_length); + return CMHelper::CreateStringFunctionCompress(std::move(cast_expr), string_type, std::move(compress_stats)); + } + } + + // Just the cast to the shredded type, this is still a lot cheaper to materialize than VARIANT. + // We mark it as FUNCTION (rather than CAST) so that decompression goes through GetVariantDecompress + return make_uniq(std::move(cast_expr), std::move(typed_stats), + CompressedMaterializationType::FUNCTION); +} + unique_ptr CompressedMaterialization::GetDecompressExpression(unique_ptr input, const LogicalType &result_type, const BaseStatistics &stats) { const auto &type = result_type; + if (type.id() == LogicalTypeId::VARIANT) { + return GetVariantDecompress(std::move(input), result_type, stats); + } + if (type.id() == LogicalTypeId::BLOB && stats.GetType().id() == LogicalTypeId::VARIANT) { + auto variant = GetVariantDecompress(std::move(input), LogicalType::VARIANT(), stats); + auto comparator_function = VariantComparatorFun::GetFunction(); + BoundScalarFunction bound_function(comparator_function); + bound_function.SetReturnType(LogicalType::BLOB); + vector> arguments; + arguments.push_back(std::move(variant)); + return make_uniq(std::move(bound_function), std::move(arguments), nullptr); + } + if (type.id() == LogicalTypeId::GEOMETRY) { + return GetGeometryDecompress(std::move(input), result_type, stats); + } if (TypeIsIntegral(type.InternalType())) { return GetIntegralDecompress(std::move(input), result_type, stats); } - if (type.id() == LogicalTypeId::VARCHAR) { + switch (type.id()) { + case LogicalTypeId::VARCHAR: return GetStringDecompress(std::move(input), result_type, stats); + default: + throw InternalException("Type other than integral/string/variant marked for decompression!"); } - throw InternalException("Type other than integral/string marked for decompression!"); +} + +unique_ptr CompressedMaterialization::GetGeometryDecompress(unique_ptr input, + const LogicalType &result_type, + const BaseStatistics &stats) { + D_ASSERT(result_type.id() == LogicalTypeId::GEOMETRY); + auto decompress_function = CMGeometryPointDecompressFun::GetFunction(); + vector> arguments; + arguments.emplace_back(std::move(input)); + + BoundScalarFunction bound_function(decompress_function); + bound_function.SetReturnType(result_type); + return make_uniq(std::move(bound_function), std::move(arguments), nullptr); } unique_ptr CompressedMaterialization::GetIntegralDecompress(unique_ptr input, @@ -744,4 +977,26 @@ unique_ptr CompressedMaterialization::GetStringDecompress(unique_ptr return make_uniq(std::move(bound_function), std::move(arguments), nullptr); } +unique_ptr CompressedMaterialization::GetVariantDecompress(unique_ptr input, + const LogicalType &result_type, + const BaseStatistics &stats) { + D_ASSERT(result_type.id() == LogicalTypeId::VARIANT); + LogicalType shredded_type; + unique_ptr typed_stats; + if (!CMHelper::GetVariantCompressInfo(stats, shredded_type, typed_stats)) { + throw InternalException("Could not obtain compress info for VARIANT decompression!"); + } + if (input->GetReturnType() != shredded_type) { + // The cast to the shredded type was compressed further, decompress to the shredded type first + if (shredded_type.IsIntegral()) { + input = GetIntegralDecompress(std::move(input), shredded_type, *typed_stats); + } else if (shredded_type.id() == LogicalTypeId::VARCHAR) { + input = GetStringDecompress(std::move(input), shredded_type, *typed_stats); + } else { // LCOV_EXCL_START + throw InternalException("Cannot decompress to the shredded type of a VARIANT!"); + } // LCOV_EXCL_STOP + } + return BoundCastExpression::AddCastToType(context, std::move(input), result_type); +} + } // namespace duckdb diff --git a/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp b/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp index 4de2f4af8..6671dbda8 100644 --- a/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp +++ b/src/duckdb/src/optimizer/compressed_materialization/compress_aggregate.cpp @@ -35,6 +35,20 @@ void CompressedMaterialization::CompressAggregate(unique_ptr &o vector materialization_types(groups.size(), CompressedMaterializationType::INVALID); vector> stored_group_stats; stored_group_stats.resize(groups.size()); + auto try_compress_group = [&](idx_t group_idx, Expression &group_expr, optional_ptr stats) { + if (!stats) { + return false; + } + auto compress_expr = GetCompressExpression(group_expr.Copy(), *stats); + if (!compress_expr) { + return false; + } + materialization_types[group_idx] = compress_expr->materialization_type; + stored_group_stats[group_idx] = stats->ToUnique(); + groups[group_idx] = std::move(compress_expr->expression); + group_stats[group_idx] = std::move(compress_expr->stats); + return true; + }; for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { auto &group_expr = *groups[group_idx]; if (group_expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { @@ -47,18 +61,10 @@ void CompressedMaterialization::CompressAggregate(unique_ptr &o GetReferencedBindings(group_expr, referenced_bindings); // The non-colref expression won't be compressed generically, so try to compress it here - if (!group_stats[group_idx]) { - continue; // Can't compress without stats - } - - // Try to compress, if successful, replace the expression - auto compress_expr = GetCompressExpression(group_expr.Copy(), *group_stats[group_idx]); - if (compress_expr) { - materialization_types[group_idx] = compress_expr->materialization_type; - stored_group_stats[group_idx] = std::move(group_stats[group_idx]); - groups[group_idx] = std::move(compress_expr->expression); - group_stats[group_idx] = std::move(compress_expr->stats); + if (try_compress_group(group_idx, group_expr, GetVariantWrapperStats(group_expr))) { + continue; } + try_compress_group(group_idx, group_expr, group_stats[group_idx].get()); } // Anything referenced in the aggregate functions is also excluded diff --git a/src/duckdb/src/optimizer/compressed_materialization/compress_comparison_join.cpp b/src/duckdb/src/optimizer/compressed_materialization/compress_comparison_join.cpp index 9570bd716..41e69d8f9 100644 --- a/src/duckdb/src/optimizer/compressed_materialization/compress_comparison_join.cpp +++ b/src/duckdb/src/optimizer/compressed_materialization/compress_comparison_join.cpp @@ -1,6 +1,9 @@ #include "duckdb/optimizer/compressed_materialization.hpp" +#include "duckdb/optimizer/statistics_propagator.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/storage/statistics/variant_stats.hpp" namespace duckdb { @@ -19,6 +22,98 @@ static void PopulateBindingMap(CompressedMaterializationInfo &info, const vector } } +#ifndef DEBUG +static bool HasVariantType(LogicalOperator &op) { + for (const auto &type : op.types) { + if (type.id() == LogicalTypeId::VARIANT) { + return true; + } + } + return false; +} +#endif + +bool CompressedMaterialization::TryGetVariantJoinKeyInfo(const Expression &expr, VariantJoinKeyInfo &result) { + auto child = TryGetVariantWrapperColumnRef(expr); + if (!child) { + return false; + } + auto stats_it = statistics_map.find(child->Binding()); + if (stats_it == statistics_map.end() || !stats_it->second) { + return false; + } + + auto &variant_stats = *stats_it->second; + if (variant_stats.GetType().id() != LogicalTypeId::VARIANT || !VariantStats::IsShredded(variant_stats)) { + return false; + } + auto &shredded_stats = VariantStats::GetShreddedStats(variant_stats); + if (!VariantShreddedStats::IsFullyShredded(shredded_stats)) { + return false; + } + + auto shredded_type = VariantStats::GetShreddedStructuredType(variant_stats); + if (!shredded_type.IsIntegral()) { + return false; + } + auto &typed_stats = VariantStats::GetTypedStats(shredded_stats); + if (typed_stats.GetType() != shredded_type) { // LCOV_EXCL_START + return false; + } // LCOV_EXCL_STOP + + result.child = child.get(); + result.shredded_type = std::move(shredded_type); + result.typed_stats = typed_stats.ToUnique(); + if (variant_stats.CanHaveNull()) { + result.typed_stats->Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + return true; +} + +unique_ptr CompressedMaterialization::CastVariantJoinKeyStats(const VariantJoinKeyInfo &key_info, + const LogicalType &target_type) { + if (key_info.shredded_type == target_type) { + return key_info.typed_stats->ToUnique(); + } + return StatisticsPropagator::TryPropagateCast(*key_info.typed_stats, key_info.shredded_type, target_type); +} + +unique_ptr CompressedMaterialization::CreateVariantJoinKeyCast(unique_ptr input, + const LogicalType &shredded_type, + const LogicalType &target_type) { + input = BoundCastExpression::AddCastToType(context, std::move(input), shredded_type); + if (shredded_type == target_type) { + return input; + } + return BoundCastExpression::AddCastToType(context, std::move(input), target_type); +} + +bool CompressedMaterialization::TryCompressVariantComparisonJoinKey(JoinCondition &condition) { + VariantJoinKeyInfo left; + VariantJoinKeyInfo right; + if (!TryGetVariantJoinKeyInfo(condition.GetLHS(), left) || !TryGetVariantJoinKeyInfo(condition.GetRHS(), right)) { + return false; + } + + LogicalType target_type; + if (!LogicalType::TryGetMaxLogicalType(context, left.shredded_type, right.shredded_type, target_type) || + !target_type.IsIntegral()) { + return false; + } + + auto left_stats = CastVariantJoinKeyStats(left, target_type); + auto right_stats = CastVariantJoinKeyStats(right, target_type); + if (!left_stats || !right_stats) { + return false; + } + + condition.LeftReference() = CreateVariantJoinKeyCast(left.child->Copy(), left.shredded_type, target_type); + condition.RightReference() = CreateVariantJoinKeyCast(right.child->Copy(), right.shredded_type, target_type); + condition.SetLeftStats(std::move(left_stats)); + condition.SetRightStats(std::move(right_stats)); + return true; +} + void CompressedMaterialization::CompressComparisonJoin(unique_ptr &op) { auto &join = op->Cast(); if (join.join_type == JoinType::MARK) { @@ -32,19 +127,22 @@ void CompressedMaterialization::CompressComparisonJoin(unique_ptr(join_cardinality) / static_cast(build_cardinality); - if (ratio > JOIN_CARDINALITY_RATIO_THRESHOLD) { + // If any of the inputs has VARIANT, we skip these checks: compressing is assumed to always be better + if (!HasVariantType(left_child) && !HasVariantType(right_child)) { + const auto build_cardinality = right_child.has_estimated_cardinality ? right_child.estimated_cardinality + : right_child.EstimateCardinality(context); + if (build_cardinality < JOIN_BUILD_CARDINALITY_THRESHOLD) { return; } + + if (right_child.types.size() < JOIN_BUILD_COLUMN_COUNT_THRESHOLD) { + const auto join_cardinality = + join.has_estimated_cardinality ? join.estimated_cardinality : join.EstimateCardinality(context); + const double ratio = static_cast(join_cardinality) / static_cast(build_cardinality); + if (ratio > JOIN_CARDINALITY_RATIO_THRESHOLD) { + return; + } + } } #endif @@ -65,13 +163,18 @@ void CompressedMaterialization::CompressComparisonJoin(unique_ptr(); auto &rhs_colref = condition.GetRHS().Cast(); diff --git a/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp b/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp index 820fb35e7..ac069dd7c 100644 --- a/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp +++ b/src/duckdb/src/optimizer/join_filter_pushdown_optimizer.cpp @@ -8,6 +8,7 @@ #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/operator/logical_aggregate.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/planner/operator/logical_get.hpp" @@ -20,17 +21,36 @@ namespace duckdb { JoinFilterPushdownOptimizer::JoinFilterPushdownOptimizer(Optimizer &optimizer) : optimizer(optimizer) { } -bool JoinFilterPushdownUtil::PushdownJoinFilterExpression(const Expression &expr, JoinFilterPushdownColumn &filter) { - if (expr.GetReturnType().IsNested()) { +static bool IsJoinFilterPushdownIntegralType(const LogicalType &type) { + return type.IsIntegral() && GetTypeIdSize(type.InternalType()) <= GetTypeIdSize(PhysicalType::INT64); +} + +static bool IsJoinFilterPushdownIntegralCast(const LogicalType &src, const LogicalType &tgt) { + return IsJoinFilterPushdownIntegralType(src) && IsJoinFilterPushdownIntegralType(tgt); +} + +static bool IsJoinFilterPushdownVariantIntegralCast(const LogicalType &src, const LogicalType &tgt) { + if (src.id() == LogicalTypeId::VARIANT) { + return IsJoinFilterPushdownIntegralType(tgt); + } + if (tgt.id() == LogicalTypeId::VARIANT) { + return IsJoinFilterPushdownIntegralType(src); + } + return false; +} + +static bool PushdownJoinFilterExpressionInternal(const Expression &expr, JoinFilterPushdownColumn &filter) { + const auto &return_type = expr.GetReturnType(); + if (return_type.IsNested() && return_type.id() != LogicalTypeId::VARIANT) { // nested columns are not supported for pushdown return false; } - if (expr.GetReturnType().id() == LogicalTypeId::INTERVAL) { + if (return_type.id() == LogicalTypeId::INTERVAL) { // interval is not supported for pushdown return false; } if (filter.mode == JoinFilterPushdownMode::RECONSTRUCT_EXPRESSION && !filter.runtime_filter_type.IsValid()) { - filter.runtime_filter_type = expr.GetReturnType(); + filter.runtime_filter_type = return_type; } switch (expr.GetExpressionClass()) { case ExpressionClass::BOUND_COLUMN_REF: { @@ -40,20 +60,25 @@ bool JoinFilterPushdownUtil::PushdownJoinFilterExpression(const Expression &expr return true; } case ExpressionClass::BOUND_CAST: { - // We allow pushing through integral down/upcasts, as long as source/target are (u)bigint or smaller + // We allow pushing through integral casts and integral/VARIANT casts. const auto &bound_cast = expr.Cast(); const auto &src = bound_cast.Child().GetReturnType(); const auto &tgt = bound_cast.GetReturnType(); - if (!src.IsIntegral() || !tgt.IsIntegral()) { + const bool integral_cast = IsJoinFilterPushdownIntegralCast(src, tgt); + const bool variant_integral_cast = IsJoinFilterPushdownVariantIntegralCast(src, tgt); + if (!integral_cast && !variant_integral_cast) { return false; } - if (GetTypeIdSize(src.InternalType()) > GetTypeIdSize(PhysicalType::INT64) || - GetTypeIdSize(tgt.InternalType()) > GetTypeIdSize(PhysicalType::INT64)) { - return false; // Only do this for (u)bigint and smaller - } - if (!JoinFilterPushdownUtil::PushdownJoinFilterExpression(bound_cast.Child(), filter)) { + if (!PushdownJoinFilterExpressionInternal(bound_cast.Child(), filter)) { return false; } + if (variant_integral_cast) { + if (tgt.id() == LogicalTypeId::VARIANT) { + filter.mode = JoinFilterPushdownMode::STORAGE_ONLY; + filter.runtime_filter_type = LogicalType::INVALID; + } + return true; + } const bool widening_signed_cast = src.IsSigned() == tgt.IsSigned() && GetTypeIdSize(tgt.InternalType()) >= GetTypeIdSize(src.InternalType()); const bool widening_unsigned_to_signed_cast = @@ -66,11 +91,22 @@ bool JoinFilterPushdownUtil::PushdownJoinFilterExpression(const Expression &expr } return true; } + case ExpressionClass::BOUND_FUNCTION: { + auto &function_expr = expr.Cast(); + if (function_expr.Function().GetName() != "variant_normalize" || function_expr.GetChildren().size() != 1) { + return false; + } + return PushdownJoinFilterExpressionInternal(*function_expr.GetChildren()[0], filter); + } default: return false; } } +bool JoinFilterPushdownUtil::PushdownJoinFilterExpression(const Expression &expr, JoinFilterPushdownColumn &filter) { + return PushdownJoinFilterExpressionInternal(expr, filter); +} + bool JoinFilterPushdownUtil::JoinTypeIsSupported(JoinType join_type) { switch (join_type) { case JoinType::MARK: @@ -174,6 +210,11 @@ void JoinFilterPushdownOptimizer::GetPushdownFilterTargets(LogicalOperator &op, } } D_ASSERT(filter.storage_type != LogicalType::INVALID); + if (filter.storage_type.id() == LogicalTypeId::VARIANT && + filter.mode == JoinFilterPushdownMode::RECONSTRUCT_EXPRESSION && + filter.runtime_filter_type.id() == LogicalTypeId::VARIANT) { + return; + } } targets.emplace_back(get, std::move(columns)); break; diff --git a/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp b/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp index 675a4184f..a30501cdd 100644 --- a/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp +++ b/src/duckdb/src/optimizer/join_order/relation_statistics_helper.cpp @@ -88,7 +88,7 @@ static idx_t GetUnsignedMinMaxDistinctCount(const BaseStatistics &base_stats, id static idx_t GetBooleanMinMaxDistinctCount(const BaseStatistics &base_stats, idx_t base_table_cardinality) { auto min_value = NumericStats::Min(base_stats).GetValueUnsafe(); auto max_value = NumericStats::Max(base_stats).GetValueUnsafe(); - auto distinct_count = min_value == max_value ? 1 : 2; + idx_t distinct_count = min_value == max_value ? idx_t(1) : idx_t(2); return CapMinMaxDistinctCount(distinct_count, base_table_cardinality); } diff --git a/src/duckdb/src/optimizer/projection_pullup.cpp b/src/duckdb/src/optimizer/projection_pullup.cpp index fdea0831a..fd56b4c68 100644 --- a/src/duckdb/src/optimizer/projection_pullup.cpp +++ b/src/duckdb/src/optimizer/projection_pullup.cpp @@ -1,6 +1,7 @@ #include "duckdb/optimizer/projection_pullup.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/expression/bound_columnref_expression.hpp" #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/optimizer/optimizer.hpp" @@ -221,6 +222,10 @@ void ProjectionPullup::CanPullThrough(column_binding_map_t &op) { + VisitOperator(op); +} + +void ProjectionPullup::VisitOperator(unique_ptr &op) { switch (op->type) { // These operators depend on column order. // If their immediate child is a projection, keep it and recurse into the projection’s child. @@ -255,15 +260,14 @@ void ProjectionPullup::Optimize(unique_ptr &op) { parents.push_back(*op); if (comp_join.join_type == JoinType::SEMI || comp_join.join_type == JoinType::ANTI) { // LHS: can pull through - Optimize(comp_join.children[0]); + VisitChildOfOperatorWithProjectionMap(comp_join.children[0], comp_join.left_projection_map); // RHS: Cannot pull through. Add a projection "barrier" InsertProjectionBelowOp(op, comp_join.children[1], false); } else { // All other joins: recurse normally on both sides - for (auto &child : op->children) { - Optimize(child); - } + VisitChildOfOperatorWithProjectionMap(comp_join.children[0], comp_join.left_projection_map); + VisitChildOfOperatorWithProjectionMap(comp_join.children[1], comp_join.right_projection_map); } PopParents(*op); @@ -274,7 +278,8 @@ void ProjectionPullup::Optimize(unique_ptr &op) { parents.push_back(*op); // Recurse - Optimize(op->children[0]); + auto &filter = op->Cast(); + VisitChildOfOperatorWithProjectionMap(op->children[0], filter.projection_map); PopParents(*op); return; @@ -347,9 +352,7 @@ void ProjectionPullup::Optimize(unique_ptr &op) { } // Create new optimizer for child (start fresh without any state) - for (auto &child : op->children) { - ProjectionPullup next(optimizer, root); - next.Optimize(child); - } + ProjectionPullup next(optimizer, root); + next.VisitOperatorChildren(*op); } } // namespace duckdb diff --git a/src/duckdb/src/parallel/async_result.cpp b/src/duckdb/src/parallel/async_result.cpp index a956fe75f..90c7899be 100644 --- a/src/duckdb/src/parallel/async_result.cpp +++ b/src/duckdb/src/parallel/async_result.cpp @@ -128,6 +128,9 @@ void AsyncResult::ScheduleTasks(InterruptState &interrupt_state, Executor &execu auto task = make_uniq(executor, std::move(async_task), interrupt_state, completion); TaskScheduler::GetScheduler(executor.context).ScheduleTask(executor.GetToken(), std::move(task), pool_type); } + + async_tasks.clear(); + result_type = AsyncResultType::INVALID; } void AsyncResult::ExecuteTasksSynchronously() { diff --git a/src/duckdb/src/parser/peg/transformer/peg_transformer.cpp b/src/duckdb/src/parser/peg/transformer/peg_transformer.cpp index b97ae11d8..44808f671 100644 --- a/src/duckdb/src/parser/peg/transformer/peg_transformer.cpp +++ b/src/duckdb/src/parser/peg/transformer/peg_transformer.cpp @@ -43,7 +43,9 @@ void PEGTransformer::SetParam(const Identifier &identifier, idx_t index, Prepare void PEGTransformer::ClearParameters() { prepared_statement_parameter_index = 0; + last_param_type = PreparedParamType::INVALID; named_parameter_map.clear(); + has_anonymous_parameters = false; } void PEGTransformer::Clear() { diff --git a/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp b/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp index 31dce43e1..e39cad1da 100644 --- a/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp +++ b/src/duckdb/src/parser/peg/transformer/peg_transformer_factory.cpp @@ -24,6 +24,7 @@ unique_ptr PEGTransformerFactory::TransformStatement(PEGTransforme // Avoid overriding a previous move with nothing result->named_param_map = transformer.named_parameter_map; } + result->has_anonymous_parameters = transformer.has_anonymous_parameters; return result; } @@ -110,7 +111,7 @@ unique_ptr PEGTransformerFactory::TransformTopLevelStatement(vecto ArenaAllocator transformer_allocator(Allocator::DefaultAllocator()); PEGTransformerState transformer_state(tokens); PEGTransformer transformer(transformer_allocator, transformer_state, sql_transform_functions, parser.rules, - enum_mappings, options); + options); return ExtractAndTransformStatement(transformer, tokens, stmt_opt.GetResult(), terminator_offset); } @@ -129,10 +130,6 @@ void PEGTransformerFactory::RegisterCommon() { REGISTER_TRANSFORM(TransformIntervalToIntervalAsType); } -void PEGTransformerFactory::RegisterCreateMacro() { - // create_macro.gram -} - void PEGTransformerFactory::RegisterCreateTable() { // create_table.gram REGISTER_TRANSFORM(TransformColLabelOrString); @@ -146,9 +143,6 @@ void PEGTransformerFactory::RegisterExpression() { REGISTER_TRANSFORM(TransformOverClause); } -void PEGTransformerFactory::RegisterConnect() { -} - void PEGTransformerFactory::RegisterPivot() { // PivotStatement and UnpivotStatement measure parameter usage while transforming // the source table, so their top-level wrappers remain manual. @@ -180,38 +174,16 @@ void PEGTransformerFactory::RegisterKeywordsAndIdentifiers() { Register("ExplainOptionName", &TransformIdentifierOrKeyword); } -void PEGTransformerFactory::RegisterEnums() { - RegisterEnum("MaterializedViewEntry", CatalogType::VIEW_ENTRY); - - RegisterEnum("MinusPrefixOperator", "-"); - RegisterEnum("PlusPrefixOperator", "+"); - RegisterEnum("TildePrefixOperator", "~"); - - RegisterEnum("ExcludeCurrentRow", WindowExcludeMode::CURRENT_ROW); - RegisterEnum("ExcludeGroup", WindowExcludeMode::GROUP); - RegisterEnum("ExcludeTies", WindowExcludeMode::TIES); - RegisterEnum("ExcludeNoOthers", WindowExcludeMode::NO_OTHER); - - RegisterEnum("SubqueryAny", true); - RegisterEnum("SubqueryAll", false); - - RegisterEnum("IncludeNulls", true); - RegisterEnum("ExcludeNulls", false); -} - PEGTransformerFactory::PEGTransformerFactory() { RegisterGenerated(); REGISTER_TRANSFORM(TransformStatement); RegisterComment(); RegisterCommon(); - RegisterCreateMacro(); RegisterCreateTable(); RegisterExpression(); - RegisterConnect(); RegisterPivot(); RegisterSelect(); RegisterKeywordsAndIdentifiers(); - RegisterEnums(); } vector> PEGTransformerFactory::ExtractParseResultsFromList(ParseResult &parse_result) { diff --git a/src/duckdb/src/parser/peg/transformer/transform_expression.cpp b/src/duckdb/src/parser/peg/transformer/transform_expression.cpp index e0c6a23c1..50bf4a901 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_expression.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_expression.cpp @@ -1197,7 +1197,7 @@ unique_ptr PEGTransformerFactory::TransformExponentiationExpre BinaryExpressionTail PEGTransformerFactory::TransformBitwiseExpressionTail(PEGTransformer &transformer, const string &bit_operator, unique_ptr additive_expression) { - return {bit_operator, std::move(additive_expression)}; + return {bit_operator, std::move(additive_expression), optional_idx()}; } BinaryExpressionTail @@ -1210,12 +1210,12 @@ PEGTransformerFactory::TransformAdditiveExpressionTail(PEGTransformer &transform BinaryExpressionTail PEGTransformerFactory::TransformMultiplicativeExpressionTail(PEGTransformer &transformer, const string &factor, unique_ptr exponentiation_expression) { - return {factor, std::move(exponentiation_expression)}; + return {factor, std::move(exponentiation_expression), optional_idx()}; } BinaryExpressionTail PEGTransformerFactory::TransformExponentiationExpressionTail( PEGTransformer &transformer, const string &exponent_operator, unique_ptr collate_expression) { - return {exponent_operator, std::move(collate_expression)}; + return {exponent_operator, std::move(collate_expression), optional_idx()}; } unique_ptr PEGTransformerFactory::TransformCollateExpression( @@ -1355,6 +1355,7 @@ unique_ptr PEGTransformerFactory::TransformAnonymousParameter( // Register it transformer.SetParam(Identifier(identifier), known_param_index, PreparedParamType::AUTO_INCREMENT); transformer.SetParamCount(MaxValue(transformer.ParamCount(), known_param_index)); + transformer.has_anonymous_parameters = true; expr->IdentifierMutable() = Identifier(identifier); return std::move(expr); @@ -1411,6 +1412,7 @@ PEGTransformerFactory::TransformNumberedParameter(PEGTransformer &transformer, expr->IdentifierMutable() = Identifier(identifier); transformer.SetParamCount(MaxValue(transformer.ParamCount(), known_param_index)); + transformer.has_anonymous_parameters = true; return std::move(expr); } diff --git a/src/duckdb/src/parser/peg/transformer/transform_generated.cpp b/src/duckdb/src/parser/peg/transformer/transform_generated.cpp index ddd7bc131..2c2ba3a58 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_generated.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_generated.cpp @@ -8720,6 +8720,17 @@ unique_ptr PEGTransformerFactory::TransformRegularJoinClau return make_uniq>>(std::move(result)); } +unique_ptr PEGTransformerFactory::TransformJoinByClauseInternal(PEGTransformer &transformer, + ParseResult &parse_result) { + auto &list_pr = parse_result.Cast(); + auto col_label = + transformer.Transform(ExtractResultFromParens(list_pr.GetChild(2)).Cast().GetChild(1)); + auto table_ref = transformer.Transform>(list_pr.GetChild(3)); + auto join_qualifier = transformer.Transform(list_pr.GetChild(4)); + auto result = TransformJoinByClause(transformer, col_label, std::move(table_ref), std::move(join_qualifier)); + return make_uniq>>(std::move(result)); +} + unique_ptr PEGTransformerFactory::TransformAsofInternal(PEGTransformer &transformer, ParseResult &parse_result) { auto result = TransformAsof(transformer); @@ -10801,6 +10812,7 @@ void PEGTransformerFactory::RegisterGenerated() { {"TimestampAtUnit", &PEGTransformerFactory::TransformTimestampAtUnitInternal}, {"JoinClause", &PEGTransformerFactory::TransformJoinClauseInternal}, {"RegularJoinClause", &PEGTransformerFactory::TransformRegularJoinClauseInternal}, + {"JoinByClause", &PEGTransformerFactory::TransformJoinByClauseInternal}, {"Asof", &PEGTransformerFactory::TransformAsofInternal}, {"JoinWithoutOnClause", &PEGTransformerFactory::TransformJoinWithoutOnClauseInternal}, {"JoinQualifier", &PEGTransformerFactory::TransformJoinQualifierInternal}, diff --git a/src/duckdb/src/parser/peg/transformer/transform_select.cpp b/src/duckdb/src/parser/peg/transformer/transform_select.cpp index f49eceea2..7bac4239e 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_select.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_select.cpp @@ -1,3 +1,5 @@ +#include "duckdb/common/enum_util.hpp" +#include "duckdb/common/string_util.hpp" #include "duckdb/parser/peg/ast/distinct_clause.hpp" #include "duckdb/parser/peg/ast/join_prefix.hpp" #include "duckdb/parser/peg/ast/join_qualifier.hpp" @@ -523,6 +525,31 @@ unique_ptr PEGTransformerFactory::TransformRegularJoinClause(PEGTransf return std::move(result); } +unique_ptr PEGTransformerFactory::TransformJoinByClause(PEGTransformer &transformer, const string &col_label, + unique_ptr table_ref, + JoinQualifier join_qualifier) { + auto result = make_uniq(); + // resolve the join type name against the JoinType enum (case-insensitive); accept an optional `_join` suffix, + // so e.g. `mark` and `mark_join` are equivalent. EnumUtil::FromString throws on an unknown name. + auto type_name = col_label; + if (StringUtil::EndsWith(StringUtil::Lower(type_name), "_join")) { + type_name = type_name.substr(0, type_name.size() - 5); + } + result->type = EnumUtil::FromString(type_name); + if (result->type == JoinType::INVALID) { + throw ParserException("\"%s\" is not a valid join type for JOIN BY", col_label); + } + result->right = std::move(table_ref); + if (join_qualifier.on_clause) { + result->condition = std::move(join_qualifier.on_clause); + } else if (!join_qualifier.using_columns.empty()) { + result->using_columns = std::move(join_qualifier.using_columns); + } else { + throw InternalException("Invalid join qualifier found."); + } + return std::move(result); +} + bool PEGTransformerFactory::TransformAsof(PEGTransformer &transformer) { return true; } diff --git a/src/duckdb/src/parser/tableref/joinref.cpp b/src/duckdb/src/parser/tableref/joinref.cpp index 233bbc9e5..02a32dfd3 100644 --- a/src/duckdb/src/parser/tableref/joinref.cpp +++ b/src/duckdb/src/parser/tableref/joinref.cpp @@ -5,6 +5,22 @@ namespace duckdb { +//! True if the join type is expressible with standard SQL JOIN syntax (e.g. INNER JOIN, LEFT JOIN). +//! Internal join types (MARK, SINGLE, ...) instead use the JOIN BY (TYPE ) form. +static bool JoinTypeUsesStandardSyntax(JoinType type) { + switch (type) { + case JoinType::INNER: + case JoinType::LEFT: + case JoinType::RIGHT: + case JoinType::OUTER: + case JoinType::SEMI: + case JoinType::ANTI: + return true; + default: + return false; + } +} + string JoinRef::ToString() const { string result; if (!is_implicit) { @@ -13,7 +29,11 @@ string JoinRef::ToString() const { result += left->ToString() + " "; switch (ref_type) { case JoinRefType::REGULAR: - result += EnumUtil::ToString(type) + " JOIN "; + if (JoinTypeUsesStandardSyntax(type)) { + result += EnumUtil::ToString(type) + " JOIN "; + } else { + result += "JOIN BY (TYPE " + EnumUtil::ToString(type) + ") "; + } break; case JoinRefType::NATURAL: result += "NATURAL "; diff --git a/src/duckdb/src/parser/tableref/pivotref.cpp b/src/duckdb/src/parser/tableref/pivotref.cpp index 2174ed719..2617b9b75 100644 --- a/src/duckdb/src/parser/tableref/pivotref.cpp +++ b/src/duckdb/src/parser/tableref/pivotref.cpp @@ -371,7 +371,7 @@ string PivotRef::ToString() const { if (i > 0) { result += ", "; } - result += groups[i]; + result += SQLIdentifier(groups[i].GetIdentifierName()); } } result += ")"; diff --git a/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp index 47128cd9a..55c92b7d0 100644 --- a/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp @@ -323,10 +323,11 @@ BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFu error.Throw(); } + // attach the ORDER BY before the state export: an ordered aggregate's exported type depends on the ORDER BY keys + aggregate->GetOrderBysMutable() = std::move(order_bys); if (aggr.ExportState()) { aggregate = ExportAggregateFunction::Bind(std::move(aggregate)); } - aggregate->GetOrderBysMutable() = std::move(order_bys); // check for all the aggregates if this aggregate already exists ProjectionIndex aggr_index; diff --git a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp index 9ee9ba320..01ba2a795 100644 --- a/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_parameter_expression.cpp @@ -4,32 +4,46 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_parameter_expression.hpp" #include "duckdb/planner/expression_binder.hpp" +#include "duckdb/main/client_config.hpp" +#include "duckdb/main/prepared_statement.hpp" namespace duckdb { BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t depth) { auto parameters = binder.GetParameters(); - if (!parameters) { - throw BinderException("Unexpected prepared parameter. This type of statement can't be prepared!"); - } auto parameter_id = expr.Identifier(); - // Check if a parameter value has already been supplied - auto ¶meter_data = parameters->GetParameterData(); - auto param_data_it = parameter_data.find(parameter_id); - if (param_data_it != parameter_data.end()) { - // it has! emit a constant directly - auto &data = param_data_it->second; - auto return_type = parameters->GetReturnType(parameter_id); - bool is_literal = - return_type.id() == LogicalTypeId::INTEGER_LITERAL || return_type.id() == LogicalTypeId::STRING_LITERAL; - auto constant = make_uniq(data.GetValue()); - constant->SetAlias(expr.GetAlias()); - if (is_literal) { + if (parameters) { + // Check if a parameter value has already been supplied (named params take precedence) + auto ¶meter_data = parameters->GetParameterData(); + auto param_data_it = parameter_data.find(parameter_id); + if (param_data_it != parameter_data.end()) { + // it has! emit a constant directly + auto &data = param_data_it->second; + auto return_type = parameters->GetReturnType(parameter_id); + bool is_literal = + return_type.id() == LogicalTypeId::INTEGER_LITERAL || return_type.id() == LogicalTypeId::STRING_LITERAL; + auto constant = make_uniq(data.GetValue()); + constant->SetAlias(expr.GetAlias()); + if (is_literal) { + return BindResult(std::move(constant)); + } + auto cast = BoundCastExpression::AddCastToType(context, std::move(constant), return_type); + return BindResult(std::move(cast)); + } + } + // No explicit parameter value supplied; fall back to a user variable of the same name. + if (binder.GetBindingMode() != BindingMode::PREPARE && + PreparedStatement::AllowsUserVariableFallback(parameter_id)) { + Value variable_value; + if (ClientConfig::GetConfig(context).GetUserVariable(parameter_id, variable_value)) { + auto constant = make_uniq(variable_value); + constant->SetAlias(expr.GetAlias()); return BindResult(std::move(constant)); } - auto cast = BoundCastExpression::AddCastToType(context, std::move(constant), return_type); - return BindResult(std::move(cast)); + } + if (!parameters) { + throw BinderException("Unexpected prepared parameter. This type of statement can't be prepared!"); } auto bound_parameter = parameters->BindParameterExpression(expr); diff --git a/src/duckdb/src/planner/binder/statement/bind_call.cpp b/src/duckdb/src/planner/binder/statement/bind_call.cpp index a746e1689..6a7d07b0b 100644 --- a/src/duckdb/src/planner/binder/statement/bind_call.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_call.cpp @@ -3,9 +3,23 @@ #include "duckdb/planner/binder.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/expression/star_expression.hpp" +#include "duckdb/planner/operator/logical_get.hpp" namespace duckdb { +static optional_ptr FindTableFunctionGet(LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_GET) { + return op.Cast(); + } + for (auto &child : op.children) { + auto get = FindTableFunctionGet(*child); + if (get) { + return get; + } + } + return nullptr; +} + BoundStatement Binder::Bind(CallStatement &stmt) { SelectStatement select_statement; auto select_node = make_uniq(); @@ -18,6 +32,13 @@ BoundStatement Binder::Bind(CallStatement &stmt) { auto result = Bind(select_statement); auto &properties = GetStatementProperties(); properties.output_type = QueryResultOutputType::FORCE_MATERIALIZED; + // use the return type of the table function (if any) instead of the default query result + if (result.plan) { + auto get = FindTableFunctionGet(*result.plan); + if (get) { + properties.return_type = get->function.call_return_type; + } + } return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_execute.cpp b/src/duckdb/src/planner/binder/statement/bind_execute.cpp index 95001cf84..a8e0c3ff4 100644 --- a/src/duckdb/src/planner/binder/statement/bind_execute.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_execute.cpp @@ -3,11 +3,10 @@ #include "duckdb/planner/planner.hpp" #include "duckdb/planner/operator/logical_execute.hpp" #include "duckdb/planner/expression_binder/constant_binder.hpp" -#include "duckdb/main/client_config.hpp" #include "duckdb/main/client_data.hpp" #include "duckdb/execution/expression_executor.hpp" -#include "duckdb/common/case_insensitive_map.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/parser/expression/constant_expression.hpp" namespace duckdb { @@ -27,7 +26,7 @@ BoundStatement Binder::Bind(ExecuteStatement &stmt) { auto prepared = entry->second; auto &named_param_map = prepared->unbound_statement->named_param_map; - PreparedStatement::VerifyParameters(stmt.named_values, named_param_map); + PreparedStatement::VerifyParameters(stmt.named_values, named_param_map, &context); auto &mapped_named_values = stmt.named_values; // bind any supplied parameters @@ -59,6 +58,7 @@ BoundStatement Binder::Bind(ExecuteStatement &stmt) { } bind_values[pair.first] = std::move(parameter_data); } + prepared->PopulateMissingParameterValues(context, bind_values); unique_ptr rebound_plan; RebindQueryInfo rebind = RebindQueryInfo::DO_NOT_REBIND; @@ -90,7 +90,7 @@ BoundStatement Binder::Bind(ExecuteStatement &stmt) { result.names = prepared->names; result.types = prepared->types; - prepared->Bind(std::move(bind_values)); + prepared->Bind(context, bind_values); if (rebound_plan) { auto execute_plan = make_uniq(std::move(prepared)); execute_plan->children.push_back(std::move(rebound_plan)); diff --git a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp index 29d66f8d2..26e9f2028 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_joinref.cpp @@ -379,7 +379,7 @@ BoundStatement Binder::Bind(JoinRef &ref) { if (result->type == JoinType::MARK) { auto mark_join_idx = GenerateTableIndex(); string mark_join_alias = "__internal_mark_join_ref" + to_string(mark_join_idx.index); - bind_context.AddGenericBinding(mark_join_idx, Identifier(mark_join_alias), {"__mark_index_column"}, + bind_context.AddGenericBinding(mark_join_idx, Identifier(mark_join_alias), {"__mark_join_marker"}, {LogicalType::BOOLEAN}); result->mark_index = mark_join_idx; } diff --git a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp index ec7b06f87..994fb5737 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp @@ -742,6 +742,24 @@ struct UnpivotEntry { void Binder::ExtractUnpivotEntries(Binder &child_binder, PivotColumnEntry &entry, vector &unpivot_entries) { + // Try to bind the entry expression as values - but only when it is composed purely of column + // references and constants (a column list). Entries that contain other expressions are + // unpivoted as expressions instead of being folded into column names. + if (entry.expr) { + vector column_list; + if (TryExtractUnpivotList(*entry.expr, column_list)) { + try { + auto expr_copy = entry.expr->Copy(); + BindPivotInList(expr_copy, entry.values, child_binder); + // successfully bound as values - clear the expression + entry.expr = nullptr; + } catch (...) { + // ignore binder exceptions here - we fall back to expression mode + entry.values.clear(); + } + } + } + if (!entry.expr) { // pivot entry without an expression - generate one UnpivotEntry unpivot_entry; diff --git a/src/duckdb/src/planner/collation_binding.cpp b/src/duckdb/src/planner/collation_binding.cpp index cedc1764d..c979c9ffe 100644 --- a/src/duckdb/src/planner/collation_binding.cpp +++ b/src/duckdb/src/planner/collation_binding.cpp @@ -2,6 +2,8 @@ #include "duckdb/catalog/catalog_entry/collate_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" +#include "duckdb/planner/expression/bound_lambda_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/main/config.hpp" #include "duckdb/main/settings.hpp" #include "duckdb/catalog/catalog.hpp" @@ -10,11 +12,12 @@ namespace duckdb { constexpr const char *CollateCatalogEntry::Name; -bool PushVarcharCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, - CollationType type) { +static vector GetVarcharCollationFunctions(ClientContext &context, const LogicalType &sql_type, + CollationType type) { + vector result; if (sql_type.id() != LogicalTypeId::VARCHAR) { // only VARCHAR columns require collation - return false; + return result; } // replace default collation with system collation auto str_collation = StringType::GetCollation(sql_type); @@ -28,7 +31,7 @@ bool PushVarcharCollation(ClientContext &context, unique_ptr &source // bind the collation if (collation.empty() || collation == "binary" || collation == "c" || collation == "posix") { // no collation or binary collation: skip - return false; + return result; } auto &catalog = Catalog::GetSystemCatalog(context); auto splits = StringUtil::Split(StringUtil::Lower(collation), "."); @@ -55,128 +58,131 @@ bool PushVarcharCollation(ClientContext &context, unique_ptr &source for (auto &entry : entries) { auto &collation_entry = entry.get(); if (!collation_entry.combinable && type == CollationType::COMBINABLE_COLLATIONS) { - // not a combinable collation - ignore - return false; + // not a combinable collation - only apply the (preceding) combinable collations + break; } - vector> children; - children.push_back(std::move(source)); - - FunctionBinder function_binder(context); - auto function = function_binder.BindScalarFunction(collation_entry.function, std::move(children)); - source = std::move(function); + result.push_back(collation_entry.function.GetName().GetIdentifierName()); } - return true; + return result; } -bool PushTimeTZCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, - CollationType) { +static vector GetTimeTZCollationFunctions(ClientContext &, const LogicalType &sql_type, CollationType) { if (sql_type.id() != LogicalTypeId::TIME_TZ) { - return false; + return vector(); } + return {"timetz_byte_comparable"}; +} - auto &catalog = Catalog::GetSystemCatalog(context); - auto &function_entry = - catalog.GetEntry(context, Identifier::DefaultSchema(), "timetz_byte_comparable"); - if (function_entry.functions.Size() != 1) { - throw InternalException("timetz_byte_comparable should only have a single overload"); +static vector GetBitStringCollationFunctions(ClientContext &, const LogicalType &sql_type, CollationType) { + if (sql_type.id() != LogicalTypeId::BIT) { + return vector(); } - const auto &scalar_function = function_entry.functions.GetFunctionByOffset(0); - vector> children; - children.push_back(std::move(source)); - - FunctionBinder function_binder(context); - auto function = function_binder.BindScalarFunction(scalar_function, std::move(children)); - source = std::move(function); - return true; + return {"bitstring_byte_comparable"}; } -bool PushBitStringCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, - CollationType) { - if (sql_type.id() != LogicalTypeId::BIT) { - return false; +static vector GetIntervalCollationFunctions(ClientContext &, const LogicalType &sql_type, CollationType) { + if (sql_type.id() != LogicalTypeId::INTERVAL) { + return vector(); } + return {"normalized_interval"}; +} - auto &catalog = Catalog::GetSystemCatalog(context); - auto &function_entry = - catalog.GetEntry(context, Identifier::DefaultSchema(), "bitstring_byte_comparable"); - if (function_entry.functions.Size() != 1) { - throw InternalException("bitstring_byte_comparable should only have a single overload"); +static vector GetVariantCollationFunctions(ClientContext &, const LogicalType &sql_type, CollationType) { + if (sql_type.id() != LogicalTypeId::VARIANT) { + return vector(); } - const auto &scalar_function = function_entry.functions.GetFunctionByOffset(0); - vector> children; - children.push_back(std::move(source)); + return {"variant_comparator"}; +} - FunctionBinder function_binder(context); - auto function = function_binder.BindScalarFunction(scalar_function, std::move(children)); - source = std::move(function); - return true; +CollationBinding::CollationBinding() { + RegisterCollation(CollationCallback(GetVarcharCollationFunctions)); + RegisterCollation(CollationCallback(GetTimeTZCollationFunctions)); + RegisterCollation(CollationCallback(GetBitStringCollationFunctions)); + RegisterCollation(CollationCallback(GetIntervalCollationFunctions)); + RegisterCollation(CollationCallback(GetVariantCollationFunctions)); } -bool PushIntervalCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, - CollationType) { - if (sql_type.id() != LogicalTypeId::INTERVAL) { - return false; - } +void CollationBinding::RegisterCollation(CollationCallback callback) { + collations.push_back(callback); +} +//! Binds the scalar function with the given name (looked up from the system catalog) around "source". +static unique_ptr ApplyCollationFunction(ClientContext &context, const string &function_name, + unique_ptr source) { auto &catalog = Catalog::GetSystemCatalog(context); auto &function_entry = - catalog.GetEntry(context, Identifier::DefaultSchema(), "normalized_interval"); - if (function_entry.functions.Size() != 1) { - throw InternalException("normalized_interval should only have a single overload"); - } - const auto &scalar_function = function_entry.functions.GetFunctionByOffset(0); + catalog.GetEntry(context, Identifier::DefaultSchema(), Identifier(function_name)); + auto source_alias = source->GetAlias(); vector> children; children.push_back(std::move(source)); FunctionBinder function_binder(context); - auto function = function_binder.BindScalarFunction(scalar_function, std::move(children)); - source = std::move(function); - return true; + ErrorData error; + auto function = function_binder.BindScalarFunction(function_entry, std::move(children), error); + if (!function) { + error.Throw(); + } + function->SetAlias(source_alias); + return function; } -bool PushVariantCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, - CollationType) { - if (sql_type.id() != LogicalTypeId::VARIANT) { +//! Pushes a collation into a LIST/ARRAY type by wrapping the source in a list_transform that applies the collation to +//! every element via a lambda. Returns false if the elements do not require collation, or if list_transform is not +//! available (i.e. the core_functions extension is not loaded). +static bool PushNestedCollation(ClientContext &context, unique_ptr &source, const LogicalType &child_type, + CollationType type, const CollationBinding &binding) { + // build a lambda body that applies the collation to the lambda parameter (a reference to the child element) + unique_ptr lambda_body = make_uniq(child_type, idx_t(0)); + if (!binding.PushCollation(context, lambda_body, child_type, type)) { + // the child type does not require collation - nothing to push return false; } + auto &catalog = Catalog::GetSystemCatalog(context); - auto &function_entry = - catalog.GetEntry(context, Identifier::DefaultSchema(), "variant_comparator"); - if (function_entry.functions.Size() != 1) { - throw InternalException("variant_comparator should only have a single overload"); + auto list_transform = catalog.GetEntry(context, Identifier::DefaultSchema(), + "list_transform", OnEntryNotFound::RETURN_NULL); + if (!list_transform) { + // list_transform is not available - cannot push the collation into the list + return false; } - auto source_alias = source->GetAlias(); - const auto &scalar_function = function_entry.functions.GetFunctionByOffset(0); + + auto bound_lambda = + make_uniq(ExpressionType::LAMBDA, LogicalType::LAMBDA, std::move(lambda_body), 1); vector> children; children.push_back(std::move(source)); + children.push_back(std::move(bound_lambda)); FunctionBinder function_binder(context); - auto function = function_binder.BindScalarFunction(scalar_function, std::move(children)); - function->SetAlias(source_alias); + ErrorData error; + auto function = function_binder.BindScalarFunction(*list_transform, std::move(children), error); + if (!function) { + error.Throw(); + } + // the lambda expression is consumed by the bind - remove it from the children + auto &bound_function = function->Cast(); + bound_function.GetChildrenMutable().erase_at(1); source = std::move(function); return true; } -// timetz_byte_comparable -CollationBinding::CollationBinding() { - RegisterCollation(CollationCallback(PushVarcharCollation)); - RegisterCollation(CollationCallback(PushTimeTZCollation)); - RegisterCollation(CollationCallback(PushBitStringCollation)); - RegisterCollation(CollationCallback(PushIntervalCollation)); - RegisterCollation(CollationCallback(PushVariantCollation)); -} - -void CollationBinding::RegisterCollation(CollationCallback callback) { - collations.push_back(callback); -} - bool CollationBinding::PushCollation(ClientContext &context, unique_ptr &source, const LogicalType &sql_type, CollationType type) const { + if (sql_type.id() == LogicalTypeId::LIST) { + return PushNestedCollation(context, source, ListType::GetChildType(sql_type), type, *this); + } + if (sql_type.id() == LogicalTypeId::ARRAY) { + return PushNestedCollation(context, source, ArrayType::GetChildType(sql_type), type, *this); + } for (auto &collation : collations) { - if (collation.try_push_collation(context, source, sql_type, type)) { - // successfully pushed a collation - return true; + auto functions = collation.get_collation_functions(context, sql_type, type); + if (functions.empty()) { + continue; + } + // successfully retrieved the collation functions - apply them to the source expression + for (auto &function_name : functions) { + source = ApplyCollationFunction(context, function_name, std::move(source)); } + return true; } return false; } diff --git a/src/duckdb/src/planner/filter/expression_filter.cpp b/src/duckdb/src/planner/filter/expression_filter.cpp index 5ceafeed3..e6a89658e 100644 --- a/src/duckdb/src/planner/filter/expression_filter.cpp +++ b/src/duckdb/src/planner/filter/expression_filter.cpp @@ -24,6 +24,7 @@ #include "duckdb/storage/statistics/numeric_stats.hpp" #include "duckdb/storage/statistics/struct_stats.hpp" #include "duckdb/storage/statistics/string_stats.hpp" +#include "duckdb/storage/statistics/variant_stats.hpp" #include "duckdb/function/scalar/struct_utils.hpp" namespace duckdb { @@ -275,6 +276,66 @@ static optional_ptr TryGetFilterStats(optional_ptr +TryPrepareVariantComparisonStats(const BaseStatistics &stats, Value &constant, + vector> &owned_stats) { + if (stats.GetType().id() != LogicalTypeId::VARIANT) { + return constant.type().id() == LogicalTypeId::VARIANT ? nullptr : &stats; + } + if (!VariantStats::IsShredded(stats)) { + return nullptr; + } + auto &shredded_stats = VariantStats::GetShreddedStats(stats); + if (!VariantShreddedStats::IsFullyShredded(shredded_stats)) { + return nullptr; + } + auto &typed_stats = VariantStats::GetTypedStats(shredded_stats); + auto &typed_type = typed_stats.GetType(); + if (typed_type.IsNested() || typed_type.id() == LogicalTypeId::VARIANT) { + return nullptr; + } + + if (constant.type().id() == LogicalTypeId::VARIANT) { + constant = VariantValue::GetValue(constant); + } + if (constant.IsNull()) { + return nullptr; + } + + LogicalType comparison_type; + if (!TryGetVariantComparisonStatsType(typed_type, constant.type(), comparison_type)) { + return nullptr; + } + if (!constant.DefaultTryCastAs(comparison_type)) { + return nullptr; + } + if (typed_type == comparison_type) { + return &typed_stats; + } + + auto cast_stats = StatisticsPropagator::TryPropagateCast(typed_stats, typed_type, comparison_type); + if (!cast_stats) { + return nullptr; + } + owned_stats.push_back(std::move(cast_stats)); + return owned_stats.back().get(); +} + static FilterPropagateResult CheckComparisonStatistics(optional_ptr context_p, const BoundFunctionExpression &comp_expr, const BaseStatistics &stats) { @@ -301,11 +362,24 @@ static FilterPropagateResult CheckComparisonStatistics(optional_ptrGetType().id() == LogicalTypeId::VARIANT || + comparison_constant.type().id() == LogicalTypeId::VARIANT; + if (variant_comparison) { + filter_stats = TryPrepareVariantComparisonStats(*filter_stats, comparison_constant, owned_stats); + if (!filter_stats) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } + } if (!filter_stats->CanHaveNoNull()) { + if (variant_comparison) { + return FilterPropagateResult::NO_PRUNING_POSSIBLE; + } return comparison_type == ExpressionType::COMPARE_DISTINCT_FROM ? FilterPropagateResult::FILTER_ALWAYS_TRUE : FilterPropagateResult::FILTER_ALWAYS_FALSE; } - auto result = CheckZonemapAgainstConstants(*filter_stats, comparison_type, array_ptr(&constant, 1)); + auto result = + CheckZonemapAgainstConstants(*filter_stats, comparison_type, array_ptr(&comparison_constant, 1)); if (filter_stats->CanHaveNull()) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } diff --git a/src/duckdb/src/planner/planner.cpp b/src/duckdb/src/planner/planner.cpp index e57f97d0f..427b80794 100644 --- a/src/duckdb/src/planner/planner.cpp +++ b/src/duckdb/src/planner/planner.cpp @@ -153,6 +153,9 @@ void Planner::CreatePlan(SQLStatement &statement) { shared_ptr Planner::PrepareSQLStatement(unique_ptr statement) { auto copied_statement = statement->Copy(); // create a plan of the underlying statement + // set PREPARE binding mode so that $params without supplied values create placeholder slots + // instead of falling back to user variables (user variables serve as defaults at EXECUTE time) + binder->SetBindingMode(BindingMode::PREPARE); CreatePlan(std::move(statement)); // now create the logical prepare auto prepared_data = make_shared_ptr(copied_statement->type); diff --git a/src/duckdb/src/storage/compression/rle.cpp b/src/duckdb/src/storage/compression/rle.cpp index fbaa93e2f..7245dcaca 100644 --- a/src/duckdb/src/storage/compression/rle.cpp +++ b/src/duckdb/src/storage/compression/rle.cpp @@ -245,53 +245,74 @@ void RLEFinalizeCompress(CompressionState &state_p) { //===--------------------------------------------------------------------===// template struct RLEScanState : public SegmentScanState { - explicit RLEScanState(ColumnSegment &segment) { - auto &buffer_manager = BufferManager::GetBufferManager(segment.GetDatabase()); - handle = buffer_manager.Pin(segment.GetBlockHandle()); - entry_pos = 0; - position_in_entry = 0; - rle_count_offset = - UnsafeNumericCast(Load(handle.GetDataMutable() + segment.GetBlockOffset())); - D_ASSERT(rle_count_offset <= segment.GetBlockSize()); + explicit RLEScanState(ColumnSegment &segment) + : handle(BufferManager::GetBufferManager(segment.GetDatabase()).Pin(segment.GetBlockHandle())), entry_pos(0), + position_in_entry(0), + rle_count_offset(UnsafeNumericCast(Load(handle.Ptr() + segment.GetBlockOffset()))), + data_pointer( + reinterpret_cast(handle.Ptr() + segment.GetBlockOffset() + RLEConstants::RLE_HEADER_SIZE)), + index_pointer( + reinterpret_cast(handle.Ptr() + segment.GetBlockOffset() + rle_count_offset)), + max_entry_pos(static_cast(reinterpret_cast(handle.Ptr() + segment.GetBlockSize()) - + reinterpret_cast(index_pointer)) / + static_cast(sizeof(rle_count_t))) { + if (rle_count_offset < RLEConstants::RLE_HEADER_SIZE) { + //! This would make the index_pointer point into a region reserved for the header data + throw IOException("Corrupted RLE segment: rle_count_offset is corrupted"); + } + if (segment.GetBlockOffset() + rle_count_offset > segment.GetBlockSize()) { + //! This would make the index_pointer start outside of the segment + throw IOException("Corrupted RLE segment: rle_count_offset is corrupted"); + } + if ((rle_count_offset - RLEConstants::RLE_HEADER_SIZE) / sizeof(T) > max_entry_pos) { + //! This would make the indexing of the index_pointer[entry_pos] reach outside of the segment + throw IOException("Corrupted RLE segment: rle_count_offset is corrupted"); + } } - inline void SkipInternal(rle_count_t *index_pointer, idx_t skip_count) { + inline void SkipInternal(idx_t skip_count) { while (skip_count > 0) { rle_count_t run_end = index_pointer[entry_pos]; idx_t skip_amount = MinValue(skip_count, run_end - position_in_entry); skip_count -= skip_amount; position_in_entry += skip_amount; - if (ExhaustedRun(index_pointer)) { + if (ExhaustedRun()) { ForwardToNextRun(); } } } void Skip(ColumnSegment &segment, idx_t skip_count) { - auto data = handle.GetDataMutable() + segment.GetBlockOffset(); - auto index_pointer = reinterpret_cast(data + rle_count_offset); - SkipInternal(index_pointer, skip_count); + SkipInternal(skip_count); } inline void ForwardToNextRun() { // handled all entries in this RLE value // move to the next entry entry_pos++; + if (entry_pos > max_entry_pos) { + throw IOException( + "Corrupted RLE segment: index_pointer[entry_pos] would reach outside of the blocks memory"); + } position_in_entry = 0; } - inline bool ExhaustedRun(const rle_count_t *const index_pointer) const { + inline bool ExhaustedRun() { return position_in_entry >= index_pointer[entry_pos]; } BufferHandle handle; idx_t entry_pos; idx_t position_in_entry; - uint32_t rle_count_offset; + const uint32_t rle_count_offset; //! If we are running a filter over the column - the runs that match the filter unsafe_unique_array matching_runs; idx_t matching_run_count = 0; + + const T *data_pointer; + const rle_count_t *index_pointer; + const idx_t max_entry_pos; }; template @@ -326,14 +347,13 @@ static bool CanEmitConstantVector(idx_t position, idx_t run_length, idx_t scan_c } template -static void RLEScanConstant(RLEScanState &scan_state, const rle_count_t *const index_pointer, - const T *const data_pointer, idx_t scan_count, Vector &result) { +static void RLEScanConstant(RLEScanState &scan_state, idx_t scan_count, Vector &result) { result.SetVectorType(VectorType::CONSTANT_VECTOR); FlatVector::SetSize(result, count_t(scan_count)); auto result_data = ConstantVector::GetData(result); - result_data[0] = data_pointer[scan_state.entry_pos]; + result_data[0] = scan_state.data_pointer[scan_state.entry_pos]; scan_state.position_in_entry += scan_count; - if (scan_state.ExhaustedRun(index_pointer)) { + if (scan_state.ExhaustedRun()) { scan_state.ForwardToNextRun(); } } @@ -343,14 +363,10 @@ void RLEScanPartialInternal(ColumnSegment &segment, ColumnScanState &state, idx_ idx_t result_offset) { auto &scan_state = state.scan_state->Cast>(); - const auto data = scan_state.handle.GetDataMutable() + segment.GetBlockOffset(); - const auto data_pointer = reinterpret_cast(data + RLEConstants::RLE_HEADER_SIZE); - const auto index_pointer = reinterpret_cast(data + scan_state.rle_count_offset); - // If we are scanning an entire Vector and it contains only a single run - if (CanEmitConstantVector(scan_state.position_in_entry, index_pointer[scan_state.entry_pos], - scan_count)) { - RLEScanConstant(scan_state, index_pointer, data_pointer, scan_count, result); + if (CanEmitConstantVector(scan_state.position_in_entry, + scan_state.index_pointer[scan_state.entry_pos], scan_count)) { + RLEScanConstant(scan_state, scan_count, result); return; } @@ -358,20 +374,23 @@ void RLEScanPartialInternal(ColumnSegment &segment, ColumnScanState &state, idx_ const idx_t result_end = result_offset + scan_count; while (result_offset < result_end) { - const rle_count_t &run_end = index_pointer[scan_state.entry_pos]; - const idx_t run_count = run_end - scan_state.position_in_entry; - const idx_t remaining = result_end - result_offset; - const idx_t to_write = run_count < remaining ? run_count : remaining; - - const T &element = data_pointer[scan_state.entry_pos]; - std::fill_n(result_data + result_offset, to_write, element); - - result_offset += to_write; - scan_state.position_in_entry += to_write; - - if (to_write != run_count) { + rle_count_t run_end = scan_state.index_pointer[scan_state.entry_pos]; + idx_t run_count = run_end - scan_state.position_in_entry; + idx_t remaining_scan_count = result_end - result_offset; + T element = scan_state.data_pointer[scan_state.entry_pos]; + if (DUCKDB_UNLIKELY(run_count > remaining_scan_count)) { + for (idx_t i = 0; i < remaining_scan_count; i++) { + result_data[result_offset + i] = element; + } + scan_state.position_in_entry += remaining_scan_count; break; } + + for (idx_t i = 0; i < run_count; i++) { + result_data[result_offset + i] = element; + } + + result_offset += run_count; scan_state.ForwardToNextRun(); } } @@ -395,13 +414,10 @@ void RLESelect(ColumnSegment &segment, ColumnScanState &state, idx_t vector_coun const SelectionVector &sel, idx_t sel_count) { auto &scan_state = state.scan_state->Cast>(); - auto data = scan_state.handle.GetDataMutable() + segment.GetBlockOffset(); - auto data_pointer = reinterpret_cast(data + RLEConstants::RLE_HEADER_SIZE); - auto index_pointer = reinterpret_cast(data + scan_state.rle_count_offset); - // If we are scanning an entire Vector and it contains only a single run we don't need to select at all - if (CanEmitConstantVector(scan_state.position_in_entry, index_pointer[scan_state.entry_pos], vector_count)) { - RLEScanConstant(scan_state, index_pointer, data_pointer, vector_count, result); + if (CanEmitConstantVector(scan_state.position_in_entry, scan_state.index_pointer[scan_state.entry_pos], + vector_count)) { + RLEScanConstant(scan_state, vector_count, result); return; } @@ -414,14 +430,14 @@ void RLESelect(ColumnSegment &segment, ColumnScanState &state, idx_t vector_coun throw InternalException("Error in RLESelect - selection vector indices are not ordered"); } // skip forward to the next index - scan_state.SkipInternal(index_pointer, next_idx - prev_idx); + scan_state.SkipInternal(next_idx - prev_idx); // read the element - result_data.WriteValue(data_pointer[scan_state.entry_pos]); + result_data.WriteValue(scan_state.data_pointer[scan_state.entry_pos]); // move the next to the prev prev_idx = next_idx; } // skip the tail - scan_state.SkipInternal(index_pointer, vector_count - prev_idx); + scan_state.SkipInternal(vector_count - prev_idx); } //===--------------------------------------------------------------------===// @@ -432,9 +448,8 @@ void RLEFilter(ColumnSegment &segment, ColumnScanState &state, idx_t vector_coun idx_t &sel_count, const TableFilter &filter, TableFilterState &filter_state) { auto &scan_state = state.scan_state->Cast>(); - auto data = scan_state.handle.GetDataMutable() + segment.GetBlockOffset(); - auto data_pointer = reinterpret_cast(data + RLEConstants::RLE_HEADER_SIZE); - auto index_pointer = reinterpret_cast(data + scan_state.rle_count_offset); + auto data_pointer = const_cast(scan_state.data_pointer); + auto index_pointer = const_cast(scan_state.index_pointer); auto total_run_count = (scan_state.rle_count_offset - RLEConstants::RLE_HEADER_SIZE) / sizeof(T); if (!scan_state.matching_runs) { @@ -512,7 +527,7 @@ void RLEFilter(ColumnSegment &segment, ColumnScanState &state, idx_t vector_coun throw InternalException("Error in RLEFilter - selection vector indices are not ordered"); } // skip forward to the next index - scan_state.SkipInternal(index_pointer, read_idx - prev_idx); + scan_state.SkipInternal(read_idx - prev_idx); prev_idx = read_idx; if (!scan_state.matching_runs[scan_state.entry_pos]) { // this run is filtered out - we don't need to scan it @@ -523,7 +538,7 @@ void RLEFilter(ColumnSegment &segment, ColumnScanState &state, idx_t vector_coun matching_sel.set_index(matching_count++, read_idx); } // skip the tail - scan_state.SkipInternal(index_pointer, vector_count - prev_idx); + scan_state.SkipInternal(vector_count - prev_idx); } // set up the filter result diff --git a/src/duckdb/src/storage/external_file_cache/caching_file_system.cpp b/src/duckdb/src/storage/external_file_cache/caching_file_system.cpp index d5dc2f627..a621b52d9 100644 --- a/src/duckdb/src/storage/external_file_cache/caching_file_system.cpp +++ b/src/duckdb/src/storage/external_file_cache/caching_file_system.cpp @@ -257,8 +257,11 @@ FileBufferHandleGroup CachingFileHandle::Read(const idx_t nr_bytes, const idx_t return FileBufferHandleGroup(); } - // Uncached files are read directly into one contiguous buffer, skipping the per block syscalls and copies - if (!external_file_cache.IsEnabled() || !external_file_cache.ShouldCacheFile(path.path)) { + // Only cache when file metadata is available. + const bool no_validation_metadata = + Validate() && version_tag.empty() && (!last_modified.IsFinite() || last_modified == timestamp_t(0)); + + if (!external_file_cache.IsEnabled() || !external_file_cache.ShouldCacheFile(path.path) || no_validation_metadata) { auto buf = external_file_cache.GetBufferManager().Allocate(MemoryTag::EXTERNAL_FILE_CACHE, nr_bytes); ReadAndRecord(context, buf.GetDataMutable(), nr_bytes, location); vector mem_handles; diff --git a/src/duckdb/src/storage/metadata/metadata_reader.cpp b/src/duckdb/src/storage/metadata/metadata_reader.cpp index b47ca7805..50c5ca836 100644 --- a/src/duckdb/src/storage/metadata/metadata_reader.cpp +++ b/src/duckdb/src/storage/metadata/metadata_reader.cpp @@ -36,7 +36,7 @@ void MetadataReader::ReadData(QueryContext context, data_ptr_t buffer, idx_t rea memcpy(buffer, Ptr(), to_read); read_size -= to_read; buffer += to_read; - offset += read_size; + offset += to_read; } // then move to the next block ReadNextBlock(context); diff --git a/src/duckdb/src/storage/statistics/base_statistics.cpp b/src/duckdb/src/storage/statistics/base_statistics.cpp index 2cdd92fda..1c2f99b30 100644 --- a/src/duckdb/src/storage/statistics/base_statistics.cpp +++ b/src/duckdb/src/storage/statistics/base_statistics.cpp @@ -12,7 +12,8 @@ namespace duckdb { -BaseStatistics::BaseStatistics() : type(LogicalType::INVALID) { +BaseStatistics::BaseStatistics() : type(LogicalType::INVALID), has_null(false), has_no_null(false), distinct_count(0) { + memset(&stats_union, 0, sizeof(stats_union)); } BaseStatistics::BaseStatistics(LogicalType type) { @@ -20,7 +21,10 @@ BaseStatistics::BaseStatistics(LogicalType type) { } void BaseStatistics::Construct(BaseStatistics &stats, LogicalType type) { + stats.has_null = false; + stats.has_no_null = false; stats.distinct_count = 0; + memset(&stats.stats_union, 0, sizeof(stats.stats_union)); stats.type = std::move(type); switch (GetStatsType(stats.type)) { case StatisticsType::LIST_STATS: diff --git a/src/duckdb/src/storage/statistics/geometry_stats.cpp b/src/duckdb/src/storage/statistics/geometry_stats.cpp index f4143c2e8..783f92520 100644 --- a/src/duckdb/src/storage/statistics/geometry_stats.cpp +++ b/src/duckdb/src/storage/statistics/geometry_stats.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/types/vector.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" @@ -240,8 +241,10 @@ static FilterPropagateResult CheckIntersectionFilter(const GeometryStatsData &da return FilterPropagateResult::NO_PRUNING_POSSIBLE; } - // This has been checked before and needs to be true for the checks below to be valid - D_ASSERT(data.extent.HasXY()); + // This has been checked before and needs to be true for the checks below to be valid. + // Note: only one axis needs to be set; an unknown axis is an infinite range that + // intersects everything, so the IntersectsXY/ContainsXY math below stays valid. + D_ASSERT(data.extent.CanPruneXY()); const auto &geom = StringValue::Get(constant); auto extent = GeometryExtent::Empty(); @@ -297,8 +300,23 @@ FilterPropagateResult GeometryStats::CheckZonemap(const BaseStatistics &stats, c return FilterPropagateResult::NO_PRUNING_POSSIBLE; } - const auto lhs_kind = func.GetChildren()[0]->GetExpressionType(); - const auto rhs_kind = func.GetChildren()[1]->GetExpressionType(); + // The column reference may be wrapped in a GEOMETRY -> GEOMETRY cast (e.g. a CRS-erasing cast inserted to match + // the predicate's argument type). Such casts only change CRS metadata, not coordinates, so the bounding box + // remains valid. Look through them when classifying the operands. + auto strip_geometry_cast = [](const Expression &child) -> const Expression * { + if (child.GetExpressionType() == ExpressionType::OPERATOR_CAST) { + auto &cast = child.Cast(); + if (cast.Child().GetReturnType().id() == LogicalTypeId::GEOMETRY) { + return &cast.Child(); + } + } + return &child; + }; + + const auto &lhs = *strip_geometry_cast(*func.GetChildren()[0]); + const auto &rhs = *strip_geometry_cast(*func.GetChildren()[1]); + const auto lhs_kind = lhs.GetExpressionType(); + const auto rhs_kind = rhs.GetExpressionType(); const auto lhs_is_const = lhs_kind == ExpressionType::VALUE_CONSTANT && rhs_kind == ExpressionType::BOUND_REF; const auto rhs_is_const = rhs_kind == ExpressionType::VALUE_CONSTANT && lhs_kind == ExpressionType::BOUND_REF; @@ -309,16 +327,18 @@ FilterPropagateResult GeometryStats::CheckZonemap(const BaseStatistics &stats, c auto &data = GetDataUnsafe(stats); - if (!data.extent.HasXY()) { - // If the extent is empty or unknown, we cannot prune + if (!data.extent.CanPruneXY()) { + // If neither axis is set (the extent is empty or fully unknown), we cannot prune. + // A single known axis is enough: the unknown axis is an infinite range that + // intersects everything, so pruning degrades to the known axis. return FilterPropagateResult::NO_PRUNING_POSSIBLE; } if (lhs_is_const) { - return CheckIntersectionFilter(data, func.GetChildren()[0]->Cast().GetValue()); + return CheckIntersectionFilter(data, lhs.Cast().GetValue()); } if (rhs_is_const) { - return CheckIntersectionFilter(data, func.GetChildren()[1]->Cast().GetValue()); + return CheckIntersectionFilter(data, rhs.Cast().GetValue()); } // Else, no constant argument return FilterPropagateResult::NO_PRUNING_POSSIBLE; diff --git a/src/duckdb/src/storage/statistics/variant_stats.cpp b/src/duckdb/src/storage/statistics/variant_stats.cpp index 0716532a5..51934e22d 100644 --- a/src/duckdb/src/storage/statistics/variant_stats.cpp +++ b/src/duckdb/src/storage/statistics/variant_stats.cpp @@ -786,6 +786,7 @@ unique_ptr VariantStats::PushdownExtract(const BaseStatistics &s if (!index_iter.get().HasChildren()) { break; } + index_iter = index_iter.get().GetChildIndex(0); } auto &shredded_child_stats = *res; diff --git a/src/duckdb/src/storage/table/geo_column_data.cpp b/src/duckdb/src/storage/table/geo_column_data.cpp index 2c67a9e7f..6c5b91d7b 100644 --- a/src/duckdb/src/storage/table/geo_column_data.cpp +++ b/src/duckdb/src/storage/table/geo_column_data.cpp @@ -346,8 +346,15 @@ unique_ptr GeoColumnData::Checkpoint(const RowGroup &row_ auto &partial_block_manager = info.GetPartialBlockManager(); auto checkpoint_state = make_uniq(row_group, *this, partial_block_manager); - auto &old_column_stats = - base_column->GetType().id() == LogicalTypeId::GEOMETRY ? old_stats : base_column->GetStatisticsRef(); + // When the inner column is unshredded, the geometry old_stats are already correct. + // When the inner column is shredded, the base_column has no stats of its own (it is parented to us). + // Shredded columns are always re-written from scratch, and the stats are recomputes, do the empty stats of the + // inner layout type is a correct default in these cases. + unique_ptr shredded_stats; + if (base_column->GetType().id() != LogicalTypeId::GEOMETRY) { + shredded_stats = BaseStatistics::CreateEmpty(base_column->GetType()).ToUnique(); + } + auto &old_column_stats = shredded_stats ? *shredded_stats : old_stats; // Are there any changes? if (!HasAnyChanges()) { @@ -356,17 +363,22 @@ unique_ptr GeoColumnData::Checkpoint(const RowGroup &row_ checkpoint_state->inner_column_state = checkpoint_state->inner_column->Checkpoint(row_group, info, old_column_stats); - if (base_column->GetType().id() == LogicalTypeId::GEOMETRY) { - // Get the stats from the base column. + // Only the specialized (shredded) layouts need to be reinterpreted via GetSpecializedType. + // Both WKB and the legacy SPATIAL format store the full, unshredded geometry, so their stats come + // directly from the column rather than from a specialized layout. + + const auto storage_type = checkpoint_state->storage_type; + if (storage_type == GeometryStorageType::WKB) { + // WKB: the base column carries the geometry stats directly. checkpoint_state->global_stats = checkpoint_state->inner_column_state->GetStatistics(); - } else if (checkpoint_state->storage_type == GeometryStorageType::SPATIAL) { + } else if (storage_type == GeometryStorageType::SPATIAL) { // Legacy spatial storage - we cannot interpret the stats of the old format auto new_stats = checkpoint_state->inner_column_state->GetStatistics(); checkpoint_state->global_stats = GeometryStats::CreateUnknown(type).ToUnique(); checkpoint_state->global_stats->CopyBase(*new_stats); } else { - // Otherwise interpret stats from shredded column - const auto types = Geometry::GetSpecializedType(checkpoint_state->storage_type); + // Shredded storage, interpret stats from shredded column + const auto types = Geometry::GetSpecializedType(storage_type); const auto gtype = types.first; const auto vtype = types.second; diff --git a/src/duckdb/src/storage/table/row_group.cpp b/src/duckdb/src/storage/table/row_group.cpp index c7b6b9dc3..b658a50e4 100644 --- a/src/duckdb/src/storage/table/row_group.cpp +++ b/src/duckdb/src/storage/table/row_group.cpp @@ -1592,8 +1592,7 @@ RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriter &writer) { } } - if (!reused_columns.empty()) { - D_ASSERT(partial_reuse); + if (partial_reuse) { // carry forward the extras for reused columns onto the new row group, so RowGroup::Checkpoint // can look them up via this->per_column_metadata_blocks auto extras = per_column_metadata_blocks.GetBlocksForColumns(reused_columns); diff --git a/src/duckdb/src/storage/table/row_group_reorderer.cpp b/src/duckdb/src/storage/table/row_group_reorderer.cpp index 6ab37e900..9a6ed0f8b 100644 --- a/src/duckdb/src/storage/table/row_group_reorderer.cpp +++ b/src/duckdb/src/storage/table/row_group_reorderer.cpp @@ -267,6 +267,9 @@ OffsetPruningResult RowGroupReorderer::GetOffsetAfterPruning(const OrderByStatis } auto column_stats = partition_stats.partition_row_group->GetColumnStatistics(storage_index); + if (!column_stats) { + return {new_row_offset, 0, leading_null_group_offset}; + } if (null_order == OrderByNullType::NULLS_FIRST && IsNullOnly(*column_stats)) { if (new_row_offset < partition_stats.count) { return {new_row_offset, 0, leading_null_group_offset}; @@ -335,6 +338,10 @@ optional_ptr> RowGroupReorderer::GetRootSegment(RowGroupSe multimap row_group_map; for (auto &row_group : row_groups.SegmentNodes()) { auto stats = row_group.GetNode().GetStatistics(options.column_idx); + if (!stats) { + ambiguous_groups.push_back(row_group); + continue; + } if (IsNullOnly(*stats)) { null_only_groups.push_back(row_group); continue; diff --git a/src/duckdb/src/storage/table/variant/variant_shredding.cpp b/src/duckdb/src/storage/table/variant/variant_shredding.cpp index 25545aadc..c03c3882c 100644 --- a/src/duckdb/src/storage/table/variant/variant_shredding.cpp +++ b/src/duckdb/src/storage/table/variant/variant_shredding.cpp @@ -441,6 +441,10 @@ bool VariantShreddingStats::GetShreddedTypeInternal(const VariantColumnStatsData child_list_t child_types; for (auto &entry : column.field_stats) { auto &child_column = GetColumnStats(entry.second); + if (entry.first.empty()) { + //! Do not include empty field names in the shredded type! + continue; + } LogicalType child_type; if (GetShreddedTypeInternal(child_column, child_type, total_value_count, force_partial)) { child_types.emplace_back(entry.first, child_type); @@ -796,7 +800,18 @@ void VariantColumnData::ShredVariantData(const Vector &input, Vector &output, id for (idx_t i = 0; i < count; i++) { auto input_val = input.GetValue(i); auto roundtripped_val = roundtrip_result.GetValue(i); - if (!ValueOperations::NotDistinctFrom(input_val, roundtripped_val)) { + + Vector input_vec(input_val, count_t(1)); + Vector roundtripped_vec(roundtripped_val, count_t(1)); + + Vector normalized_input(LogicalType::VARIANT(), 1); + Vector normalized_roundtrip(LogicalType::VARIANT(), 1); + VariantNormalizer::Normalize(input_vec, normalized_input); + VariantNormalizer::Normalize(roundtripped_vec, normalized_roundtrip); + + auto normalized_input_value = normalized_input.GetValue(0); + auto normalized_roundtrip_value = normalized_roundtrip.GetValue(0); + if (!ValueOperations::NotDistinctFrom(normalized_input_value, normalized_roundtrip_value)) { throw InternalException("Shredding roundtrip verification failed for row: %d, expected: %s, actual: %s", i, input_val.ToString(), roundtripped_val.ToString()); } diff --git a/src/duckdb/src/storage/table/variant/variant_unshredding.cpp b/src/duckdb/src/storage/table/variant/variant_unshredding.cpp index cfbccf964..307fb9199 100644 --- a/src/duckdb/src/storage/table/variant/variant_unshredding.cpp +++ b/src/duckdb/src/storage/table/variant/variant_unshredding.cpp @@ -1,246 +1,10 @@ -#include "duckdb/common/vector/list_vector.hpp" -#include "duckdb/common/vector/map_vector.hpp" -#include "duckdb/common/vector/variant_vector.hpp" #include "duckdb/common/vector/struct_vector.hpp" -#include "duckdb/storage/table/variant_column_data.hpp" #include "duckdb/common/types/variant.hpp" -#include "duckdb/function/cast/variant/to_variant_fwd.hpp" -#include "duckdb/common/types/variant_value.hpp" -#include "duckdb/common/types/variant_visitor.hpp" -#include "duckdb/function/variant/variant_value_convert.hpp" +#include "duckdb/common/types/variant_iterator.hpp" +#include "duckdb/function/scalar/variant_utils.hpp" namespace duckdb { -template -static VariantValue UnshreddedVariantValue(UnifiedVariantVectorData &input, uint32_t row, uint32_t values_index) { - if (!input.RowIsValid(row)) { - return VariantValue(Value(LogicalTypeId::SQLNULL)); - } - - if (values_index == 0) { - //! 0 is reserved to indicate a missing value - return VariantValue(VariantValueType::MISSING); - } - values_index--; - - auto type_id = input.GetTypeId(row, values_index); - if (!ALLOW_NULL) { - //! We don't expect NULLs at the root, those should have a NULL 'untyped_value_index' - D_ASSERT(type_id != VariantLogicalType::VARIANT_NULL); - } - - if (type_id == VariantLogicalType::OBJECT) { - VariantValue res(VariantValueType::OBJECT); - - auto object_data = VariantUtils::DecodeNestedData(input, row, values_index); - for (idx_t i = 0; i < object_data.child_count; i++) { - auto child_values_index = input.GetValuesIndex(row, object_data.children_idx + i); - auto val = UnshreddedVariantValue(input, row, child_values_index + 1); - - auto keys_index = input.GetKeysIndex(row, object_data.children_idx + i); - auto &key = input.GetKey(row, keys_index); - - res.AddChild(key.GetString(), std::move(val)); - } - return res; - } - if (type_id == VariantLogicalType::ARRAY) { - VariantValue res(VariantValueType::ARRAY); - - auto array_data = VariantUtils::DecodeNestedData(input, row, values_index); - for (idx_t i = 0; i < array_data.child_count; i++) { - auto child_values_index = input.GetValuesIndex(row, array_data.children_idx + i); - auto val = UnshreddedVariantValue(input, row, child_values_index + 1); - - res.AddItem(std::move(val)); - } - return res; - } - auto val = VariantVisitor::Visit(input, row, values_index); - return VariantValue(std::move(val)); -} - -static vector Unshred(UnifiedVariantVectorData &variant, Vector &shredded, idx_t count, - optional_ptr row_sel); - -static vector UnshredTypedLeaf(Vector &typed_value, idx_t count) { - vector res(count); - - for (idx_t i = 0; i < count; i++) { - auto val = typed_value.GetValue(i); - if (val.IsNull()) { - res[i] = VariantValue(Value(LogicalTypeId::SQLNULL)); - } else { - res[i] = VariantValue(std::move(val)); - } - } - return res; -} - -static vector UnshredTypedObject(UnifiedVariantVectorData &variant, Vector &typed_value, idx_t count, - optional_ptr row_sel) { - vector res(count); - - auto &child_types = StructType::GetChildTypes(typed_value.GetType()); - auto &child_entries = StructVector::GetEntries(typed_value); - - D_ASSERT(child_types.size() == child_entries.size()); - - //! First unshred all children - vector> child_values(child_entries.size()); - for (idx_t child_idx = 0; child_idx < child_entries.size(); child_idx++) { - auto &child_entry = child_entries[child_idx]; - child_values[child_idx] = Unshred(variant, child_entry, count, row_sel); - } - - //! Then compose the OBJECT value by combining all the children - auto validity = typed_value.Validity(); - for (idx_t child_idx = 0; child_idx < child_entries.size(); child_idx++) { - auto &child_name = child_types[child_idx].first; - auto &values = child_values[child_idx]; - - for (idx_t i = 0; i < count; i++) { - if (values[i].IsMissing()) { - // struct field is missing - continue; - } - if (!validity.IsValid(i)) { - res[i] = VariantValue(Value(LogicalTypeId::SQLNULL)); - } else if (res[i].IsMissing()) { - res[i] = VariantValue(VariantValueType::OBJECT); - } - if (res[i].IsNull()) { - // struct itself is NULL - continue; - } - auto &obj_value = res[i]; - obj_value.AddChild(child_name.GetIdentifierName(), std::move(values[i])); - } - } - return res; -} - -static vector UnshredTypedArray(UnifiedVariantVectorData &variant, Vector &typed_value, idx_t count, - optional_ptr row_sel) { - auto &child_vector = ListVector::GetChildMutable(typed_value); - - D_ASSERT(typed_value.GetType().id() == LogicalTypeId::LIST); - - auto list_data = typed_value.Values(); - idx_t child_size = 0; - for (uint32_t i = 0; i < count; i++) { - auto entry = list_data[i]; - if (!entry.IsValid()) { - continue; - } - auto &list_entry = entry.GetValue(); - child_size += list_entry.length; - } - idx_t current_offset = 0; - SelectionVector child_sel(child_size); - vector res(count); - for (uint32_t i = 0; i < count; i++) { - auto entry = list_data[i]; - if (!entry.IsValid()) { - res[i] = VariantValue(Value(LogicalType::SQLNULL)); - continue; - } - auto row = row_sel ? static_cast(row_sel->get_index(i)) : i; - auto &list_entry = entry.GetValue(); - for (idx_t j = 0; j < list_entry.length; j++) { - child_sel[current_offset++] = row; - } - } - auto child_values = Unshred(variant, child_vector, child_size, child_sel); - - current_offset = 0; - for (idx_t i = 0; i < count; i++) { - auto entry = list_data[i]; - if (!entry.IsValid()) { - continue; - } - auto &list_entry = entry.GetValue(); - - auto &list_val = res[i]; - list_val = VariantValue(VariantValueType::ARRAY); - list_val.ReserveItems(list_entry.length); - list_val.AddItems(child_values.begin() + static_cast(current_offset), - child_values.begin() + static_cast(current_offset + list_entry.length)); - current_offset += list_entry.length; - } - return res; -} - -static vector UnshredTypedValue(UnifiedVariantVectorData &variant, Vector &typed_value, idx_t count, - optional_ptr row_sel) { - auto &type = typed_value.GetType(); - if (type.id() == LogicalTypeId::STRUCT) { - return UnshredTypedObject(variant, typed_value, count, row_sel); - } else if (type.id() == LogicalTypeId::LIST) { - return UnshredTypedArray(variant, typed_value, count, row_sel); - } else { - D_ASSERT(!type.IsNested()); - return UnshredTypedLeaf(typed_value, count); - } -} - -static vector Unshred(UnifiedVariantVectorData &variant, Vector &shredded, idx_t count, - optional_ptr row_sel) { - reference typed_value_ref(shredded); - optional_ptr untyped_value_index; - if (shredded.GetType().id() == LogicalTypeId::STRUCT) { - // "typed_value", "untyped_value" - auto &child_vectors = StructVector::GetEntries(shredded); - D_ASSERT(shredded.GetType().id() == LogicalTypeId::STRUCT); - D_ASSERT(child_vectors.size() <= 2); - typed_value_ref = child_vectors[VariantColumnData::TYPED_VALUE_INDEX]; - if (child_vectors.size() > 1) { - D_ASSERT(child_vectors.size() == 2); - untyped_value_index = child_vectors[VariantColumnData::UNTYPED_VALUE_INDEX]; - } - } - auto &typed_value = typed_value_ref.get(); - - // unshred the typed variant - auto res = UnshredTypedValue(variant, typed_value, count, row_sel); - - if (!untyped_value_index) { - return res; - } - // if we have any untyped values - unshred them - auto untyped_data = untyped_value_index->Values(); - for (uint32_t i = 0; i < count; i++) { - auto entry = untyped_data[i]; - if (!entry.IsValid()) { - //! NULL untyped_value_index indicates a fully shredded variant - continue; - } - auto value_index = entry.GetValue(); - if (value_index == 0) { - // untyped value index of 0 indicates missing - res[i] = VariantValue(VariantValueType::MISSING); - continue; - } - auto row = row_sel ? static_cast(row_sel->get_index(i)) : i; - auto unshredded = UnshreddedVariantValue(variant, row, value_index); - - if (res[i].IsNull() || res[i].IsMissing()) { - //! No shredded value was produced for this row - either the value was not shredded at all, or it is a - //! shredded object none of whose fields are present in the typed schema. Take the overlay value as-is. - res[i] = std::move(unshredded); - } else { - //! Partial shredding, already has a shredded value that this has to be combined into - D_ASSERT(res[i].value_type == VariantValueType::OBJECT); - D_ASSERT(unshredded.value_type == VariantValueType::OBJECT); - auto object_children = unshredded.TakeObjectChildren(); - for (auto &entry : object_children) { - res[i].AddChild(entry.first, std::move(entry.second)); - } - } - } - return res; -} - void VariantUtils::UnshredVariantData(Vector &input, Vector &output, idx_t count) { D_ASSERT(input.GetType().id() == LogicalTypeId::STRUCT); auto &child_vectors = StructVector::GetEntries(input); @@ -249,12 +13,10 @@ void VariantUtils::UnshredVariantData(Vector &input, Vector &output, idx_t count auto &unshredded = child_vectors[0]; auto &shredded = child_vectors[1]; - RecursiveUnifiedVectorFormat recursive_format; - Vector::RecursiveToUnifiedFormat(unshredded, recursive_format); - UnifiedVariantVectorData variant(recursive_format); - - auto variant_values = Unshred(variant, shredded, count, nullptr); - VariantValue::ToVARIANT(variant_values, output); + //! Traverse the (shredded) variant directly through the iterator and encode it into the canonical + //! unshredded layout - no intermediate vector materialization is required. + VariantIterator state(unshredded, shredded); + ToVariant(state, count, output); } } // namespace duckdb diff --git a/src/duckdb/src/storage/temporary_memory_manager.cpp b/src/duckdb/src/storage/temporary_memory_manager.cpp index 0b7a3841b..a0221c453 100644 --- a/src/duckdb/src/storage/temporary_memory_manager.cpp +++ b/src/duckdb/src/storage/temporary_memory_manager.cpp @@ -194,26 +194,28 @@ idx_t TemporaryMemoryManager::ComputeInitialReservation(const TemporaryMemorySta static void ComputeDerivatives(const vector> &states, const vector &res, vector &der, const idx_t n) { // Cost function takes "throughput" (reservation / size) of each operator as its principal input - double prod_siz = 1; - double prod_res = 1; + double log_throughput_prod = 0; double mat_cost = 0; for (idx_t i = 0; i < n; i++) { auto &state = states[i].get(); const auto resd = static_cast(res[i]); const auto sizd = static_cast(MaxValue(state.GetRemainingSize(), 1)); const auto pend = static_cast(state.GetMaterializationPenalty()); - prod_res *= resd; - prod_siz *= sizd; - mat_cost += pend * (1 - resd / sizd); // Materialization cost: sum of (1 - throughput) + D_ASSERT(resd > 0); + D_ASSERT(resd <= sizd); + const auto throughput = resd / sizd; + log_throughput_prod += std::log(throughput); + mat_cost += pend * (1 - throughput); // Materialization cost: sum of (1 - throughput) } - const double nd = static_cast(n); // n as double for convenience - const double tp_mult = 1 - pow(prod_res / prod_siz, 1 / nd); // Throughput multiplier: 1 - geomean throughputs + const double nd = static_cast(n); // n as double for convenience + const double throughput_geomean = std::exp(log_throughput_prod / nd); + const double tp_mult = 1 - throughput_geomean; // Throughput multiplier: 1 - geomean throughputs // Cost function: materialization cost * (1 - throughput multiplier), but we don't actually need to compute it // here. We need to compute the derivative with respect to every reservation, stored in "der" // Just use https://www.derivative-calculator.net with this (n = 3) to see what's going on // (3 - (a_1/s_1)-(a_2/s_2)-(a_3/s_3))*(1-((a_1/s_1)*(a_2/s_2)*(a_3/s_3))^(1/3)) - const double intermediate = -(pow(prod_res, 1 / nd) * mat_cost) / (nd * pow(prod_siz, 1 / nd)); + const double intermediate = -(throughput_geomean * mat_cost) / nd; for (idx_t i = 0; i < n; i++) { auto &state = states[i].get(); const auto resd = static_cast(res[i]); @@ -257,23 +259,27 @@ idx_t TemporaryMemoryManager::ComputeReservation(const TemporaryMemoryState &tem // Distribute memory in OPTIMIZATION_ITERATIONS idx_t remaining_memory = free_memory; const idx_t optimization_iterations = OPTIMIZATION_ITERATIONS_MULTIPLIER * n; - for (idx_t opt_idx = 0; opt_idx < optimization_iterations; opt_idx++) { - D_ASSERT(remaining_memory != 0); + for (idx_t opt_idx = 0; opt_idx < optimization_iterations && remaining_memory > 0; opt_idx++) { ComputeDerivatives(states, res, der, n); // Find the index of the state with the lowest derivative idx_t min_idx = 0; double min_der = NumericLimits::Maximum(); + bool found_candidate = false; for (i = 0; i < n; i++) { auto &state = states[i].get(); if (res[i] >= state.GetRemainingSize()) { continue; // We can't increase the reservation of "maxed" states, so we skip these } - if (der[i] < min_der) { + if (!found_candidate || der[i] < min_der) { min_idx = i; min_der = der[i]; + found_candidate = true; } } + if (!found_candidate) { + break; + } auto &min_state = states[min_idx].get(); // This is how much memory we will distribute in this round @@ -295,7 +301,6 @@ idx_t TemporaryMemoryManager::ComputeReservation(const TemporaryMemoryState &tem opt_idx--; } } - D_ASSERT(remaining_memory == 0); // We computed how the memory should be assigned to the states, // but we did not yet take into account the upper bound of MAXIMUM_FREE_MEMORY_RATIO * free_memory. diff --git a/src/duckdb/src/transaction/duck_transaction_manager.cpp b/src/duckdb/src/transaction/duck_transaction_manager.cpp index 6edea85e4..4c8906ad9 100644 --- a/src/duckdb/src/transaction/duck_transaction_manager.cpp +++ b/src/duckdb/src/transaction/duck_transaction_manager.cpp @@ -21,7 +21,7 @@ namespace duckdb { -void DuckCleanupInfo::Cleanup() noexcept { +void DuckCleanupInfo::Cleanup() { for (auto &transaction : transactions) { if (transaction->awaiting_cleanup) { transaction->Cleanup(lowest_start_time); diff --git a/src/duckdb/third_party/thrift/thrift/protocol/TProtocol.h b/src/duckdb/third_party/thrift/thrift/protocol/TProtocol.h index 517aceb75..f1bccd43c 100644 --- a/src/duckdb/third_party/thrift/thrift/protocol/TProtocol.h +++ b/src/duckdb/third_party/thrift/thrift/protocol/TProtocol.h @@ -89,6 +89,18 @@ static inline To bitwise_cast(From from) { # define __THRIFT_BYTE_ORDER BYTE_ORDER # define __THRIFT_LITTLE_ENDIAN LITTLE_ENDIAN # define __THRIFT_BIG_ENDIAN BIG_ENDIAN +# elif defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN__) && defined(__ORDER_BIG_ENDIAN__) + // GCC / Clang builtin (macOS, Linux, MinGW, ...). Reliable without relying on system headers happening to have + // defined BYTE_ORDER already. +# define __THRIFT_BYTE_ORDER __BYTE_ORDER__ +# define __THRIFT_LITTLE_ENDIAN __ORDER_LITTLE_ENDIAN__ +# define __THRIFT_BIG_ENDIAN __ORDER_BIG_ENDIAN__ +# elif defined(_WIN32) + // All Windows targets (x86, x64, ARM, ARM64) are little-endian. MSVC does not define BYTE_ORDER, so without this + // we would fall through to the broken default below and byteswap every double on the wire. +# define __THRIFT_BYTE_ORDER 1234 +# define __THRIFT_LITTLE_ENDIAN __THRIFT_BYTE_ORDER +# define __THRIFT_BIG_ENDIAN 0 # else //# include # if BOOST_ENDIAN_BIG_BYTE @@ -106,6 +118,13 @@ static inline To bitwise_cast(From from) { # endif #endif +// Guard against silently falling into the big-endian byteswap path. +// if detection failed above, __THRIFT_BYTE_ORDER and __THRIFT_BIG_ENDIAN both expand to 0 and the comparison below +// would be (0 == 0) -> true, byte-swapping every double. +#if !defined(__THRIFT_BYTE_ORDER) || !defined(__THRIFT_LITTLE_ENDIAN) || !defined(__THRIFT_BIG_ENDIAN) +# error "Could not detect endianness for Thrift; define __THRIFT_BYTE_ORDER explicitly." +#endif + #if __THRIFT_BYTE_ORDER == __THRIFT_BIG_ENDIAN # if !defined(THRIFT_ntohll) # define THRIFT_ntohll(n) (n) diff --git a/src/duckdb/ub_extension_parquet_reader_variant.cpp b/src/duckdb/ub_extension_parquet_reader_variant.cpp index 4e55f0b16..8a5fc94ad 100644 --- a/src/duckdb/ub_extension_parquet_reader_variant.cpp +++ b/src/duckdb/ub_extension_parquet_reader_variant.cpp @@ -1,4 +1,4 @@ -#include "extension/parquet/reader/variant/variant_binary_decoder.cpp" +#include "extension/parquet/reader/variant/parquet_variant_iterator.cpp" -#include "extension/parquet/reader/variant/variant_shredded_conversion.cpp" +#include "extension/parquet/reader/variant/variant_binary_decoder.cpp" diff --git a/src/duckdb/ub_src_common_serializer.cpp b/src/duckdb/ub_src_common_serializer.cpp index 99b04f84b..c72a84f41 100644 --- a/src/duckdb/ub_src_common_serializer.cpp +++ b/src/duckdb/ub_src_common_serializer.cpp @@ -1,5 +1,9 @@ #include "src/common/serializer/async_file_writer.cpp" +#include "src/common/serializer/async_memory_governor.cpp" + +#include "src/common/serializer/async_task_queue.cpp" + #include "src/common/serializer/async_write_queue.cpp" #include "src/common/serializer/binary_deserializer.cpp" diff --git a/src/duckdb/ub_src_common_types_variant.cpp b/src/duckdb/ub_src_common_types_variant.cpp index 3ea81f77f..cb843e56b 100644 --- a/src/duckdb/ub_src_common_types_variant.cpp +++ b/src/duckdb/ub_src_common_types_variant.cpp @@ -2,7 +2,5 @@ #include "src/common/types/variant/variant_iterator.cpp" -#include "src/common/types/variant/variant_value.cpp" - #include "src/common/types/variant/variant_value_convert.cpp" diff --git a/src/duckdb/ub_src_function_scalar_compressed_materialization.cpp b/src/duckdb/ub_src_function_scalar_compressed_materialization.cpp index a160a1843..864d4b392 100644 --- a/src/duckdb/ub_src_function_scalar_compressed_materialization.cpp +++ b/src/duckdb/ub_src_function_scalar_compressed_materialization.cpp @@ -1,3 +1,5 @@ +#include "src/function/scalar/compressed_materialization/compress_geometry.cpp" + #include "src/function/scalar/compressed_materialization/compress_integral.cpp" #include "src/function/scalar/compressed_materialization/compress_string.cpp" From 6d37018738df6615f088294c7f199894170e3e1e Mon Sep 17 00:00:00 2001 From: Alex Kasko Date: Thu, 25 Jun 2026 10:30:05 +0100 Subject: [PATCH 3/3] Tests fixes after types update --- src/test/java/org/duckdb/TestDuckDBJDBC.java | 2 +- src/test/java/org/duckdb/TestMetadata.java | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/test/java/org/duckdb/TestDuckDBJDBC.java b/src/test/java/org/duckdb/TestDuckDBJDBC.java index b8d6277c6..599e743a5 100644 --- a/src/test/java/org/duckdb/TestDuckDBJDBC.java +++ b/src/test/java/org/duckdb/TestDuckDBJDBC.java @@ -1636,7 +1636,7 @@ public static void test_all_types() throws Exception { try { String sql = // TODO: remove excludes - "select * EXCLUDE(time, time_ns, time_tz, timestamp_tz_ns)" + "select * EXCLUDE(time, time_ns, time_tz, timestamp_tz_ns, empty_struct)" + "\n , CASE WHEN time = '24:00:00'::TIME THEN '23:59:59.999999'::TIME ELSE time END AS time" + "\n , CASE WHEN time_ns = '24:00:00'::TIME_NS THEN '23:59:59.999999'::TIME_NS ELSE time_ns END AS time_ns" diff --git a/src/test/java/org/duckdb/TestMetadata.java b/src/test/java/org/duckdb/TestMetadata.java index 7783a6ca1..60202933a 100644 --- a/src/test/java/org/duckdb/TestMetadata.java +++ b/src/test/java/org/duckdb/TestMetadata.java @@ -347,6 +347,7 @@ public static void test_column_metadata() throws Exception { expectedTypes.put("dec_18_6", JDBCType.DECIMAL); expectedTypes.put("dec38_10", JDBCType.DECIMAL); expectedTypes.put("geometry", JDBCType.BLOB); + expectedTypes.put("empty_struct", JDBCType.STRUCT); try (Connection conn = DriverManager.getConnection(JDBC_URL)) { try (Statement stmt = conn.createStatement()) {