From 50636d0a93f97b4fb4af7269109b94517d77fbdf Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Tue, 16 Jun 2026 14:34:39 +0000 Subject: [PATCH] Update vendored DuckDB sources to 68d73b7c3a --- .../core_functions/aggregate/holistic/mad.cpp | 1 + .../aggregate/holistic/quantile.cpp | 6 + .../aggregate/quantile_state.hpp | 80 + .../core_functions/lambda_functions.cpp | 13 +- .../parquet/include/column_writer.hpp | 1 + .../parquet/include/parquet_writer.hpp | 9 +- .../include/writer/list_column_writer.hpp | 1 + .../writer/primitive_column_writer.hpp | 6 + .../include/writer/struct_column_writer.hpp | 1 + .../writer/templated_column_writer.hpp | 3 +- .../extension/parquet/parquet_writer.cpp | 132 +- .../parquet/writer/list_column_writer.cpp | 5 + .../writer/primitive_column_writer.cpp | 117 +- .../parquet/writer/struct_column_writer.cpp | 9 +- .../catalog_entry/table_catalog_entry.cpp | 3 +- src/duckdb/src/common/enum_util.cpp | 36 +- .../src/common/enums/optimizer_type.cpp | 1 + src/duckdb/src/common/file_system.cpp | 8 + src/duckdb/src/common/local_file_system.cpp | 4 + .../common/serializer/async_file_writer.cpp | 366 ++++ .../common/serializer/async_write_queue.cpp | 1485 +++++++++++++++++ .../src/common/vector/shredded_vector.cpp | 4 + src/duckdb/src/execution/join_hashtable.cpp | 18 + .../helper/physical_streaming_sample.cpp | 80 +- .../operator/join/physical_hash_join.cpp | 231 ++- .../operator/scan/physical_table_scan.cpp | 3 +- .../execution/physical_plan/plan_sample.cpp | 55 +- .../aggregate/distributive/minmax.cpp | 13 +- src/duckdb/src/function/cast/struct_cast.cpp | 8 - src/duckdb/src/function/function_list.cpp | 1 - .../scalar/system/aggregate_export.cpp | 74 +- src/duckdb/src/function/table/table_scan.cpp | 3 +- .../function/table/version/pragma_version.cpp | 6 +- .../src/include/duckdb/common/enum_util.hpp | 8 + .../enums/dialect_compatibility_mode.hpp | 17 + .../duckdb/common/enums/optimizer_type.hpp | 1 + .../src/include/duckdb/common/file_system.hpp | 3 + .../duckdb/common/local_file_system.hpp | 1 + .../duckdb/common/primitive_dictionary.hpp | 16 +- .../common/serializer/async_file_writer.hpp | 141 ++ .../common/serializer/async_write_queue.hpp | 468 ++++++ .../duckdb/common/vector/shredded_vector.hpp | 1 + .../duckdb/execution/join_hashtable.hpp | 1 + .../helper/physical_streaming_sample.hpp | 5 +- .../operator/join/join_filter_pushdown.hpp | 29 +- .../function/aggregate/list_aggregate.hpp | 4 +- .../duckdb/function/aggregate_function.hpp | 8 +- .../function/aggregate_state_layout.hpp | 41 +- .../function/scalar/tablefilter_functions.hpp | 10 - .../src/include/duckdb/main/settings.hpp | 23 + .../optimizer/grouping_sets_optimizer.hpp | 35 + .../duckdb/optimizer/sampling_pushdown.hpp | 8 +- .../duckdb/parallel/task_scheduler.hpp | 2 + .../parser/parsed_data/sample_options.hpp | 1 + .../planner/filter/expression_filter.hpp | 2 + .../filter/perfect_hash_join_filter.hpp | 40 - .../planner/filter/table_filter_functions.hpp | 30 +- .../subquery/delim_join_cte_rewriter.hpp | 4 +- .../include/duckdb/planner/table_filter.hpp | 25 +- .../duckdb/storage/buffer/block_handle.hpp | 22 +- .../duckdb/storage/buffer/buffer_pool.hpp | 6 +- .../duckdb/storage/table/scan_state.hpp | 10 +- .../storage/temporary_memory_manager.hpp | 6 + src/duckdb/src/main/config.cpp | 10 +- .../main/settings/autogenerated_settings.cpp | 7 + .../optimizer/build_probe_side_optimizer.cpp | 24 + .../src/optimizer/grouping_sets_optimizer.cpp | 326 ++++ .../src/optimizer/late_materialization.cpp | 5 + src/duckdb/src/optimizer/optimizer.cpp | 9 +- .../optimizer/partial_aggregate_pushdown.cpp | 5 + .../optimizer/pushdown/pushdown_mark_join.cpp | 141 ++ .../src/optimizer/sampling_pushdown.cpp | 52 +- src/duckdb/src/optimizer/unnest_rewriter.cpp | 64 +- src/duckdb/src/parallel/task_executor.cpp | 7 +- src/duckdb/src/parallel/task_scheduler.cpp | 4 + .../src/parser/parsed_data/sample_options.cpp | 3 +- .../binder/query_node/bind_select_node.cpp | 39 + .../src/planner/filter/expression_filter.cpp | 26 +- .../filter/perfect_hash_join_filter.cpp | 36 - .../filter/table_filter_bloom_function.cpp | 8 +- .../planner/filter/table_filter_functions.cpp | 15 +- ...able_filter_perfect_hash_join_function.cpp | 196 --- .../table_filter_prefix_range_function.cpp | 19 +- .../src/planner/operator/logical_get.cpp | 6 +- .../subquery/delim_join_cte_rewriter.cpp | 516 +++++- src/duckdb/src/planner/table_filter_set.cpp | 3 +- .../src/storage/buffer/block_handle.cpp | 18 +- .../src/storage/buffer/block_manager.cpp | 7 +- src/duckdb/src/storage/buffer/buffer_pool.cpp | 57 +- .../storage/compression/numeric_constant.cpp | 11 - .../storage/serialization/serialize_nodes.cpp | 2 + .../src/storage/standard_buffer_manager.cpp | 5 +- src/duckdb/src/storage/table/row_group.cpp | 96 +- src/duckdb/src/storage/table/scan_state.cpp | 30 +- .../src/storage/temporary_memory_manager.cpp | 23 +- src/duckdb/ub_src_common_serializer.cpp | 4 + src/duckdb/ub_src_optimizer.cpp | 2 + src/duckdb/ub_src_planner_filter.cpp | 4 - 98 files changed, 4730 insertions(+), 801 deletions(-) create mode 100644 src/duckdb/src/common/serializer/async_file_writer.cpp create mode 100644 src/duckdb/src/common/serializer/async_write_queue.cpp create mode 100644 src/duckdb/src/include/duckdb/common/enums/dialect_compatibility_mode.hpp create mode 100644 src/duckdb/src/include/duckdb/common/serializer/async_file_writer.hpp create mode 100644 src/duckdb/src/include/duckdb/common/serializer/async_write_queue.hpp create mode 100644 src/duckdb/src/include/duckdb/optimizer/grouping_sets_optimizer.hpp delete mode 100644 src/duckdb/src/include/duckdb/planner/filter/perfect_hash_join_filter.hpp create mode 100644 src/duckdb/src/optimizer/grouping_sets_optimizer.cpp delete mode 100644 src/duckdb/src/planner/filter/perfect_hash_join_filter.cpp delete mode 100644 src/duckdb/src/planner/filter/table_filter_perfect_hash_join_function.cpp diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp index d25802751..779dd4e06 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp @@ -274,6 +274,7 @@ AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const Logical using OP = MedianAbsoluteDeviationOperation; auto fun = QuantileBufferingAggregate(input_type, target_type); fun.SetBindCallback(BindMAD); + fun.SetStructStateExport(QuantileStateLayout); fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); #ifndef DUCKDB_SMALLER_BINARY fun.SetWindowBatchCallback(OP::template Window); diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp index 41bb35d07..dc2f012cd 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp @@ -425,6 +425,7 @@ struct ScalarDiscreteQuantile { using STATE = QuantileState; using OP = QuantileScalarOperation; auto fun = QuantileBufferingAggregate(type, type); + fun.SetStructStateExport(QuantileStateLayout); #ifndef DUCKDB_SMALLER_BINARY fun.SetWindowBatchCallback(OP::Window); fun.SetWindowInitCallback(OP::WindowInit); @@ -442,6 +443,7 @@ struct ScalarDiscreteQuantile { AggregateFunction::StateVoidFinalize, nullptr, nullptr, AggregateFunction::StateDestroy); fun.SetInitLocalStateFinalizeCallback(FlattenedQuantileValues::Init); + fun.SetStructStateExport(QuantileStateLayout>>>); return fun; } }; @@ -452,6 +454,7 @@ struct ListDiscreteQuantile { using STATE = QuantileState; using OP = QuantileListOperation; auto fun = QuantileBufferingAggregate(type, LogicalType::LIST(type)); + fun.SetStructStateExport(QuantileStateLayout); fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); #ifndef DUCKDB_SMALLER_BINARY fun.SetWindowBatchCallback(OP::template Window); @@ -470,6 +473,7 @@ struct ListDiscreteQuantile { AggregateFunction::StateFinalize, nullptr, nullptr, AggregateFunction::StateDestroy); fun.SetInitLocalStateFinalizeCallback(FlattenedQuantileValues::Init); + fun.SetStructStateExport(QuantileStateLayout>>>); return fun; } }; @@ -544,6 +548,7 @@ struct ScalarContinuousQuantile { using STATE = QuantileState; using OP = QuantileScalarOperation; auto fun = QuantileBufferingAggregate(input_type, target_type); + fun.SetStructStateExport(QuantileStateLayout); fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); #ifndef DUCKDB_SMALLER_BINARY fun.SetWindowBatchCallback(OP::template Window); @@ -559,6 +564,7 @@ struct ListContinuousQuantile { using STATE = QuantileState; using OP = QuantileListOperation; auto fun = QuantileBufferingAggregate(input_type, LogicalType::LIST(target_type)); + fun.SetStructStateExport(QuantileStateLayout); fun.SetOrderDependent(AggregateOrderDependent::NOT_ORDER_DEPENDENT); #ifndef DUCKDB_SMALLER_BINARY fun.SetWindowBatchCallback(OP::template Window); diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp index 3bac85bfc..8358e2201 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp @@ -9,8 +9,10 @@ #pragma once #include "core_functions/aggregate/quantile_sort_tree.hpp" +#include "duckdb/common/operator/negate.hpp" #include "duckdb/common/types/data_chunk.hpp" #include "duckdb/common/types/list_segment.hpp" +#include "duckdb/function/aggregate_function.hpp" #include "duckdb/function/aggregate/list_aggregate.hpp" #include "SkipList.h" @@ -326,4 +328,82 @@ struct QuantileState : ListAggState { } }; +//===--------------------------------------------------------------------===// +// Quantile State Export +//===--------------------------------------------------------------------===// +template +inline T QuantileNeg(const T &t) { + return NegateOperator::Operation(t); +} + +//! Restores the sign of a normalized quantile parameter (see QuantileAbs) +template <> +inline Value QuantileNeg(const Value &v) { + const auto &type = v.type(); + switch (type.id()) { + case LogicalTypeId::DECIMAL: { + const auto integral = IntegralValue::Get(v); + const auto width = DecimalType::GetWidth(type); + const auto scale = DecimalType::GetScale(type); + switch (type.InternalType()) { + case PhysicalType::INT16: + return Value::DECIMAL(QuantileNeg(Cast::Operation(integral)), width, scale); + case PhysicalType::INT32: + return Value::DECIMAL(QuantileNeg(Cast::Operation(integral)), width, scale); + case PhysicalType::INT64: + return Value::DECIMAL(QuantileNeg(Cast::Operation(integral)), width, scale); + case PhysicalType::INT128: + return Value::DECIMAL(QuantileNeg(integral), width, scale); + default: + throw InternalException("Unknown DECIMAL type"); + } + } + default: + return Value::DOUBLE(QuantileNeg(v.GetValue())); + } +} + +//! Reconstructs the quantile parameter (e.g. 0.5 or [0.25, 0.75]) from the bind data, so that it can be recorded +//! in the AGGREGATE_STATE type - param_type is the declared type of the (erased) parameter argument +inline Value QuantileParameterValue(const QuantileBindData &bind_data, const LogicalType ¶m_type) { + vector quantiles; + for (auto &q : bind_data.quantiles) { + // the bind data holds the normalized (absolute) quantiles - restore the sign of descending quantiles + quantiles.push_back(bind_data.desc ? QuantileNeg(q.val) : q.val); + } + if (param_type.id() != LogicalTypeId::LIST && param_type.id() != LogicalTypeId::ARRAY) { + D_ASSERT(quantiles.size() == 1); + return quantiles[0]; + } + if (quantiles.empty()) { + return Value::LIST(LogicalType::DOUBLE, std::move(quantiles)); + } + auto child_type = quantiles[0].type(); + return Value::LIST(child_type, std::move(quantiles)); +} + +template >> +AggregateStateLayout QuantileStateLayout(AggregateLayoutInput &input) { + auto &function = input.function; + AggregateStateLayout layout; + if (function.GetReturnType().IsAggregateState()) { + // the function has been modified for state export (see ExportAggregateFunction::SetStateExport) - + // its return type IS the state type already + layout.type = function.GetReturnType(); + } else { + layout.type = LogicalType::LIST(function.GetArguments()[0]); + } + layout.total_state_size = AlignValue(sizeof(STATE)); + layout.field = BuildStateField(); + AggregateStateField::PopulateListFunctions(layout.type, layout.field); + if (function.GetOriginalArguments().size() == 2) { + // the quantile parameter must be a constant at bind time (its argument is erased by BindQuantile) - + // record its value so that re-binding the exported state can supply it + // median and mad have no parameter argument (their binds create the quantile themselves) and skip this + auto &bind_data = input.bind_data->Cast(); + layout.constant_parameters.emplace(1, QuantileParameterValue(bind_data, function.GetOriginalArguments()[1])); + } + return layout; +} + } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/lambda_functions.cpp b/src/duckdb/extension/core_functions/lambda_functions.cpp index 1c18aaad0..2f509f1e9 100644 --- a/src/duckdb/extension/core_functions/lambda_functions.cpp +++ b/src/duckdb/extension/core_functions/lambda_functions.cpp @@ -7,7 +7,8 @@ #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/planner/expression/bound_cast_expression.hpp" #include "duckdb/planner/expression/bound_lambda_expression.hpp" - +#include "duckdb/common/enums/dialect_compatibility_mode.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { //===--------------------------------------------------------------------===// @@ -38,6 +39,11 @@ struct LambdaExecuteInfo { // initialize the data chunks input_chunk.InitializeEmpty(input_types); lambda_chunk.Initialize(Allocator::DefaultAllocator(), result_types); + // Spark Compatibility Mode: zero-based index for lambdas + if (Settings::Get(context) == DialectCompatibilityMode::SPARK) { + // Spark's lambda index parameter is 0-based; default SQL is 1-based + index_offset = 0; + } }; //! The expression executor that executes the lambda expression @@ -48,6 +54,8 @@ struct LambdaExecuteInfo { DataChunk lambda_chunk; //! True, if this lambda expression expects an index vector in the input chunk bool has_index; + //! Added to child_idx to form the value the lambda sees in its index parameter (1 by default). + idx_t index_offset = 1; }; //! A helper struct with information that is specific to the list_filter function @@ -323,7 +331,8 @@ static void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &resul // set the index vector if (info.has_index) { - index_vector.SetValue(elem_cnt, Value::BIGINT(NumericCast(child_idx + 1))); + index_vector.SetValue(elem_cnt, + Value::BIGINT(NumericCast(child_idx + execute_info.index_offset))); } elem_cnt++; diff --git a/src/duckdb/extension/parquet/include/column_writer.hpp b/src/duckdb/extension/parquet/include/column_writer.hpp index db0942e27..00309aaea 100644 --- a/src/duckdb/extension/parquet/include/column_writer.hpp +++ b/src/duckdb/extension/parquet/include/column_writer.hpp @@ -205,6 +205,7 @@ class ColumnWriter { virtual void BeginWrite(ColumnWriterState &state) = 0; virtual void Write(ColumnWriterState &state, Vector &vector, idx_t count) = 0; + virtual void PrepareWrite(ColumnWriterState &state) = 0; virtual void FinalizeWrite(ColumnWriterState &state) = 0; public: diff --git a/src/duckdb/extension/parquet/include/parquet_writer.hpp b/src/duckdb/extension/parquet/include/parquet_writer.hpp index ea297fa6a..2f4646eb7 100644 --- a/src/duckdb/extension/parquet/include/parquet_writer.hpp +++ b/src/duckdb/extension/parquet/include/parquet_writer.hpp @@ -22,7 +22,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/mutex.hpp" #include "duckdb/common/atomic.hpp" -#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/serializer/async_file_writer.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" #include "duckdb/function/copy_function.hpp" #include "parquet_statistics.hpp" @@ -216,7 +216,7 @@ class ParquetWriter { LogicalType GetSQLType(idx_t schema_idx) const { return options.sql_types[schema_idx]; } - BufferedFileWriter &GetWriter() { + AsyncFileWriter &GetWriter() { return *writer; } idx_t FileSize() const { @@ -259,6 +259,9 @@ class ParquetWriter { uint32_t Write(const duckdb_apache::thrift::TBase &object); uint32_t WriteData(const const_data_ptr_t buffer, const uint32_t buffer_size); + unique_ptr PrepareWrite(const duckdb_apache::thrift::TBase &object); + unique_ptr PrepareWriteData(unique_ptr buffer); + uint32_t WriteData(unique_ptr buffer); GeoParquetFileMetadata &GetGeoParquetData(); @@ -292,7 +295,7 @@ class ParquetWriter { ParquetWriterOptions options; shared_ptr encryption_util; - unique_ptr writer; + unique_ptr writer; std::shared_ptr protocol; duckdb_parquet::FileMetaData file_meta_data; std::mutex lock; diff --git a/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp index 21ea168fb..c4db2c097 100644 --- a/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/list_column_writer.hpp @@ -43,6 +43,7 @@ class ListColumnWriter : public ColumnWriter { void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; + void PrepareWrite(ColumnWriterState &state) override; void FinalizeWrite(ColumnWriterState &state) override; idx_t FinalizeSchema(vector &schemas) override; diff --git a/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp index 740830ca3..308a58e27 100644 --- a/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/primitive_column_writer.hpp @@ -13,6 +13,7 @@ #include #include "column_writer.hpp" +#include "duckdb/common/serializer/async_file_writer.hpp" #include "writer/parquet_write_stats.hpp" #include "duckdb/common/serializer/memory_stream.hpp" #include "parquet_statistics.hpp" @@ -29,6 +30,7 @@ class ParquetWriter; class Vector; class WriteStream; struct ParquetColumnSchema; +struct PrimitiveDictionaryTargetData; struct PageInformation { idx_t offset = 0; @@ -48,6 +50,8 @@ struct PageWriteInformation { size_t compressed_size; data_ptr_t compressed_data; AllocatedData compressed_buf; + unique_ptr prepared_header; + unique_ptr prepared_payload; }; class PrimitiveColumnWriterState : public ColumnWriterState { @@ -88,6 +92,7 @@ class PrimitiveColumnWriter : public ColumnWriter { bool vector_can_span_multiple_pages) override; void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; + void PrepareWrite(ColumnWriterState &state) override; void FinalizeWrite(ColumnWriterState &state) override; idx_t FinalizeSchema(vector &schemas) override; @@ -121,6 +126,7 @@ class PrimitiveColumnWriter : public ColumnWriter { //! The number of elements in the dictionary virtual idx_t DictionarySize(PrimitiveColumnWriterState &state_p); void WriteDictionary(PrimitiveColumnWriterState &state, unique_ptr temp_writer, idx_t row_count); + void WriteDictionary(PrimitiveColumnWriterState &state, PrimitiveDictionaryTargetData target_data, idx_t row_count); virtual void FlushDictionary(PrimitiveColumnWriterState &state, ColumnWriterStatistics *stats); void SetParquetStatistics(PrimitiveColumnWriterState &state, duckdb_parquet::ColumnChunk &column); diff --git a/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp index 3deb031a8..6ae6ac459 100644 --- a/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/struct_column_writer.hpp @@ -34,6 +34,7 @@ class StructColumnWriter : public ColumnWriter { void BeginWrite(ColumnWriterState &state) override; void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; + void PrepareWrite(ColumnWriterState &state) override; void FinalizeWrite(ColumnWriterState &state) override; idx_t FinalizeSchema(vector &schemas) override; }; diff --git a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp index cd3477014..0fbf2519e 100644 --- a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp @@ -310,7 +310,8 @@ class StandardColumnWriter : public PrimitiveColumnWriter { }); // flush the dictionary page and add it to the to-be-written pages - WriteDictionary(state, state.dictionary.GetTargetMemoryStream(), state.dictionary.GetSize()); + auto dictionary_size = state.dictionary.GetSize(); + WriteDictionary(state, state.dictionary.TakeTargetData(), dictionary_size); // bloom filter will be queued for writing in ParquetWriter::BufferBloomFilter one level up } diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp index 7bc7e8fd0..b8a282beb 100644 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ b/src/duckdb/extension/parquet/parquet_writer.cpp @@ -9,7 +9,8 @@ #include "parquet_shredding.hpp" #include "resizable_buffer.hpp" #include "duckdb/parser/keyword_helper.hpp" -#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/serializer/async_file_writer.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" #include "duckdb/common/serializer/write_stream.hpp" #include "duckdb/common/string_util.hpp" #include "duckdb/main/client_context.hpp" @@ -85,6 +86,28 @@ class MyTransport : public TTransport { WriteStream &serializer; }; +class ParquetPreparedWriteBuffer : public AsyncWriteBuffer { +public: + explicit ParquetPreparedWriteBuffer(unique_ptr stream_p) : stream(std::move(stream_p)) { + D_ASSERT(stream); + data = stream->GetData(); + size = stream->GetPosition(); + } + + data_ptr_t Ptr() override { + return data; + } + + idx_t Size() const override { + return size; + } + +private: + unique_ptr stream; + data_ptr_t data; + idx_t size; +}; + bool ParquetWriter::TryGetParquetType(const LogicalType &duckdb_type, optional_ptr parquet_type_ptr, bool write_timestamp_as_int96) { Type::type parquet_type; @@ -368,6 +391,39 @@ uint32_t ParquetWriter::WriteData(const const_data_ptr_t buffer, const uint32_t } } +unique_ptr ParquetWriter::PrepareWrite(const duckdb_apache::thrift::TBase &object) { + auto stream = make_uniq(BufferAllocator::Get(context)); + TCompactProtocolFactoryT tproto_factory; + auto stream_protocol = tproto_factory.getProtocol(duckdb_base_std::make_shared(*stream)); + if (options.encryption_config) { + ParquetCrypto::Write(object, *stream_protocol, options.encryption_config->GetFooterKey(), *encryption_util); + } else { + object.write(stream_protocol.get()); + } + return make_uniq(std::move(stream)); +} + +unique_ptr ParquetWriter::PrepareWriteData(unique_ptr buffer) { + if (!options.encryption_config) { + return buffer; + } + + auto required_capacity = + buffer->Size() + ParquetCrypto::LENGTH_BYTES + ParquetCrypto::NONCE_BYTES + ParquetCrypto::TAG_BYTES; + auto stream = make_uniq(BufferAllocator::Get(context), NextPowerOfTwo(required_capacity)); + TCompactProtocolFactoryT tproto_factory; + auto stream_protocol = tproto_factory.getProtocol(duckdb_base_std::make_shared(*stream)); + ParquetCrypto::WriteData(*stream_protocol, buffer->Ptr(), NumericCast(buffer->Size()), + options.encryption_config->GetFooterKey(), *encryption_util); + return make_uniq(std::move(stream)); +} + +uint32_t ParquetWriter::WriteData(unique_ptr buffer) { + auto buffer_size = NumericCast(buffer->Size()); + writer->WriteData(std::move(buffer)); + return buffer_size; +} + static void VerifyUniqueNames(const vector &names) { #ifdef DEBUG unordered_set name_set; @@ -437,8 +493,8 @@ ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, ParquetWrit const vector> &kv_metadata) : context(context), options(std::move(options_p)) { // initialize the file writer - writer = make_uniq(fs, options.file_name.c_str(), - FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW); + writer = make_uniq(context, fs, options.file_name.c_str(), + FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW); if (options.encryption_config) { // Get the encryption util @@ -610,10 +666,10 @@ void ParquetWriter::PrepareRowGroup(ColumnDataCollection &raw_buffer, PreparedRo AnalyzeSchema(raw_buffer, column_writers); bool requires_transform = false; - for (auto &writer_p : column_writers) { - auto &writer = *writer_p; + for (auto &col_writer_p : column_writers) { + auto &col_writer = *col_writer_p; - if (writer.HasTransform()) { + if (col_writer.HasTransform()) { requires_transform = true; break; } @@ -692,6 +748,10 @@ void ParquetWriter::PrepareRowGroup(ColumnDataCollection &raw_buffer, PreparedRo } } + for (idx_t i = 0; i < next; i++) { + col_writers[i].get().PrepareWrite(*write_states[i]); + } + for (auto &write_state : write_states) { states.push_back(std::move(write_state)); } @@ -735,38 +795,42 @@ static void ValidateColumnOffsets(const string &filename, idx_t file_length, con } void ParquetWriter::FlushRowGroup(PreparedRowGroup &prepared) { - lock_guard glock(lock); - auto &row_group = prepared.row_group; - auto &states = prepared.states; - if (states.empty()) { - throw InternalException("Attempting to flush a row group with no rows"); - } - InitializeSchemaFromPreparedRowGroup(prepared); - VerifyPreparedRowGroup(prepared); - row_group.file_offset = NumericCast(writer->GetTotalWritten()); - for (idx_t col_idx = 0; col_idx < states.size(); col_idx++) { - const auto &col_writer = column_writers[col_idx]; - auto write_state = std::move(states[col_idx]); - col_writer->FinalizeWrite(*write_state); - } - // let's make sure all offsets are ay-okay - ValidateColumnOffsets(options.file_name, writer->GetTotalWritten(), row_group); + auto batch_guard = writer->StartBatch(); + { + lock_guard glock(lock); + auto &row_group = prepared.row_group; + auto &states = prepared.states; + if (states.empty()) { + throw InternalException("Attempting to flush a row group with no rows"); + } + InitializeSchemaFromPreparedRowGroup(prepared); + VerifyPreparedRowGroup(prepared); + row_group.file_offset = NumericCast(writer->GetTotalWritten()); + for (idx_t col_idx = 0; col_idx < states.size(); col_idx++) { + const auto &col_writer = column_writers[col_idx]; + auto write_state = std::move(states[col_idx]); + col_writer->FinalizeWrite(*write_state); + } + // let's make sure all offsets are ay-okay + ValidateColumnOffsets(options.file_name, writer->GetTotalWritten(), row_group); - row_group.total_compressed_size = NumericCast(writer->GetTotalWritten()) - row_group.file_offset; - row_group.__isset.total_compressed_size = true; + row_group.total_compressed_size = NumericCast(writer->GetTotalWritten()) - row_group.file_offset; + row_group.__isset.total_compressed_size = true; - if (options.encryption_config) { - const auto row_group_ordinal = file_meta_data.row_groups.size(); - if (row_group_ordinal > std::numeric_limits::max()) { - throw InvalidInputException("RowGroup ordinal exceeds 32767 when encryption enabled"); + if (options.encryption_config) { + const auto row_group_ordinal = file_meta_data.row_groups.size(); + if (row_group_ordinal > std::numeric_limits::max()) { + throw InvalidInputException("RowGroup ordinal exceeds 32767 when encryption enabled"); + } + row_group.ordinal = NumericCast(row_group_ordinal); + row_group.__isset.ordinal = true; } - row_group.ordinal = NumericCast(row_group_ordinal); - row_group.__isset.ordinal = true; - } - // append the row group to the file metadata - file_meta_data.row_groups.push_back(row_group); - file_meta_data.num_rows += row_group.num_rows; + // append the row group to the file metadata + file_meta_data.row_groups.push_back(row_group); + file_meta_data.num_rows += row_group.num_rows; + } + batch_guard.Finish(); } void ParquetWriter::Flush(ColumnDataCollection &buffer, unique_ptr &transform_data) { diff --git a/src/duckdb/extension/parquet/writer/list_column_writer.cpp b/src/duckdb/extension/parquet/writer/list_column_writer.cpp index 1ff891f89..88c5dbf92 100644 --- a/src/duckdb/extension/parquet/writer/list_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/list_column_writer.cpp @@ -157,6 +157,11 @@ void ListColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t c GetChildWriter().Write(*state.child_state, child_list, child_length); } +void ListColumnWriter::PrepareWrite(ColumnWriterState &state_p) { + auto &state = state_p.Cast(); + GetChildWriter().PrepareWrite(*state.child_state); +} + void ListColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { auto &state = state_p.Cast(); GetChildWriter().FinalizeWrite(*state.child_state); diff --git a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp index b4e52d028..3da3271f8 100644 --- a/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/primitive_column_writer.cpp @@ -14,7 +14,7 @@ #include "duckdb/common/limits.hpp" #include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/optional_ptr.hpp" -#include "duckdb/common/serializer/buffered_file_writer.hpp" +#include "duckdb/common/primitive_dictionary.hpp" #include "duckdb/common/serializer/write_stream.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/types/validity_mask.hpp" @@ -30,6 +30,47 @@ using duckdb_parquet::PageType; constexpr const idx_t PrimitiveColumnWriter::MAX_UNCOMPRESSED_PAGE_SIZE; constexpr const idx_t PrimitiveColumnWriter::MAX_UNCOMPRESSED_DICT_PAGE_SIZE; +class ParquetPagePayloadBuffer : public AsyncWriteBuffer { +public: + ParquetPagePayloadBuffer(idx_t size_p, unique_ptr temp_writer_p, AllocatedData compressed_buf_p) + : size(size_p), temp_writer(std::move(temp_writer_p)), compressed_buf(std::move(compressed_buf_p)) { + D_ASSERT(temp_writer || compressed_buf.IsSet()); + } + + data_ptr_t Ptr() override { + if (compressed_buf.IsSet()) { + return compressed_buf.get(); + } + D_ASSERT(temp_writer); + return temp_writer->GetData(); + } + + idx_t Size() const override { + return size; + } + +private: + idx_t size; + unique_ptr temp_writer; + AllocatedData compressed_buf; +}; + +static PageWriteInformation CreateDictionaryPageWriteInformation(idx_t uncompressed_size, idx_t row_count) { + PageWriteInformation write_info; + auto &hdr = write_info.page_header; + hdr.uncompressed_page_size = UnsafeNumericCast(uncompressed_size); + hdr.type = PageType::DICTIONARY_PAGE; + hdr.__isset.dictionary_page_header = true; + + hdr.dictionary_page_header.encoding = Encoding::PLAIN; + hdr.dictionary_page_header.is_sorted = false; + hdr.dictionary_page_header.num_values = UnsafeNumericCast(row_count); + + write_info.write_count = 0; + write_info.max_write_count = 0; + return write_info; +} + PrimitiveColumnWriter::PrimitiveColumnWriter(ParquetWriter &writer, ParquetColumnSchema &&column_schema, vector schema_path) : ColumnWriter(writer, std::move(column_schema), std::move(schema_path)) { @@ -376,22 +417,38 @@ void PrimitiveColumnWriter::SetParquetStatistics(PrimitiveColumnWriterState &sta } } -void PrimitiveColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { +void PrimitiveColumnWriter::PrepareWrite(ColumnWriterState &state_p) { auto &state = state_p.Cast(); - auto &column_chunk = state.row_group.columns[state.col_idx]; - // flush the last page (if any remains) + // Flush the last page and materialize any dictionary page before the row-group flush path. FlushPage(state); + if (HasDictionary(state)) { + FlushDictionary(state, state.stats_state.get()); + } + + for (auto &write_info : state.write_info) { + D_ASSERT(write_info.page_header.uncompressed_page_size > 0); + D_ASSERT(!write_info.prepared_header); + D_ASSERT(!write_info.prepared_payload); + + write_info.prepared_header = writer.PrepareWrite(write_info.page_header); + auto payload_buffer = make_uniq( + write_info.compressed_size, std::move(write_info.temp_writer), std::move(write_info.compressed_buf)); + write_info.prepared_payload = writer.PrepareWriteData(std::move(payload_buffer)); + } +} + +void PrimitiveColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { + auto &state = state_p.Cast(); + auto &column_chunk = state.row_group.columns[state.col_idx]; auto &column_writer = writer.GetWriter(); auto start_offset = column_writer.GetTotalWritten(); - // flush the dictionary if (HasDictionary(state)) { column_chunk.meta_data.statistics.distinct_count = UnsafeNumericCast(DictionarySize(state)); column_chunk.meta_data.statistics.__isset.distinct_count = true; column_chunk.meta_data.dictionary_page_offset = UnsafeNumericCast(column_writer.GetTotalWritten()); column_chunk.meta_data.__isset.dictionary_page_offset = true; - FlushDictionary(state, state.stats_state.get()); } // record the start position of the pages for this column @@ -407,12 +464,13 @@ void PrimitiveColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { column_chunk.meta_data.data_page_offset = UnsafeNumericCast(column_writer.GetTotalWritten()); } D_ASSERT(write_info.page_header.uncompressed_page_size > 0); - auto header_start_offset = column_writer.GetTotalWritten(); - writer.Write(write_info.page_header); + D_ASSERT(write_info.prepared_header); + D_ASSERT(write_info.prepared_payload); // total uncompressed size in the column chunk includes the header size (!) - total_uncompressed_size += column_writer.GetTotalWritten() - header_start_offset; + total_uncompressed_size += write_info.prepared_header->Size(); total_uncompressed_size += write_info.page_header.uncompressed_page_size; - writer.WriteData(write_info.compressed_data, write_info.compressed_size); + writer.WriteData(std::move(write_info.prepared_header)); + writer.WriteData(std::move(write_info.prepared_payload)); } column_chunk.meta_data.total_compressed_size = UnsafeNumericCast(column_writer.GetTotalWritten() - start_offset); @@ -439,26 +497,13 @@ void PrimitiveColumnWriter::WriteDictionary(PrimitiveColumnWriterState &state, u D_ASSERT(temp_writer); D_ASSERT(temp_writer->GetPosition() > 0); - // write the dictionary page header - PageWriteInformation write_info; - // set up the header - auto &hdr = write_info.page_header; - hdr.uncompressed_page_size = UnsafeNumericCast(temp_writer->GetPosition()); - hdr.type = PageType::DICTIONARY_PAGE; - hdr.__isset.dictionary_page_header = true; - - hdr.dictionary_page_header.encoding = Encoding::PLAIN; - hdr.dictionary_page_header.is_sorted = false; - hdr.dictionary_page_header.num_values = UnsafeNumericCast(row_count); - + auto write_info = CreateDictionaryPageWriteInformation(temp_writer->GetPosition(), row_count); write_info.temp_writer = std::move(temp_writer); - write_info.write_count = 0; - write_info.max_write_count = 0; // compress the contents of the dictionary page CompressPage(*write_info.temp_writer, write_info.compressed_size, write_info.compressed_data, write_info.compressed_buf); - hdr.compressed_page_size = UnsafeNumericCast(write_info.compressed_size); + write_info.page_header.compressed_page_size = UnsafeNumericCast(write_info.compressed_size); if (write_info.compressed_buf.IsSet()) { // if the data has been compressed, we no longer need the uncompressed data @@ -470,6 +515,28 @@ void PrimitiveColumnWriter::WriteDictionary(PrimitiveColumnWriterState &state, u state.write_info.insert(state.write_info.begin(), std::move(write_info)); } +void PrimitiveColumnWriter::WriteDictionary(PrimitiveColumnWriterState &state, + PrimitiveDictionaryTargetData target_data, idx_t row_count) { + D_ASSERT(target_data.data.IsSet()); + D_ASSERT(target_data.size > 0); + + auto write_info = CreateDictionaryPageWriteInformation(target_data.size, row_count); + + // compress the contents of the dictionary page + MemoryStream temp_writer(target_data.data.get(), target_data.data.GetSize()); + temp_writer.SetPosition(target_data.size); + CompressPage(temp_writer, write_info.compressed_size, write_info.compressed_data, write_info.compressed_buf); + write_info.page_header.compressed_page_size = UnsafeNumericCast(write_info.compressed_size); + + if (!write_info.compressed_buf.IsSet()) { + D_ASSERT(write_info.compressed_data == target_data.data.get()); + write_info.compressed_buf = std::move(target_data.data); + } + + // insert the dictionary page as the first page to write for this column + state.write_info.insert(state.write_info.begin(), std::move(write_info)); +} + idx_t PrimitiveColumnWriter::FinalizeSchema(vector &schemas) { idx_t schema_idx = schemas.size(); diff --git a/src/duckdb/extension/parquet/writer/struct_column_writer.cpp b/src/duckdb/extension/parquet/writer/struct_column_writer.cpp index fa7a9e3da..d16413a66 100644 --- a/src/duckdb/extension/parquet/writer/struct_column_writer.cpp +++ b/src/duckdb/extension/parquet/writer/struct_column_writer.cpp @@ -115,11 +115,18 @@ void StructColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t } } -void StructColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { +void StructColumnWriter::PrepareWrite(ColumnWriterState &state_p) { auto &state = state_p.Cast(); for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { // we add the null count of the struct to the null count of the children state.child_states[child_idx]->null_count += state_p.null_count; + child_writers[child_idx]->PrepareWrite(*state.child_states[child_idx]); + } +} + +void StructColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { + auto &state = state_p.Cast(); + for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { child_writers[child_idx]->FinalizeWrite(*state.child_states[child_idx]); } } diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp index e5d85fce4..4e6143ea9 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -6,6 +6,7 @@ #include "duckdb/common/exception.hpp" #include "duckdb/common/extra_type_info.hpp" #include "duckdb/main/database.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/parser/constraints/list.hpp" #include "duckdb/parser/expression/cast_expression.hpp" #include "duckdb/parser/parsed_data/create_table_info.hpp" @@ -307,7 +308,7 @@ void TableCatalogEntry::BindUpdateConstraints(Binder &binder, LogicalGet &get, L // we thus need all the columns to be available, hence we check if the update touches any index columns // If the returning keyword is used, we need access to the whole row in case the user requests it. // Therefore switch the update to a delete and insert. - update.update_is_del_and_insert = false; + update.update_is_del_and_insert = Settings::Get(context); TableStorageInfo table_storage_info = GetStorageInfo(context); for (auto index : table_storage_info.index_info) { for (auto &column : update.columns) { diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index fabd9dccc..5547f7de3 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -34,6 +34,7 @@ #include "duckdb/common/enums/debug_verification_mode.hpp" #include "duckdb/common/enums/deprecated_using_key_syntax.hpp" #include "duckdb/common/enums/destroy_buffer_upon.hpp" +#include "duckdb/common/enums/dialect_compatibility_mode.hpp" #include "duckdb/common/enums/expression_type.hpp" #include "duckdb/common/enums/file_compression_type.hpp" #include "duckdb/common/enums/file_glob_options.hpp" @@ -1773,6 +1774,24 @@ DestroyBufferUpon EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetDestroyBufferUponValues(), 3, "DestroyBufferUpon", value)); } +const StringUtil::EnumStringLiteral *GetDialectCompatibilityModeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(DialectCompatibilityMode::NONE), "NONE" }, + { static_cast(DialectCompatibilityMode::SPARK), "SPARK" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(DialectCompatibilityMode value) { + return StringUtil::EnumToString(GetDialectCompatibilityModeValues(), 2, "DialectCompatibilityMode", static_cast(value)); +} + +template<> +DialectCompatibilityMode EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetDialectCompatibilityModeValues(), 2, "DialectCompatibilityMode", value)); +} + const StringUtil::EnumStringLiteral *GetDistinctCountSourceValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(DistinctCountSource::CARDINALITY), "CARDINALITY" }, @@ -3678,19 +3697,20 @@ const StringUtil::EnumStringLiteral *GetOptimizerTypeValues() { { static_cast(OptimizerType::ROW_NUMBER_REWRITER), "ROW_NUMBER_REWRITER" }, { static_cast(OptimizerType::PARTITIONED_EXECUTION), "PARTITIONED_EXECUTION" }, { static_cast(OptimizerType::PARTIAL_AGGREGATE_PUSHDOWN), "PARTIAL_AGGREGATE_PUSHDOWN" }, - { static_cast(OptimizerType::REMOTE_PUSHDOWN), "REMOTE_PUSHDOWN" } + { static_cast(OptimizerType::REMOTE_PUSHDOWN), "REMOTE_PUSHDOWN" }, + { static_cast(OptimizerType::GROUPING_SETS), "GROUPING_SETS" } }; return values; } template<> const char* EnumUtil::ToChars(OptimizerType value) { - return StringUtil::EnumToString(GetOptimizerTypeValues(), 40, "OptimizerType", static_cast(value)); + return StringUtil::EnumToString(GetOptimizerTypeValues(), 41, "OptimizerType", static_cast(value)); } template<> OptimizerType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 40, "OptimizerType", value)); + return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 41, "OptimizerType", value)); } const StringUtil::EnumStringLiteral *GetOrderByColumnTypeValues() { @@ -4742,7 +4762,6 @@ const StringUtil::EnumStringLiteral *GetSelectivityOptionalFilterTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(SelectivityOptionalFilterType::MIN_MAX), "MIN_MAX" }, { static_cast(SelectivityOptionalFilterType::BF), "BF" }, - { static_cast(SelectivityOptionalFilterType::PHJ), "PHJ" }, { static_cast(SelectivityOptionalFilterType::PRF), "PRF" } }; return values; @@ -4750,12 +4769,12 @@ const StringUtil::EnumStringLiteral *GetSelectivityOptionalFilterTypeValues() { template<> const char* EnumUtil::ToChars(SelectivityOptionalFilterType value) { - return StringUtil::EnumToString(GetSelectivityOptionalFilterTypeValues(), 4, "SelectivityOptionalFilterType", static_cast(value)); + return StringUtil::EnumToString(GetSelectivityOptionalFilterTypeValues(), 3, "SelectivityOptionalFilterType", static_cast(value)); } template<> SelectivityOptionalFilterType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSelectivityOptionalFilterTypeValues(), 4, "SelectivityOptionalFilterType", value)); + return static_cast(StringUtil::StringToEnum(GetSelectivityOptionalFilterTypeValues(), 3, "SelectivityOptionalFilterType", value)); } const StringUtil::EnumStringLiteral *GetSequenceInfoValues() { @@ -5478,7 +5497,6 @@ const StringUtil::EnumStringLiteral *GetTableFilterTypeValues() { { static_cast(TableFilterType::LEGACY_DYNAMIC_FILTER), "LEGACY_DYNAMIC_FILTER" }, { static_cast(TableFilterType::EXPRESSION_FILTER), "EXPRESSION_FILTER" }, { static_cast(TableFilterType::LEGACY_BLOOM_FILTER), "LEGACY_BLOOM_FILTER" }, - { static_cast(TableFilterType::LEGACY_PERFECT_HASH_JOIN_FILTER), "LEGACY_PERFECT_HASH_JOIN_FILTER" }, { static_cast(TableFilterType::LEGACY_PREFIX_RANGE_FILTER), "LEGACY_PREFIX_RANGE_FILTER" } }; return values; @@ -5486,12 +5504,12 @@ const StringUtil::EnumStringLiteral *GetTableFilterTypeValues() { template<> const char* EnumUtil::ToChars(TableFilterType value) { - return StringUtil::EnumToString(GetTableFilterTypeValues(), 13, "TableFilterType", static_cast(value)); + return StringUtil::EnumToString(GetTableFilterTypeValues(), 12, "TableFilterType", static_cast(value)); } template<> TableFilterType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTableFilterTypeValues(), 13, "TableFilterType", value)); + return static_cast(StringUtil::StringToEnum(GetTableFilterTypeValues(), 12, "TableFilterType", value)); } const StringUtil::EnumStringLiteral *GetTableFunctionParallelismValues() { diff --git a/src/duckdb/src/common/enums/optimizer_type.cpp b/src/duckdb/src/common/enums/optimizer_type.cpp index adbee2538..7954df68e 100644 --- a/src/duckdb/src/common/enums/optimizer_type.cpp +++ b/src/duckdb/src/common/enums/optimizer_type.cpp @@ -52,6 +52,7 @@ static const DefaultOptimizerType internal_optimizer_types[] = { {"partitioned_execution", OptimizerType::PARTITIONED_EXECUTION}, {"partial_aggregate_pushdown", OptimizerType::PARTIAL_AGGREGATE_PUSHDOWN}, {"remote_pushdown", OptimizerType::REMOTE_PUSHDOWN}, + {"grouping_sets", OptimizerType::GROUPING_SETS}, {nullptr, OptimizerType::INVALID}}; string OptimizerTypeToString(OptimizerType type) { diff --git a/src/duckdb/src/common/file_system.cpp b/src/duckdb/src/common/file_system.cpp index 94e4c3137..4ff6c5b69 100644 --- a/src/duckdb/src/common/file_system.cpp +++ b/src/duckdb/src/common/file_system.cpp @@ -729,6 +729,10 @@ bool FileSystem::IsManuallySet() { return false; } +bool FileSystem::SupportsPositionalWrites(FileHandle &handle) { + return false; +} + unique_ptr FileSystem::OpenCompressedFile(QueryContext context, unique_ptr handle, bool write) { throw NotImplementedException("%s: OpenCompressedFile is not implemented!", GetName()); } @@ -817,6 +821,10 @@ bool FileHandle::CanSeek() { return file_system.CanSeek(); } +bool FileHandle::SupportsPositionalWrites() { + return file_system.SupportsPositionalWrites(*this); +} + FileCompressionType FileHandle::GetFileCompressionType() { return FileCompressionType::UNCOMPRESSED; } diff --git a/src/duckdb/src/common/local_file_system.cpp b/src/duckdb/src/common/local_file_system.cpp index c5452a069..3928f6f2c 100644 --- a/src/duckdb/src/common/local_file_system.cpp +++ b/src/duckdb/src/common/local_file_system.cpp @@ -1845,6 +1845,10 @@ bool LocalFileSystem::CanSeek() { return true; } +bool LocalFileSystem::SupportsPositionalWrites(FileHandle &handle) { + return true; +} + bool LocalFileSystem::OnDiskFile(FileHandle &handle) { return true; } diff --git a/src/duckdb/src/common/serializer/async_file_writer.cpp b/src/duckdb/src/common/serializer/async_file_writer.cpp new file mode 100644 index 000000000..23fddacaf --- /dev/null +++ b/src/duckdb/src/common/serializer/async_file_writer.cpp @@ -0,0 +1,366 @@ +#include "duckdb/common/serializer/async_file_writer.hpp" + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/storage/buffer_manager.hpp" + +#include + +namespace duckdb { + +class CopiedAsyncWriteBuffer : public AsyncWriteBuffer { +public: + CopiedAsyncWriteBuffer(ClientContext &context, idx_t capacity_p) + : data(BufferAllocator::Get(context).Allocate(capacity_p)), capacity(capacity_p) { + } + + data_ptr_t Ptr() override { + return data.get(); + } + + idx_t Size() const override { + return size; + } + + idx_t Remaining() const { + return capacity - size; + } + + void Append(const_data_ptr_t buffer, idx_t append_size) { + D_ASSERT(append_size <= Remaining()); + memcpy(data.get() + size, buffer, append_size); + size += append_size; + } + +private: + AllocatedData data; + idx_t capacity; + idx_t size = 0; +}; + +static ClientContext &RequireClientContext(QueryContext context) { + auto client_context = context.GetClientContext(); + if (!client_context) { + throw InvalidInputException("AsyncFileWriter requires a ClientContext"); + } + return *client_context; +} + +AsyncFileWriter::AsyncFileWriter(QueryContext context_p, FileSystem &fs_p, const string &path_p, + FileOpenFlags open_flags) + : context(context_p), client_context(RequireClientContext(context_p)), fs(fs_p), path(path_p) { + handle = fs.OpenFile(path, open_flags | FileLockType::WRITE_LOCK); + + ManagedAsyncWriteStreamTarget &target = *this; + write_queue = make_uniq(client_context, target); +} + +AsyncFileWriter::~AsyncFileWriter() { + if (!closed && handle) { + try { + Close(); + } catch (...) { + } + } +} + +AsyncFileWriter::BatchGuard::BatchGuard(AsyncFileWriter &writer_p) : writer(writer_p) { + writer->BeginBatch(); +} + +AsyncFileWriter::BatchGuard::BatchGuard(BatchGuard &&other) noexcept : writer(other.writer) { + other.writer = nullptr; +} + +AsyncFileWriter::BatchGuard::~BatchGuard() { + // We would call Finish() here, but that can throw, instead we assert it has been called. + D_ASSERT(Exception::UncaughtException() || !writer); + if (writer) { + writer->LeaveBatch(); + } +} + +void AsyncFileWriter::BatchGuard::Finish() { + if (!writer) { + return; + } + auto &writer_ref = *writer; + writer = nullptr; + auto apply_backpressure = !writer_ref.closed; + writer_ref.LeaveBatch(); + if (apply_backpressure) { + writer_ref.ApplyBackpressure(); + } +} + +idx_t AsyncFileWriter::GetFileSize() { + return GetTotalWritten(); +} + +idx_t AsyncFileWriter::GetTotalWritten() const { + return total_written; +} + +void AsyncFileWriter::WriteData(const_data_ptr_t buffer, idx_t write_size) { + if (write_size == 0) { + return; + } + RethrowTaskError(); + if (closed) { + throw IOException("Cannot write to closed file \"%s\"", path); + } + + // Caller-owned memory cannot outlive async scheduling, so even large const inputs are copied before registration. + if (write_size >= AsyncWriteConfig::COPIED_BUFFER_CAPACITY) { + SealCopiedBuffer(ScheduleMode::DEFER); + auto owned_buffer = make_uniq(client_context, write_size); + owned_buffer->Append(buffer, write_size); + RegisterWrite(std::move(owned_buffer)); + return; + } + + idx_t offset = 0; + while (offset < write_size) { + unique_ptr sealed_buffer; + idx_t sealed_buffer_offset = 0; + if (!copied_buffer) { + copied_buffer_offset = total_written; + copied_buffer = make_uniq(client_context, AsyncWriteConfig::COPIED_BUFFER_CAPACITY); + } + auto append_size = MinValue(write_size - offset, copied_buffer->Remaining()); + copied_buffer->Append(buffer + offset, append_size); + total_written += append_size; + offset += append_size; + if (copied_buffer->Remaining() == 0) { + sealed_buffer_offset = copied_buffer_offset; + sealed_buffer = std::move(copied_buffer); + } + if (sealed_buffer) { + RegisterStagedWrite(std::move(sealed_buffer), sealed_buffer_offset); + } + } +} + +void AsyncFileWriter::WriteData(unique_ptr buffer) { + if (!buffer || buffer->Size() == 0) { + return; + } + RethrowTaskError(); + if (closed) { + throw IOException("Cannot write to closed file \"%s\"", path); + } + if (!write_queue->IsAsync()) { + // Keep the no-async path buffered like BufferedFileWriter instead of turning every owned buffer into a syscall. + WriteDataSynchronously(buffer->Ptr(), buffer->Size()); + return; + } + SealCopiedBuffer(ScheduleMode::DEFER); + RegisterWrite(std::move(buffer)); +} + +void AsyncFileWriter::RegisterWrite(unique_ptr buffer, ScheduleMode schedule_mode) { + RethrowTaskError(); + if (closed) { + throw IOException("Cannot write to closed file \"%s\"", path); + } + + auto write_size = buffer->Size(); + auto offset = total_written; + total_written += write_size; + RegisterWriteInternal(std::move(buffer), offset, schedule_mode); +} + +void AsyncFileWriter::RegisterStagedWrite(unique_ptr buffer, idx_t offset, + ScheduleMode schedule_mode) { + RethrowTaskError(); + if (closed) { + throw IOException("Cannot write to closed file \"%s\"", path); + } + RegisterWriteInternal(std::move(buffer), offset, schedule_mode); +} + +void AsyncFileWriter::RegisterWriteInternal(unique_ptr buffer, idx_t offset, + ScheduleMode schedule_mode) { + write_queue->RegisterWrite(std::move(buffer), offset, schedule_mode); +} + +void AsyncFileWriter::WriteDataSynchronously(data_ptr_t buffer, idx_t write_size) { + auto copied_size = copied_buffer ? copied_buffer->Size() : 0; + if (write_size >= 2 * AsyncWriteConfig::COPIED_BUFFER_CAPACITY - copied_size) { + idx_t copied_prefix = 0; + if (copied_size > 0) { + copied_prefix = copied_buffer->Remaining(); + D_ASSERT(copied_prefix <= write_size); + copied_buffer->Append(buffer, copied_prefix); + total_written += copied_prefix; + SealCopiedBuffer(ScheduleMode::DEFER); + } + auto remaining_size = write_size - copied_prefix; + if (remaining_size > 0) { + auto remaining_offset = total_written; + total_written += remaining_size; + if (SupportsPositionalWrites()) { + Write(buffer + copied_prefix, remaining_size, remaining_offset); + } else { + Write(buffer + copied_prefix, remaining_size); + } + write_queue->ResetNextOffset(total_written); + } + return; + } + + idx_t input_offset = 0; + while (input_offset < write_size) { + unique_ptr sealed_buffer; + idx_t sealed_buffer_offset = 0; + if (!copied_buffer) { + copied_buffer_offset = total_written; + copied_buffer = make_uniq(client_context, AsyncWriteConfig::COPIED_BUFFER_CAPACITY); + } + auto append_size = MinValue(write_size - input_offset, copied_buffer->Remaining()); + copied_buffer->Append(buffer + input_offset, append_size); + total_written += append_size; + input_offset += append_size; + if (copied_buffer->Remaining() == 0) { + sealed_buffer_offset = copied_buffer_offset; + sealed_buffer = std::move(copied_buffer); + } + if (sealed_buffer) { + RegisterStagedWrite(std::move(sealed_buffer), sealed_buffer_offset); + } + } +} + +void AsyncFileWriter::SealCopiedBuffer(ScheduleMode schedule_mode) { + if (!copied_buffer || copied_buffer->Size() == 0) { + return; + } + auto sealed_buffer_offset = copied_buffer_offset; + auto sealed_buffer = std::move(copied_buffer); + RegisterStagedWrite(std::move(sealed_buffer), sealed_buffer_offset, schedule_mode); +} + +AsyncFileWriter::BatchGuard AsyncFileWriter::StartBatch() { + return BatchGuard(*this); +} + +void AsyncFileWriter::SchedulePendingWrites(SchedulePolicy policy) { + if (!write_queue->IsAsync()) { + return; + } + SealCopiedBuffer(ScheduleMode::DEFER); + write_queue->SchedulePendingWrites(policy); +} + +void AsyncFileWriter::BeginBatch() { + write_queue->BeginBatch(); +} + +void AsyncFileWriter::LeaveBatch() noexcept { + write_queue->LeaveBatch(); +} + +bool AsyncFileWriter::SupportsPositionalWrites() { + return handle->SupportsPositionalWrites(); +} + +bool AsyncFileWriter::IsLocalFile() { + auto local_file = fs.IsLocalFileSystem(); + if (!local_file && handle) { + try { + local_file = handle->OnDiskFile(); + } catch (...) { + local_file = false; + } + } + return local_file; +} + +void AsyncFileWriter::Write(data_ptr_t buffer, idx_t size, idx_t offset) { + if (size == 0) { + return; + } + handle->Write(context, buffer, size, offset); +} + +void AsyncFileWriter::Write(data_ptr_t buffer, idx_t size) { + if (size == 0) { + return; + } + handle->Write(context, buffer, size); +} + +void AsyncFileWriter::RethrowTaskError() { + if (write_queue) { + write_queue->RethrowTaskError(); + } +} + +void AsyncFileWriter::Flush() { + WaitAll(); +} + +void AsyncFileWriter::ApplyBackpressure() { + if (!write_queue->IsAsync()) { + return; + } + RethrowTaskError(); + if (write_queue->HasOpenBatch()) { + return; + } + SealCopiedBuffer(ScheduleMode::DEFER); + write_queue->ApplyBackpressure(); +} + +void AsyncFileWriter::WaitAll() { + WaitAllInternal(BatchDrainMode::PRESERVE_BATCH); +} + +void AsyncFileWriter::WaitAllInternal(BatchDrainMode batch_drain_mode) { + if (!write_queue->IsAsync()) { + SealCopiedBuffer(ScheduleMode::DEFER); + RethrowTaskError(); + return; + } + + if (!write_queue->HasError()) { + SealCopiedBuffer(ScheduleMode::DEFER); + } + write_queue->WaitAll(batch_drain_mode); +} + +void AsyncFileWriter::Close() { + if (closed) { + return; + } + try { + if (!write_queue->HasError()) { + SealCopiedBuffer(ScheduleMode::DEFER); + } + write_queue->Close(); + handle->Close(); + handle.reset(); + closed = true; + } catch (...) { + write_queue->ReleaseMemoryReservation(); + throw; + } +} + +void AsyncFileWriter::Sync() { + WaitAll(); + handle->Sync(); +} + +void AsyncFileWriter::Truncate(idx_t size) { + WaitAll(); + handle->Truncate(NumericCast(size)); + total_written = size; + write_queue->ResetNextOffset(total_written); + if (handle->CanSeek() && handle->SeekPosition() > size) { + handle->Seek(size); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/async_write_queue.cpp b/src/duckdb/src/common/serializer/async_write_queue.cpp new file mode 100644 index 000000000..30713e67a --- /dev/null +++ b/src/duckdb/src/common/serializer/async_write_queue.cpp @@ -0,0 +1,1485 @@ +#include "duckdb/common/serializer/async_write_queue.hpp" + +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/exception/http_exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/parallel/task_executor.hpp" +#include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/storage/temporary_memory_manager.hpp" + +#include +#include + +namespace duckdb { + +static ErrorData ErrorDataFromExceptionPtr(std::exception_ptr error_ptr) { + try { + std::rethrow_exception(std::move(error_ptr)); + } catch (const std::exception &ex) { + return ErrorData(ex); + } catch (...) { // LCOV_EXCL_START + return ErrorData("Unknown exception during async write"); + } // LCOV_EXCL_STOP +} + +AsyncWriteRequest::AsyncWriteRequest(unique_ptr payload_p, idx_t offset_p, + AsyncWriteCompletionCallback completion_p) + : payload(std::move(payload_p)), offset(offset_p), completion(std::move(completion_p)) { +} + +idx_t AsyncWriteRequest::Size() const { + return payload ? payload->Size() : 0; +} + +AsyncWriteQueue::PendingRequest::PendingRequest(AsyncWriteRequest request_p) + : request(std::move(request_p)), size(request.Size()) { +} + +idx_t AsyncWriteQueue::PendingRequest::Size() const { + return size; +} + +class AsyncWriteQueueTaskGuard { +public: + explicit AsyncWriteQueueTaskGuard(AsyncWriteQueue &queue_p) : queue(queue_p) { + } + + ~AsyncWriteQueueTaskGuard() { + Finish(); + } + + void SetRequestSize(idx_t request_size_p) { + D_ASSERT(!finished); + request_size = request_size_p; + } + + void Finish() { + if (!finished) { + queue.FinishTask(request_size); + finished = true; + } + } + +private: + AsyncWriteQueue &queue; + idx_t request_size = 0; + bool finished = false; +}; + +class AsyncWriteQueueTask : public BaseExecutorTask { +public: + AsyncWriteQueueTask(AsyncWriteQueue &queue_p, TaskExecutor &executor) : BaseExecutorTask(executor), queue(queue_p) { + } + + ~AsyncWriteQueueTask() override { + if (!started) { + queue.CancelScheduledTask(); + } + } + + void ExecuteTask() override { + started = true; + queue.DrainRequests(); + } + +private: + AsyncWriteQueue &queue; + bool started = false; +}; + +AsyncWriteQueue::AsyncWriteQueue(ClientContext &client_context_p, AsyncWriteTarget &target_p) + : client_context(client_context_p), target(target_p) { + auto &scheduler = TaskScheduler::GetScheduler(client_context); + auto async_threads = NumericCast(scheduler.NumberOfAsyncThreads()); + max_active_tasks = MaxValue(async_threads, 1); + if (async_threads == 0) { + return; + } + executor = make_uniq(client_context, TaskSchedulerType::ASYNC); +} + +AsyncWriteQueue::~AsyncWriteQueue() { + lock_guard guard(lock); + auto drained = pending_requests.empty() && pending_bytes == 0 && in_flight_bytes == 0 && active_tasks == 0 && + pending_tasks == 0 && scheduled_pending_bytes == 0 && pending_task_bytes.empty(); + D_ASSERT(closed || drained); + D_ASSERT(!closed || drained); +} + +bool AsyncWriteQueue::IsAsync() const { + return executor != nullptr; +} + +bool AsyncWriteQueue::HasError() { + return executor && executor->HasError(); +} + +void AsyncWriteQueue::Submit(AsyncWriteRequest request) { + if (!request.payload || request.Size() == 0) { + return; + } + auto request_size = request.Size(); + if (executor && executor->HasError()) { + ErrorData error; + try { + executor->ThrowError(); + } catch (const std::exception &ex) { + error = ErrorData(ex); + } catch (...) { // LCOV_EXCL_START + error = ErrorData("Unknown exception during async write"); + } // LCOV_EXCL_STOP + request.payload.reset(); + CompleteRequest(request, request_size, error); + error.Throw(); + } + if (!executor) { + VerifyOpen(); + WriteRequest(std::move(request)); + return; + } + + { + lock_guard guard(lock); + VerifyOpen(); + pending_requests.emplace_back(std::move(request)); + pending_bytes += request_size; + } + ScheduleTasksInternal(); +} + +idx_t AsyncWriteQueue::PendingBytes() { + lock_guard guard(lock); + return pending_bytes + in_flight_bytes; +} + +idx_t AsyncWriteQueue::TaskByteBudget() const { + return task_byte_budget; +} + +idx_t AsyncWriteQueue::SelectPendingRequestBytes(idx_t skip_bytes) const { + D_ASSERT(skip_bytes <= pending_bytes); + idx_t skipped_bytes = 0; + idx_t selected_bytes = 0; + auto byte_budget = TaskByteBudget(); + D_ASSERT(byte_budget > 0); + + for (auto &request : pending_requests) { + auto request_size = request.Size(); + if (skipped_bytes < skip_bytes) { + skipped_bytes += request_size; + continue; + } + + if (selected_bytes > 0 && selected_bytes + request_size > byte_budget) { + break; + } + selected_bytes += request_size; + if (selected_bytes >= byte_budget) { + break; + } + } + + D_ASSERT(skipped_bytes == skip_bytes); + D_ASSERT(selected_bytes > 0); + return selected_bytes; +} + +void AsyncWriteQueue::ScheduleTasksInternal(bool force) { + if (!executor) { + return; + } + idx_t schedule_count = 0; + deque task_bytes; + { + lock_guard guard(lock); + VerifyOpen(); + idx_t scheduled_bytes = 0; + while (scheduled_pending_bytes + scheduled_bytes < pending_bytes && + active_tasks + schedule_count < max_active_tasks) { + auto selected_bytes = SelectPendingRequestBytes(scheduled_pending_bytes + scheduled_bytes); + auto task_budget = TaskByteBudget(); + auto has_active_task = active_tasks + schedule_count > 0; + if (!force && has_active_task && selected_bytes < task_budget) { + break; + } + schedule_count++; + scheduled_bytes += selected_bytes; + task_bytes.push_back(selected_bytes); + } + active_tasks += schedule_count; + pending_tasks += schedule_count; + scheduled_pending_bytes += scheduled_bytes; + for (auto bytes : task_bytes) { + pending_task_bytes.push_back(bytes); + } + } + for (idx_t task_idx = 0; task_idx < schedule_count; task_idx++) { + unique_ptr task; + try { + task = make_uniq(*this, *executor); + } catch (...) { + CancelScheduledTasks(schedule_count - task_idx); + throw; + } + try { + executor->ScheduleTask(std::move(task)); + } catch (...) { + // The task destructor releases this task's slot. Release the slots for tasks not yet created. + CancelScheduledTasks(schedule_count - task_idx - 1); + throw; + } + } +} + +idx_t AsyncWriteQueue::TakeRequests(deque &requests) { + lock_guard guard(lock); + D_ASSERT(active_tasks > 0); + D_ASSERT(pending_tasks > 0); + D_ASSERT(!pending_task_bytes.empty()); + pending_tasks--; + auto selected_bytes = pending_task_bytes.front(); + pending_task_bytes.pop_front(); + D_ASSERT(scheduled_pending_bytes >= selected_bytes); + scheduled_pending_bytes -= selected_bytes; + if (pending_requests.empty()) { + return 0; + } + + idx_t request_bytes = 0; + while (!pending_requests.empty() && request_bytes < selected_bytes) { + request_bytes += pending_requests.front().Size(); + requests.push_back(std::move(pending_requests.front())); + pending_requests.pop_front(); + } + D_ASSERT(request_bytes == selected_bytes); + D_ASSERT(pending_bytes >= request_bytes); + pending_bytes -= request_bytes; + in_flight_bytes += request_bytes; + return request_bytes; +} + +void AsyncWriteQueue::FinishTask(idx_t task_size) { + lock_guard guard(lock); + D_ASSERT(active_tasks > 0); + active_tasks--; + D_ASSERT(in_flight_bytes >= task_size); + in_flight_bytes -= task_size; +} + +void AsyncWriteQueue::CancelScheduledTask() { + CancelScheduledTasks(1); +} + +void AsyncWriteQueue::CancelScheduledTasks(idx_t task_count) { + if (task_count == 0) { + return; + } + lock_guard guard(lock); + D_ASSERT(active_tasks >= task_count); + D_ASSERT(pending_tasks >= task_count); + active_tasks -= task_count; + pending_tasks -= task_count; + for (idx_t task_idx = 0; task_idx < task_count; task_idx++) { + D_ASSERT(!pending_task_bytes.empty()); + auto task_bytes = pending_task_bytes.back(); + pending_task_bytes.pop_back(); + D_ASSERT(scheduled_pending_bytes >= task_bytes); + scheduled_pending_bytes -= task_bytes; + } +} + +void AsyncWriteQueue::DrainRequests() { + deque requests; + AsyncWriteQueueTaskGuard guard(*this); + auto task_size = TakeRequests(requests); + if (requests.empty()) { + guard.Finish(); + return; + } + guard.SetRequestSize(task_size); + idx_t request_idx = 0; + try { + for (; request_idx < requests.size(); request_idx++) { + WriteRequest(std::move(requests[request_idx].request)); + } + } catch (...) { + auto error_ptr = std::current_exception(); + ErrorData error; + try { + std::rethrow_exception(error_ptr); + } catch (const std::exception &ex) { + error = ErrorData(ex); + } catch (...) { // LCOV_EXCL_START + error = ErrorData("Unknown exception during async write"); + } // LCOV_EXCL_STOP + request_idx++; + for (; request_idx < requests.size(); request_idx++) { + auto &request = requests[request_idx].request; + auto request_size = requests[request_idx].Size(); + request.payload.reset(); + CompleteRequest(request, request_size, error); + } + std::rethrow_exception(error_ptr); + } + guard.Finish(); + ScheduleTasksInternal(); +} + +void AsyncWriteQueue::CompleteRequest(AsyncWriteRequest &request, idx_t size, optional_ptr error) { + if (request.completion) { + request.completion(request.offset, size, error); + } +} + +void AsyncWriteQueue::WriteRequest(AsyncWriteRequest request) { + auto request_size = request.Size(); + ErrorData write_error; + bool has_error = false; + try { + WriteBuffer(request.payload->Ptr(), request_size, request.offset); + } catch (const std::exception &ex) { + write_error = ErrorData(ex); + has_error = true; + } catch (...) { // LCOV_EXCL_START + write_error = ErrorData("Unknown exception during async write"); + has_error = true; + } // LCOV_EXCL_STOP + + request.payload.reset(); + if (has_error) { + CompleteRequest(request, request_size, write_error); + write_error.Throw(); + } + CompleteRequest(request, request_size, nullptr); +} + +void AsyncWriteQueue::WriteBuffer(data_ptr_t buffer, idx_t size, idx_t offset) { + if (size == 0) { + return; + } + try { + target.Write(buffer, size, offset); + } catch (const IOException &ex) { + throw IOException("Async write failed for range [offset=%llu, size=%llu]: %s", offset, size, ex.what()); + } catch (const HTTPException &ex) { + throw HTTPException(Exception::ConstructMessage("Async write failed for range [offset=%llu, size=%llu]: %s", + offset, size, ex.what())); + } +} + +void AsyncWriteQueue::RethrowTaskError() { + if (executor && executor->HasError()) { + executor->ThrowError(); + } +} + +void AsyncWriteQueue::WorkOnPendingTask() { + if (!executor) { + VerifyOpen(); + return; + } + shared_ptr task; + if (!executor->GetTask(task)) { + TaskScheduler::YieldThread(); + return; + } + auto result = task->Execute(TaskExecutionMode::PROCESS_ALL); + D_ASSERT(result != TaskExecutionResult::TASK_BLOCKED); + task.reset(); +} + +void AsyncWriteQueue::Flush() { + bool already_closed; + { + lock_guard guard(lock); + already_closed = closed; + } + if (already_closed) { + RethrowTaskError(); + return; + } + if (!executor) { + RethrowTaskError(); + return; + } + + try { + ScheduleTasksInternal(true); + executor->WorkOnTasks(); + } catch (...) { + try { + executor->WorkOnTasks(); + } catch (...) { + } + throw; + } + RethrowTaskError(); +} + +void AsyncWriteQueue::VerifyDrained() const { + if (!pending_requests.empty() || pending_bytes != 0 || in_flight_bytes != 0 || active_tasks != 0 || + pending_tasks != 0 || scheduled_pending_bytes != 0 || !pending_task_bytes.empty()) { + throw InternalException("AsyncWriteQueue still owns submitted writes"); + } +} + +void AsyncWriteQueue::CancelPendingRequestsAfterFailure(const ErrorData &error) noexcept { + deque requests; + { + lock_guard guard(lock); + D_ASSERT(active_tasks == 0); + D_ASSERT(pending_tasks == 0); + D_ASSERT(in_flight_bytes == 0); + D_ASSERT(scheduled_pending_bytes == 0); + D_ASSERT(pending_task_bytes.empty()); + if (active_tasks != 0 || pending_tasks != 0 || in_flight_bytes != 0 || scheduled_pending_bytes != 0 || + !pending_task_bytes.empty()) { + return; + } + + requests = std::move(pending_requests); + pending_bytes = 0; + closed = true; + } + + for (auto &pending : requests) { + auto request_size = pending.Size(); + auto &request = pending.request; + request.payload.reset(); + try { + CompleteRequest(request, request_size, error); + } catch (...) { + } + } +} + +void AsyncWriteQueue::Close() { + bool already_closed; + { + lock_guard guard(lock); + already_closed = closed; + } + if (already_closed) { + RethrowTaskError(); + return; + } + + try { + Flush(); + } catch (...) { + auto error = std::current_exception(); + auto error_data = ErrorDataFromExceptionPtr(error); + CancelPendingRequestsAfterFailure(error_data); + std::rethrow_exception(error); + } + + lock_guard guard(lock); + VerifyDrained(); + closed = true; +} + +void AsyncWriteQueue::VerifyOpen() const { + if (closed) { + throw InternalException("Cannot use closed AsyncWriteQueue"); + } +} + +ManagedAsyncWriteQueue::PendingWrite::PendingWrite(AsyncWriteRequest request_p) + : request(std::move(request_p)), size(request.Size()) { +} + +idx_t ManagedAsyncWriteQueue::PendingWrite::Size() const { + return size; +} + +ManagedAsyncWriteQueue::ManagedAsyncWriteQueue(ClientContext &client_context_p, AsyncWriteTarget &target_p) + : client_context(client_context_p), target(target_p) { + auto &scheduler = TaskScheduler::GetScheduler(client_context); + auto regular_threads = MaxValue(NumericCast(scheduler.NumberOfThreads()), 1); + auto async_threads = NumericCast(scheduler.NumberOfAsyncThreads()); + max_active_drain_tasks = MaxValue(async_threads, 1); + max_pending_bytes = AsyncWriteConfig::MAX_PENDING_BYTES_PER_THREAD * regular_threads; + min_pending_bytes = MinValue(max_pending_bytes, AsyncWriteConfig::MIN_PENDING_BYTES_PER_THREAD * regular_threads); + + AsyncWriteTarget &async_target = *this; + write_queue = make_uniq(client_context, async_target); + if (write_queue->IsAsync() && max_pending_bytes > 0) { + memory_state = TemporaryMemoryManager::Get(client_context).Register(client_context); + memory_state->SetMinimumReservation(min_pending_bytes); + memory_state->SetZero(); + } +} + +ManagedAsyncWriteQueue::~ManagedAsyncWriteQueue() { + lock_guard guard(lock); + auto drained = pending_writes.empty() && pending_bytes == 0 && external_pending_bytes == 0 && + submitted_bytes == 0 && submitted_requests == 0; + D_ASSERT(closed || drained); + D_ASSERT(!closed || drained); +} + +bool ManagedAsyncWriteQueue::IsAsync() const { + return write_queue->IsAsync(); +} + +bool ManagedAsyncWriteQueue::HasError() { + return write_queue->HasError(); +} + +void ManagedAsyncWriteQueue::RegisterWrite(unique_ptr payload, idx_t offset, + ScheduleMode schedule_mode) { + if (!payload || payload->Size() == 0) { + return; + } + RegisterWrite(AsyncWriteRequest(std::move(payload), offset), schedule_mode); +} + +void ManagedAsyncWriteQueue::RegisterWrite(AsyncWriteRequest request, ScheduleMode schedule_mode) { + RegisterWriteInternal(std::move(request), 0, schedule_mode); +} + +void ManagedAsyncWriteQueue::RegisterAccountedWrite(AsyncWriteRequest request, ScheduleMode schedule_mode) { + auto request_size = request.Size(); + RegisterWriteInternal(std::move(request), request_size, ScheduleMode::DEFER); + if (schedule_mode == ScheduleMode::ALLOW) { + SchedulePendingWrites(SchedulePolicy::FORCE); + } +} + +void ManagedAsyncWriteQueue::AddExternalPendingBytes(idx_t bytes, bool update_memory) { + if (bytes == 0 || !write_queue->IsAsync()) { + return; + } + { + lock_guard guard(lock); + VerifyOpen(); + external_pending_bytes += bytes; + } + if (update_memory) { + UpdateMemoryState(); + } +} + +void ManagedAsyncWriteQueue::DiscardExternalPendingBytes(idx_t bytes) noexcept { + if (bytes == 0 || !write_queue->IsAsync()) { + return; + } + lock_guard guard(lock); + D_ASSERT(external_pending_bytes >= bytes); + if (external_pending_bytes >= bytes) { + external_pending_bytes -= bytes; + } +} + +void ManagedAsyncWriteQueue::RegisterWriteInternal(AsyncWriteRequest request, idx_t accounted_external_bytes, + ScheduleMode schedule_mode) { + if (!request.payload || request.Size() == 0) { + return; + } + RethrowTaskError(); + + auto request_size = request.Size(); + if (!write_queue->IsAsync()) { + VerifyOpen(); + write_queue->Submit(std::move(request)); + return; + } + + AddCompletionAccounting(request); + { + lock_guard guard(lock); + VerifyOpen(); + if (accounted_external_bytes > 0) { + D_ASSERT(external_pending_bytes >= accounted_external_bytes); + external_pending_bytes -= accounted_external_bytes; + } + pending_writes.emplace_back(std::move(request)); + pending_bytes += request_size; + } + UpdateMemoryState(); + if (schedule_mode == ScheduleMode::ALLOW) { + SchedulePendingWrites(); + } +} + +void ManagedAsyncWriteQueue::SchedulePendingWrites(SchedulePolicy policy) { + if (!write_queue->IsAsync()) { + VerifyOpen(); + return; + } + SchedulePendingWritesInternal(policy); +} + +void ManagedAsyncWriteQueue::SchedulePendingWritesInternal(SchedulePolicy policy) { + if (!write_queue->IsAsync()) { + return; + } + + while (true) { + AsyncWriteRequest request; + if (!TakePendingWriteRequest(request, policy)) { + return; + } + write_queue->Submit(std::move(request)); + } +} + +void ManagedAsyncWriteQueue::UpdateMemoryState(MemoryUpdateMode mode) { + (void)mode; + if (!memory_state) { + return; + } + + idx_t current_pending_bytes; + { + lock_guard guard(lock); + current_pending_bytes = TotalPendingBytes(); + } + if (current_pending_bytes == 0) { + return; + } + + auto current_reservation = memory_state->GetReservation(); + while (current_pending_bytes > MinValue(current_reservation, max_pending_bytes)) { + idx_t next_request; + if (memory_request_bytes > current_reservation) { + // TMM did not fully grant the previous request. Keep retrying it on later growth checks. + next_request = memory_request_bytes; + } else if (memory_request_bytes == 0) { + // Grow coarsely and only release on Close(). + // Repeatedly shrinking here would touch shared TMM state on the write-registration hot path. + next_request = min_pending_bytes; + } else if (memory_request_bytes >= max_pending_bytes) { + return; + } else if (memory_request_bytes > max_pending_bytes / 2) { + next_request = max_pending_bytes; + } else { + next_request = memory_request_bytes * 2; + } + next_request = MinValue(MaxValue(next_request, min_pending_bytes), max_pending_bytes); + if (next_request <= memory_request_bytes) { + return; + } + + auto previous_reservation = current_reservation; + memory_state->SetRemainingSizeAndUpdateReservation(client_context, next_request); + memory_request_bytes = next_request; + current_reservation = memory_state->GetReservation(); + if (current_reservation <= previous_reservation) { + return; + } + if (current_reservation < next_request) { + return; + } + } +} + +idx_t ManagedAsyncWriteQueue::BackpressureBudget() { + if (!memory_state) { + return NumericLimits::Maximum(); + } + auto reservation = MinValue(memory_state->GetReservation(), max_pending_bytes); + // If TMM only grants a tiny reservation, do not retain an async backlog. This makes low-memory execution + // behave close to synchronous writes, but automatically allows overlap again if the reservation grows later. + if (reservation < AsyncWriteConfig::REMOTE_COALESCE_THRESHOLD) { + return 0; + } + return reservation; +} + +idx_t ManagedAsyncWriteQueue::DrainTaskByteBudget() const { + return drain_task_byte_budget; +} + +idx_t ManagedAsyncWriteQueue::TotalPendingBytes() const { + return pending_bytes + external_pending_bytes + submitted_bytes; +} + +idx_t ManagedAsyncWriteQueue::SubmittedByteWindow() const { + auto task_budget = DrainTaskByteBudget(); + if (task_budget == 0) { + return NumericLimits::Maximum(); + } + auto max_tasks = MaxValue(max_active_drain_tasks, 1); + if (max_tasks > NumericLimits::Maximum() / task_budget) { + return NumericLimits::Maximum(); + } + return max_tasks * task_budget; +} + +bool ManagedAsyncWriteQueue::TakePendingWriteRequest(AsyncWriteRequest &request, SchedulePolicy policy) { + lock_guard guard(lock); + if (pending_writes.empty()) { + return false; + } + if (policy == SchedulePolicy::THRESHOLD) { + if (submitted_bytes >= SubmittedByteWindow()) { + return false; + } + if (submitted_requests > 0 && pending_bytes < DrainTaskByteBudget()) { + return false; + } + } + + auto request_size = pending_writes.front().Size(); + request = std::move(pending_writes.front().request); + pending_writes.pop_front(); + D_ASSERT(pending_bytes >= request_size); + pending_bytes -= request_size; + submitted_bytes += request_size; + submitted_requests++; + return true; +} + +void ManagedAsyncWriteQueue::AddCompletionAccounting(AsyncWriteRequest &request) { + auto user_completion = request.completion; + request.completion = [this, user_completion](idx_t offset, idx_t size, optional_ptr error) { + CompleteSubmittedWrite(offset, size, error); + if (user_completion) { + user_completion(offset, size, error); + } + }; +} + +void ManagedAsyncWriteQueue::CompleteSubmittedWrite(idx_t offset, idx_t size, optional_ptr error) { + (void)offset; + bool refill = false; + { + lock_guard guard(lock); + D_ASSERT(submitted_requests > 0); + submitted_requests--; + D_ASSERT(submitted_bytes >= size); + submitted_bytes -= size; + refill = !error && !closed && !pending_writes.empty(); + } + if (refill) { + SchedulePendingWritesInternal(); + } +} + +void ManagedAsyncWriteQueue::ApplyBackpressure() { + if (!write_queue->IsAsync()) { + VerifyOpen(); + return; + } + RethrowTaskError(); + UpdateMemoryState(MemoryUpdateMode::FORCE); + SchedulePendingWrites(); + while (true) { + idx_t current_pending_bytes; + { + lock_guard guard(lock); + D_ASSERT(external_pending_bytes == 0); + current_pending_bytes = TotalPendingBytes(); + } + if (current_pending_bytes <= BackpressureBudget()) { + return; + } + SchedulePendingWrites(SchedulePolicy::FORCE); + write_queue->WorkOnPendingTask(); + RethrowTaskError(); + } +} + +void ManagedAsyncWriteQueue::WaitAll() { + bool already_closed; + { + lock_guard guard(lock); + already_closed = closed; + } + if (already_closed) { + RethrowTaskError(); + return; + } + if (!write_queue->IsAsync()) { + RethrowTaskError(); + return; + } + + try { + UpdateMemoryState(MemoryUpdateMode::FORCE); + while (true) { + if (!write_queue->HasError()) { + SchedulePendingWritesInternal(SchedulePolicy::FORCE); + } + write_queue->Flush(); + lock_guard guard(lock); + if (pending_writes.empty() && pending_bytes == 0 && external_pending_bytes == 0 && submitted_bytes == 0 && + submitted_requests == 0) { + break; + } + if (pending_writes.empty() && pending_bytes == 0 && submitted_bytes == 0 && submitted_requests == 0) { + throw InternalException("ManagedAsyncWriteQueue still tracks external pending writes"); + } + } + } catch (...) { + try { + write_queue->Flush(); + } catch (...) { + } + throw; + } + + RethrowTaskError(); +} + +void ManagedAsyncWriteQueue::VerifyDrained() const { + if (!pending_writes.empty() || pending_bytes != 0 || external_pending_bytes != 0 || submitted_bytes != 0 || + submitted_requests != 0) { + throw InternalException("ManagedAsyncWriteQueue still owns registered writes"); + } +} + +void ManagedAsyncWriteQueue::CancelPendingWritesAfterFailure(const ErrorData &error) noexcept { + deque writes; + { + lock_guard guard(lock); + D_ASSERT(submitted_requests == 0); + D_ASSERT(submitted_bytes == 0); + if (submitted_requests != 0 || submitted_bytes != 0) { + return; + } + + writes = std::move(pending_writes); + pending_bytes = 0; + external_pending_bytes = 0; + closed = true; + } + + for (auto &pending : writes) { + auto request_size = pending.Size(); + auto &request = pending.request; + request.payload.reset(); + if (request.completion) { + try { + request.completion(request.offset, request_size, error); + } catch (...) { + } + } + } +} + +void ManagedAsyncWriteQueue::Close() { + bool already_closed; + { + lock_guard guard(lock); + already_closed = closed; + } + if (already_closed) { + RethrowTaskError(); + return; + } + + try { + WaitAll(); + write_queue->Close(); + ReleaseMemoryReservation(); + } catch (...) { + auto error = std::current_exception(); + try { + write_queue->Close(); + } catch (...) { + } + auto error_data = ErrorDataFromExceptionPtr(error); + CancelPendingWritesAfterFailure(error_data); + try { + ReleaseMemoryReservation(); + } catch (...) { + } + std::rethrow_exception(error); + } + + lock_guard guard(lock); + VerifyDrained(); + closed = true; +} + +void ManagedAsyncWriteQueue::ReleaseMemoryReservation() { + if (!memory_state || memory_request_bytes == 0) { + return; + } + memory_state->SetZero(); + memory_request_bytes = 0; +} + +void ManagedAsyncWriteQueue::RethrowTaskError() { + write_queue->RethrowTaskError(); +} + +void ManagedAsyncWriteQueue::VerifyOpen() const { + if (closed) { + throw InternalException("Cannot use closed ManagedAsyncWriteQueue"); + } +} + +void ManagedAsyncWriteQueue::Write(data_ptr_t buffer, idx_t size, idx_t offset) { + if (size == 0) { + return; + } + target.Write(buffer, size, offset); +} + +ManagedAsyncWriteStreamQueue::PendingWrite::PendingWrite(unique_ptr payload_p, idx_t offset_p) + : payload(std::move(payload_p)), offset(offset_p) { +} + +idx_t ManagedAsyncWriteStreamQueue::PendingWrite::Size() const { + return payload->Size(); +} + +class ManagedAsyncWriteStreamQueue::CoalescedWritePayload : public AsyncWritePayload { +public: + CoalescedWritePayload(ClientContext &client_context_p, deque writes_p, idx_t size_p) + : client_context(client_context_p), writes(std::move(writes_p)), size(size_p) { + } + + data_ptr_t Ptr() override { + if (writes.size() == 1) { + return writes.front().payload->Ptr(); + } + if (!coalesced.get()) { + coalesced = BufferAllocator::Get(client_context).Allocate(size); + idx_t offset = 0; + for (auto &write : writes) { + auto current_size = write.Size(); + memcpy(coalesced.get() + offset, write.payload->Ptr(), current_size); + offset += current_size; + write.payload.reset(); + } + D_ASSERT(offset == size); + writes.clear(); + } + return coalesced.get(); + } + + idx_t Size() const override { + return size; + } + +private: + ClientContext &client_context; + deque writes; + AllocatedData coalesced; + idx_t size; +}; + +ManagedAsyncWriteStreamQueue::ManagedAsyncWriteStreamQueue(ClientContext &client_context_p, + ManagedAsyncWriteStreamTarget &target_p) + : client_context(client_context_p), target(target_p) { + auto local_file = target.IsLocalFile(); + coalesce_threshold = + local_file ? AsyncWriteConfig::LOCAL_COALESCE_THRESHOLD : AsyncWriteConfig::REMOTE_COALESCE_THRESHOLD; + first_task_schedule_threshold = local_file ? 1 : coalesce_threshold; + drain_task_byte_budget = MaxValue(AsyncWriteConfig::DRAIN_TASK_BYTE_BUDGET, coalesce_threshold); + limit_coalesced_write_size = local_file; + + auto &scheduler = TaskScheduler::GetScheduler(client_context); + auto async_threads = NumericCast(scheduler.NumberOfAsyncThreads()); + + // Positional writes let multiple async requests drain one logical write queue concurrently. + // Otherwise the stream queue keeps one sequential request active so target ordering remains correct. + if (target.SupportsPositionalWrites()) { + drain_mode = DrainMode::POSITIONAL; + max_active_drain_tasks = MaxValue(async_threads, 1); + } + + AsyncWriteTarget &async_target = *this; + write_queue = make_uniq(client_context, async_target); +} + +ManagedAsyncWriteStreamQueue::~ManagedAsyncWriteStreamQueue() { + lock_guard guard(lock); + auto drained = batch_depth == 0 && pending_writes.empty() && pending_bytes == 0 && submitted_bytes == 0 && + submitted_requests == 0; + D_ASSERT(closed || drained); + D_ASSERT(!closed || drained); +} + +bool ManagedAsyncWriteStreamQueue::IsAsync() const { + return write_queue->IsAsync(); +} + +bool ManagedAsyncWriteStreamQueue::HasError() { + return write_queue->HasError(); +} + +void ManagedAsyncWriteStreamQueue::RegisterWrite(unique_ptr payload, idx_t offset, + ScheduleMode schedule_mode) { + if (!payload || payload->Size() == 0) { + return; + } + RethrowTaskError(); + + auto write_size = payload->Size(); + if (!write_queue->IsAsync()) { + VerifyOpen(); + auto next_offset = ValidateRegistrationOffset(offset, write_size); + Write(payload->Ptr(), write_size, offset); + next_registration_offset = next_offset; + return; + } + + // Completion-driven refills may schedule pending_writes as soon as they are visible. + write_queue->AddExternalPendingBytes(write_size, false); + bool inserted = false; + bool update_memory = true; + try { + { + lock_guard guard(lock); + VerifyOpen(); + auto next_offset = ValidateRegistrationOffset(offset, write_size); + pending_writes.emplace_back(std::move(payload), offset); + pending_bytes += write_size; + next_registration_offset = next_offset; + update_memory = batch_depth == 0; + inserted = true; + } + } catch (...) { + if (!inserted) { + write_queue->DiscardExternalPendingBytes(write_size); + } + throw; + } + if (update_memory) { + write_queue->UpdateMemoryState(ManagedAsyncWriteQueue::MemoryUpdateMode::COARSE); + } + if (schedule_mode == ScheduleMode::ALLOW) { + SchedulePendingWrites(); + } +} + +void ManagedAsyncWriteStreamQueue::BeginBatch() { + if (!write_queue->IsAsync()) { + VerifyOpen(); + return; + } + lock_guard guard(lock); + VerifyOpen(); + batch_depth++; +} + +void ManagedAsyncWriteStreamQueue::LeaveBatch() noexcept { + if (!write_queue->IsAsync()) { + return; + } + lock_guard guard(lock); + if (batch_depth == 0) { + return; + } + batch_depth--; +} + +bool ManagedAsyncWriteStreamQueue::HasOpenBatch() { + if (!write_queue->IsAsync()) { + return false; + } + lock_guard guard(lock); + return batch_depth > 0; +} + +void ManagedAsyncWriteStreamQueue::SchedulePendingWrites(SchedulePolicy policy) { + if (!write_queue->IsAsync()) { + VerifyOpen(); + return; + } + SchedulePendingWritesInternal(policy); +} + +void ManagedAsyncWriteStreamQueue::SchedulePendingWritesInternal(SchedulePolicy policy) { + if (!write_queue->IsAsync()) { + return; + } + + while (true) { + AsyncWriteRequest request; + if (!TakePendingWriteRequest(request, policy)) { + return; + } + write_queue->RegisterAccountedWrite(std::move(request)); + } +} + +idx_t ManagedAsyncWriteStreamQueue::DrainTaskByteBudget() const { + return MaxValue(drain_task_byte_budget, coalesce_threshold); +} + +idx_t ManagedAsyncWriteStreamQueue::TotalPendingBytes() const { + return pending_bytes + submitted_bytes; +} + +idx_t ManagedAsyncWriteStreamQueue::SelectPendingWriteEnd(idx_t start, idx_t &selected_bytes) const { + D_ASSERT(start < pending_writes.size()); + auto byte_budget = DrainTaskByteBudget(); + selected_bytes = 0; + idx_t end = start; + while (end < pending_writes.size()) { + auto write_size = pending_writes[end].Size(); + if (selected_bytes > 0 && selected_bytes + write_size > byte_budget) { + break; + } + selected_bytes += write_size; + end++; + if (selected_bytes >= byte_budget) { + break; + } + } + D_ASSERT(end > start); + return end; +} + +idx_t ManagedAsyncWriteStreamQueue::SelectPhysicalWriteEnd(idx_t start, idx_t &selected_bytes) const { + D_ASSERT(start < pending_writes.size()); + selected_bytes = 0; + if (coalesce_threshold == 0) { + selected_bytes = pending_writes[start].Size(); + return start + 1; + } + + auto write_size = pending_writes[start].Size(); + if (write_size >= coalesce_threshold) { + selected_bytes = write_size; + return start + 1; + } + + auto write_offset = pending_writes[start].offset; + if (limit_coalesced_write_size) { + idx_t end = start; + while (end < pending_writes.size()) { + auto next_size = pending_writes[end].Size(); + if (next_size >= coalesce_threshold || selected_bytes + next_size > coalesce_threshold) { + break; + } + VerifyContiguousWrite(pending_writes[end], write_offset + selected_bytes); + selected_bytes += next_size; + end++; + } + D_ASSERT(end > start); + return end; + } + + idx_t selected_budget_bytes; + auto budget_end = SelectPendingWriteEnd(start, selected_budget_bytes); + idx_t small_run_size = 0; + idx_t small_run_end = start; + while (small_run_end < budget_end && pending_writes[small_run_end].Size() < coalesce_threshold) { + VerifyContiguousWrite(pending_writes[small_run_end], write_offset + small_run_size); + small_run_size += pending_writes[small_run_end].Size(); + small_run_end++; + } + + idx_t end = start; + while (end < small_run_end) { + auto next_size = pending_writes[end].Size(); + selected_bytes += next_size; + end++; + auto remaining_after_next = small_run_size - selected_bytes; + if (selected_bytes >= coalesce_threshold && + (remaining_after_next == 0 || remaining_after_next >= coalesce_threshold)) { + break; + } + } + D_ASSERT(end > start); + return end; +} + +idx_t ManagedAsyncWriteStreamQueue::SubmittedByteWindow() const { + auto task_budget = DrainTaskByteBudget(); + if (task_budget == 0) { + return NumericLimits::Maximum(); + } + auto max_tasks = MaxValue(max_active_drain_tasks, 1); + if (max_tasks > NumericLimits::Maximum() / task_budget) { + return NumericLimits::Maximum(); + } + return max_tasks * task_budget; +} + +bool ManagedAsyncWriteStreamQueue::TakePendingWriteRequest(AsyncWriteRequest &request, SchedulePolicy policy) { + AsyncWriteCompletionCallback completion = [this](idx_t offset, idx_t size, optional_ptr error) { + CompleteSubmittedWrite(offset, size, error); + }; + + lock_guard guard(lock); + if (pending_writes.empty() || batch_depth > 0) { + return false; + } + if (drain_mode == DrainMode::SEQUENTIAL && submitted_requests > 0) { + return false; + } + if (policy == SchedulePolicy::THRESHOLD) { + if (pending_bytes < first_task_schedule_threshold && submitted_bytes == 0) { + return false; + } + if (submitted_bytes >= SubmittedByteWindow()) { + return false; + } + if (submitted_requests > 0 && pending_bytes < DrainTaskByteBudget()) { + return false; + } + } + + idx_t selected_bytes = 0; + auto end = SelectPhysicalWriteEnd(0, selected_bytes); + if (policy == SchedulePolicy::THRESHOLD && selected_bytes < first_task_schedule_threshold) { + return false; + } + + auto write_offset = pending_writes.front().offset; + deque writes; + for (idx_t write_idx = 0; write_idx < end; write_idx++) { + writes.push_back(std::move(pending_writes.front())); + pending_writes.pop_front(); + } + D_ASSERT(pending_bytes >= selected_bytes); + pending_bytes -= selected_bytes; + submitted_bytes += selected_bytes; + submitted_requests++; + auto payload = CreatePayload(std::move(writes), selected_bytes); + request = AsyncWriteRequest(std::move(payload), write_offset, std::move(completion)); + return true; +} + +unique_ptr ManagedAsyncWriteStreamQueue::CreatePayload(deque writes, idx_t size) { + D_ASSERT(!writes.empty()); + auto expected_offset = writes.front().offset; + for (auto &write : writes) { + VerifyContiguousWrite(write, expected_offset); + expected_offset = NextWriteOffset(expected_offset, write.Size()); + } + if (writes.size() == 1) { + return std::move(writes.front().payload); + } + return make_uniq(client_context, std::move(writes), size); +} + +void ManagedAsyncWriteStreamQueue::CompleteSubmittedWrite(idx_t offset, idx_t size, + optional_ptr error) { + (void)offset; + bool refill = false; + { + lock_guard guard(lock); + D_ASSERT(submitted_requests > 0); + submitted_requests--; + D_ASSERT(submitted_bytes >= size); + submitted_bytes -= size; + refill = !error && !closed && batch_depth == 0 && !pending_writes.empty(); + } + if (refill) { + SchedulePendingWritesInternal(); + } +} + +void ManagedAsyncWriteStreamQueue::ApplyBackpressure() { + if (!write_queue->IsAsync()) { + VerifyOpen(); + return; + } + RethrowTaskError(); + if (HasOpenBatch()) { + return; + } + write_queue->UpdateMemoryState(ManagedAsyncWriteQueue::MemoryUpdateMode::FORCE); + SchedulePendingWrites(); + while (true) { + idx_t current_pending_bytes; + { + lock_guard guard(lock); + if (batch_depth > 0) { + return; + } + current_pending_bytes = TotalPendingBytes(); + } + if (current_pending_bytes <= write_queue->BackpressureBudget()) { + return; + } + SchedulePendingWrites(SchedulePolicy::FORCE); + write_queue->ApplyBackpressure(); + RethrowTaskError(); + } +} + +void ManagedAsyncWriteStreamQueue::WaitAll(BatchDrainMode batch_drain_mode) { + bool already_closed; + { + lock_guard guard(lock); + already_closed = closed; + } + if (already_closed) { + RethrowTaskError(); + return; + } + if (!write_queue->IsAsync()) { + RethrowTaskError(); + return; + } + + const auto preserve_batch = batch_drain_mode == BatchDrainMode::PRESERVE_BATCH; + idx_t previous_batch_depth = 0; + bool batch_opened_for_drain = false; + + // Flush/Close must drain registered writes even if the caller currently has scheduling batched off. + // Flush restores that batch state, while Close intentionally leaves it closed. + auto open_batch_for_drain = [&]() { + if (batch_opened_for_drain) { + return; + } + lock_guard guard(lock); + previous_batch_depth = batch_depth; + batch_depth = 0; + batch_opened_for_drain = true; + }; + auto restore_batch = [&]() { + if (!preserve_batch || !batch_opened_for_drain) { + return; + } + lock_guard guard(lock); + batch_depth = previous_batch_depth; + }; + + try { + open_batch_for_drain(); + write_queue->UpdateMemoryState(ManagedAsyncWriteQueue::MemoryUpdateMode::FORCE); + while (true) { + if (!write_queue->HasError()) { + SchedulePendingWritesInternal(SchedulePolicy::FORCE); + } + write_queue->WaitAll(); + lock_guard guard(lock); + if (pending_writes.empty() && pending_bytes == 0 && submitted_bytes == 0 && submitted_requests == 0) { + break; + } + } + } catch (...) { + try { + open_batch_for_drain(); + write_queue->WaitAll(); + } catch (...) { + } + restore_batch(); + throw; + } + + restore_batch(); + RethrowTaskError(); +} + +void ManagedAsyncWriteStreamQueue::VerifyDrained() const { + if (batch_depth != 0 || !pending_writes.empty() || pending_bytes != 0 || submitted_bytes != 0 || + submitted_requests != 0) { + throw InternalException("ManagedAsyncWriteStreamQueue still owns registered writes"); + } +} + +void ManagedAsyncWriteStreamQueue::CancelPendingWritesAfterFailure() noexcept { + idx_t discarded_bytes; + { + lock_guard guard(lock); + D_ASSERT(submitted_requests == 0); + D_ASSERT(submitted_bytes == 0); + if (submitted_requests != 0 || submitted_bytes != 0) { + return; + } + + discarded_bytes = pending_bytes; + pending_writes.clear(); + pending_bytes = 0; + batch_depth = 0; + closed = true; + } + write_queue->DiscardExternalPendingBytes(discarded_bytes); +} + +void ManagedAsyncWriteStreamQueue::Close() { + bool already_closed; + { + lock_guard guard(lock); + already_closed = closed; + } + if (already_closed) { + RethrowTaskError(); + return; + } + + try { + WaitAll(BatchDrainMode::FORCE_CLOSE_BATCH); + write_queue->Close(); + } catch (...) { + auto error = std::current_exception(); + try { + write_queue->Close(); + } catch (...) { + } + CancelPendingWritesAfterFailure(); + try { + ReleaseMemoryReservation(); + } catch (...) { + } + std::rethrow_exception(error); + } + + lock_guard guard(lock); + VerifyDrained(); + closed = true; +} + +void ManagedAsyncWriteStreamQueue::ResetNextOffset(idx_t offset) { + RethrowTaskError(); + lock_guard guard(lock); + VerifyOpen(); + VerifyDrained(); + next_registration_offset = offset; +} + +void ManagedAsyncWriteStreamQueue::ReleaseMemoryReservation() { + write_queue->ReleaseMemoryReservation(); +} + +void ManagedAsyncWriteStreamQueue::RethrowTaskError() { + write_queue->RethrowTaskError(); +} + +idx_t ManagedAsyncWriteStreamQueue::ValidateRegistrationOffset(idx_t offset, idx_t write_size) const { + if (offset != next_registration_offset) { + throw InternalException( + "ManagedAsyncWriteStreamQueue only supports contiguous writes: expected offset %llu, got %llu", + next_registration_offset, offset); + } + return NextWriteOffset(offset, write_size); +} + +void ManagedAsyncWriteStreamQueue::VerifyOpen() const { + if (closed) { + throw InternalException("Cannot use closed ManagedAsyncWriteStreamQueue"); + } +} + +void ManagedAsyncWriteStreamQueue::VerifyContiguousWrite(const PendingWrite &write, idx_t expected_offset) const { + if (write.offset != expected_offset) { + throw InternalException( + "ManagedAsyncWriteStreamQueue only supports contiguous writes: expected offset %llu, got %llu", + expected_offset, write.offset); + } +} + +idx_t ManagedAsyncWriteStreamQueue::NextWriteOffset(idx_t offset, idx_t write_size) const { + if (write_size > NumericLimits::Maximum() - offset) { + throw InternalException("ManagedAsyncWriteStreamQueue write offset overflow"); + } + return offset + write_size; +} + +void ManagedAsyncWriteStreamQueue::Write(data_ptr_t buffer, idx_t size, idx_t offset) { + if (size == 0) { + return; + } + if (drain_mode == DrainMode::POSITIONAL) { + target.Write(buffer, size, offset); + } else { + target.Write(buffer, size); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/vector/shredded_vector.cpp b/src/duckdb/src/common/vector/shredded_vector.cpp index a42fb0c13..faee370f0 100644 --- a/src/duckdb/src/common/vector/shredded_vector.cpp +++ b/src/duckdb/src/common/vector/shredded_vector.cpp @@ -35,6 +35,10 @@ string ShreddedVectorBuffer::ToString(const LogicalType &type, idx_t count) cons return "Shredded: " + shredded.ToString() + ", Unshredded: " + unshredded.ToString(); } +void ShreddedVectorBuffer::SetVectorType(VectorType new_vector_type) { + throw InternalException("ShreddedVectorBuffer::SetVectorType is not implemented and shouldn't be reached"); +} + Value ShreddedVectorBuffer::GetValue(const LogicalType &type, idx_t index) const { // FIXME: this is extremely inefficient auto &shredded = StructVector::GetEntries(*shredded_data)[1]; diff --git a/src/duckdb/src/execution/join_hashtable.cpp b/src/duckdb/src/execution/join_hashtable.cpp index 0ee6540f9..a6abe48f8 100644 --- a/src/duckdb/src/execution/join_hashtable.cpp +++ b/src/duckdb/src/execution/join_hashtable.cpp @@ -761,6 +761,24 @@ unique_ptr JoinHashTable::InitializePrefixRangeBu return prefix_range_filter->InitializeBuildState(context); } +void JoinHashTable::BuildPrefixRangeFilter() { + if (!ShouldBuildPrefixRangeFilter()) { + return; + } + + auto prefix_range_state = InitializePrefixRangeBuildState(); + TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::KEEP_EVERYTHING_PINNED, 0, + data_collection->ChunkCount(), false); + do { + const auto count = iterator.GetCurrentChunkCount(); + if (count == 0) { + continue; + } + InsertPrefixRangeChunk(iterator.GetChunkState(), count, *prefix_range_state); + } while (iterator.Next()); + MergePrefixRangeBuildState(*prefix_range_state); +} + void JoinHashTable::InsertPrefixRangeChunk(TupleDataChunkState &chunk_state, idx_t count, PrefixRangeFilter::BuildState &state) { D_ASSERT(prefix_range_filter); diff --git a/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp b/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp index 8c044f9e2..1adfacd2b 100644 --- a/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp @@ -1,15 +1,36 @@ #include "duckdb/execution/operator/helper/physical_streaming_sample.hpp" +#include "duckdb/common/helper.hpp" #include "duckdb/common/random_engine.hpp" #include "duckdb/common/to_string.hpp" #include "duckdb/common/enum_util.hpp" +#include "duckdb/common/numeric_utils.hpp" namespace duckdb { PhysicalStreamingSample::PhysicalStreamingSample(PhysicalPlan &physical_plan, vector types, unique_ptr options, idx_t estimated_cardinality) : PhysicalOperator(physical_plan, PhysicalOperatorType::STREAMING_SAMPLE, std::move(types), estimated_cardinality), - sample_options(std::move(options)) { - percentage = sample_options->sample_size.GetValue() / 100; + sample_options(std::move(options)), percentage(0.0), system_sample_phase(0.0), rows(0) { + if (sample_options->is_percentage) { + percentage = sample_options->sample_size.GetValue() / 100; + } else { + // Convert target row count to a sampling rate. + // Prefer the pre-calculated sample_rate from the planner if available (ensures + // consistency with pushdown path), otherwise derive from estimated_cardinality. + // Fallback to 1.0 (take all rows) if no estimate is available. + rows = NumericCast(sample_options->sample_size.GetValue()); + if (sample_options->sample_rate > 0) { + percentage = sample_options->sample_rate; + } else if (estimated_cardinality > 0) { + percentage = static_cast(rows) / static_cast(estimated_cardinality); + } else { + percentage = 1.0; + } + percentage = MinValue(1.0, MaxValue(0.0, percentage)); + RandomEngine random(sample_options->seed.IsValid() ? static_cast(sample_options->seed.GetIndex()) + : -1); + system_sample_phase = random.NextRandom(); + } } //===--------------------------------------------------------------------===// @@ -17,13 +38,16 @@ PhysicalStreamingSample::PhysicalStreamingSample(PhysicalPlan &physical_plan, ve //===--------------------------------------------------------------------===// class StreamingSampleOperatorState : public OperatorState { public: - explicit StreamingSampleOperatorState(int64_t seed) : random(seed) { + explicit StreamingSampleOperatorState(int64_t seed) : random(seed), system_rows_seen(0) { } RandomEngine random; + + // Counters for row-count SYSTEM sampling. + idx_t system_rows_seen; }; -void PhysicalStreamingSample::SystemSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { +void PhysicalStreamingSample::SystemSamplePercent(DataChunk &input, DataChunk &result, OperatorState &state_p) const { // system sampling: we throw one dice per chunk auto &state = state_p.Cast(); double rand = state.random.NextRandom(); @@ -33,6 +57,37 @@ void PhysicalStreamingSample::SystemSample(DataChunk &input, DataChunk &result, } } +void PhysicalStreamingSample::SystemSampleRows(DataChunk &input, DataChunk &result, OperatorState &state_p) const { + double rate = percentage; + if (rate <= 0) { + return; + } + if (rate >= 1) { + result.Reference(input); + return; + } + + // Emit a row whenever rows_seen * rate crosses the next integer threshold. + // Using a fresh multiply per row (rather than an accumulated sum) avoids + // floating-point drift where repeated additions of rate never quite reach + // a whole number (e.g. 10000 * 0.0001 < 1.0). The phase makes seeds affect + // where the systematic sample starts. + auto &state = state_p.Cast(); + idx_t result_count = 0; + SelectionVector sel(input.size()); + for (idx_t i = 0; i < input.size(); i++) { + auto before = std::floor(LossyNumericCast(state.system_rows_seen) * rate + system_sample_phase); + state.system_rows_seen++; + auto after = std::floor(LossyNumericCast(state.system_rows_seen) * rate + system_sample_phase); + if (after > before) { + sel.set_index(result_count++, i); + } + } + if (result_count > 0) { + result.Slice(input, sel, result_count); + } +} + void PhysicalStreamingSample::BernoulliSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { // bernoulli sampling: we throw one dice per tuple // then slice the result chunk @@ -51,6 +106,11 @@ void PhysicalStreamingSample::BernoulliSample(DataChunk &input, DataChunk &resul } bool PhysicalStreamingSample::ParallelOperator() const { + if (!sample_options->is_percentage) { + // Row-count SYSTEM sampling must see the full input stream through a + // single OperatorState so the rows_seen counter advances globally. + return false; + } return !sample_options->repeatable; } @@ -71,7 +131,11 @@ OperatorResultType PhysicalStreamingSample::Execute(ExecutionContext &context, D BernoulliSample(input, chunk, state); break; case SampleMethod::SYSTEM_SAMPLE: - SystemSample(input, chunk, state); + if (sample_options->is_percentage) { + SystemSamplePercent(input, chunk, state); + } else { + SystemSampleRows(input, chunk, state); + } break; default: throw InternalException("Unsupported sample method for streaming sample"); @@ -81,7 +145,11 @@ OperatorResultType PhysicalStreamingSample::Execute(ExecutionContext &context, D InsertionOrderPreservingMap PhysicalStreamingSample::ParamsToString() const { InsertionOrderPreservingMap result; - result["Sample Method"] = EnumUtil::ToString(sample_options->method) + ": " + to_string(100 * percentage) + "%"; + if (sample_options->is_percentage) { + result["Sample Method"] = EnumUtil::ToString(sample_options->method) + ": " + to_string(100 * percentage) + "%"; + } else { + result["Sample Method"] = EnumUtil::ToString(sample_options->method) + ": " + to_string(rows) + " rows"; + } SetEstimatedCardinality(result, estimated_cardinality); return result; } diff --git a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp index 7b608a3fa..c44696f51 100644 --- a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp @@ -1070,7 +1070,7 @@ bool JoinFilterPushdownInfo::CanUseInFilter(const ClientContext &context, option return ht && ht->Count() > 1 && ht->Count() <= dynamic_or_filter_threshold && cmp == ExpressionType::COMPARE_EQUAL; } -void JoinFilterPushdownInfo::PushInFilter(const JoinFilterPushdownFilter &info, JoinHashTable &ht, +bool JoinFilterPushdownInfo::PushInFilter(const JoinFilterPushdownFilter &info, JoinHashTable &ht, const PhysicalOperator &op, idx_t filter_idx, ProjectionIndex filter_col_idx) const { // generate a "OR" filter (i.e. x=1 OR x=535 OR x=997) @@ -1080,7 +1080,7 @@ void JoinFilterPushdownInfo::PushInFilter(const JoinFilterPushdownFilter &info, Vector build_vector(ht.layout_ptr->GetTypes()[build_idx], ht.Count()); auto key_count = ht.ScanKeyColumn(tuples_addresses, build_vector, build_idx); if (key_count == 0) { - return; + return false; } // generate the OR-clause - note that we only need to consider unique values here (so we use a seT) @@ -1090,7 +1090,7 @@ void JoinFilterPushdownInfo::PushInFilter(const JoinFilterPushdownFilter &info, auto value = build_vector.GetValue(k); if (info.columns[filter_idx].storage_type.IsValid() && !value.DefaultTryCastAs(info.columns[filter_idx].storage_type)) { - return; // it's all or nothing sadly + return false; // it's all or nothing sadly } unique_ht_values.insert(value); } @@ -1100,7 +1100,7 @@ void JoinFilterPushdownInfo::PushInFilter(const JoinFilterPushdownFilter &info, // not dense and that the range does not contain NULL // i.e. if we have the values [0, 1, 2, 3, 4] - the min/max is fully equivalent to the OR filter if (FilterCombiner::ContainsNull(in_list) || FilterCombiner::IsDenseRange(in_list)) { - return; + return false; } // we push the OR filter as an OptionalFilter so that we can use it for zonemap pruning only @@ -1110,6 +1110,7 @@ void JoinFilterPushdownInfo::PushInFilter(const JoinFilterPushdownFilter &info, auto filter = make_uniq( CreateOptionalFilterExpression(std::move(in_expr), info.columns[filter_idx].storage_type)); info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(filter)); + return true; } bool JoinFilterPushdownInfo::CanUseBloomFilter(const ClientContext &context, const PhysicalComparisonJoin &op, @@ -1160,41 +1161,19 @@ bool JoinFilterPushdownInfo::CanUseBloomFilter(const ClientContext &context, con return true; } -bool JoinFilterPushdownInfo::CanUsePrefixRangeFilter(ClientContext &context, optional_ptr ht, - const PhysicalComparisonJoin &op, const ExpressionType &cmp, - const Value &min, const Value &max) const { +bool JoinFilterPushdownInfo::CanUsePrefixRangeFilter(const ClientContext &context, const PhysicalComparisonJoin &op, + optional_ptr ht, const ExpressionType &cmp) const { if (!CanUseBloomFilter(context, op, cmp, ht)) { return false; } - if (ht->Count() == 0) { + if (cmp != ExpressionType::COMPARE_EQUAL) { return false; } - if (ht->NullValuesAreEqual(0)) { // TODO: Support "A is B" type joins return false; } - static constexpr idx_t BUILD_SIZE_THRESHOLD = 524288; - bool ht_is_small = ht->Count() <= BUILD_SIZE_THRESHOLD; - bool span_is_small = false; - - uhugeint_t span; - if (PrefixRangeFilter::TryComputeSpan(min, max, span)) { - if (span == 0) { - // Filter will not be more expressive than min/max, bail - return false; - } - static const auto SPAN_THRESHOLD = Uhugeint::Convert(1048576); - span_is_small = span <= SPAN_THRESHOLD; - } else { - return false; - } - - if (!ht_is_small && !span_is_small) { - return false; - } - const auto &key_type = ht->conditions[0].GetLHS().GetReturnType(); return PrefixRangeFilter::SupportedType(key_type); } @@ -1243,37 +1222,16 @@ void JoinFilterPushdownInfo::PushBloomFilter(ClientContext &context, const Physi SelectivityOptionalFilterType::BF)); } -void JoinFilterPushdownInfo::PushPerfectHashJoinFilter(ClientContext &context, const PhysicalOperator &op, - PerfectHashJoinExecutor &perfect_join_executor, - const JoinFilterPushdownFilter &info, idx_t filter_idx, - ProjectionIndex filter_col_idx) const { - const auto key_name = op.Cast().conditions[0].GetRHS().ToString(); - const auto &key_type = perfect_join_executor.GetKeyType(); - auto filter_input_type = GetRuntimeFilterInputType(info.columns[filter_idx], key_type); - float selectivity_threshold; - idx_t n_vectors_to_check; - GetThresholdAndVectorsToCheck(SelectivityOptionalFilterType::PHJ, selectivity_threshold, n_vectors_to_check); - vector> children; - children.push_back(CreateRuntimeFilterInputExpression(context, info.columns[filter_idx], key_type)); - auto filter_expr = make_uniq( - BoundScalarFunction(PerfectHashJoinScalarFun::GetFunction(filter_input_type)), std::move(children), - make_uniq(perfect_join_executor, key_name, selectivity_threshold, - n_vectors_to_check)); - info.dynamic_filters->PushFilter(op, filter_col_idx, - CreateSelectivityOptionalExpressionFilter(std::move(filter_expr), - info.columns[filter_idx].storage_type, - SelectivityOptionalFilterType::PHJ)); -} - -void JoinFilterPushdownInfo::RegisterPrefixRangeFilter(const JoinFilterPushdownFilter &info, ClientContext &context, - JoinHashTable &ht, const PhysicalOperator &op, idx_t filter_idx, - ProjectionIndex filter_col_idx, const Value &min_val, - const Value &max_val) const { +bool JoinFilterPushdownInfo::TryRegisterPrefixRangeFilter(const JoinFilterPushdownFilter &info, ClientContext &context, + JoinHashTable &ht, const PhysicalOperator &op, + idx_t filter_idx, ProjectionIndex filter_col_idx, + const Value &min_val, const Value &max_val, + idx_t max_bits) const { const auto key_type = ht.conditions[0].GetLHS().GetReturnType(); auto filter_input_type = GetRuntimeFilterInputType(info.columns[filter_idx], key_type); if (!ht.GetPrefixRangeFilter()) { auto prefix_filter = PrefixRangeFilter::CreatePrefixRangeFilter(key_type); - prefix_filter->Initialize(context, ht.Count(), min_val, max_val); + prefix_filter->Initialize(context, ht.Count(), min_val, max_val, max_bits); ht.SetPrefixRangeFilter(std::move(prefix_filter)); ht.SetBuildPrefixRangeFilter(); } @@ -1292,6 +1250,7 @@ void JoinFilterPushdownInfo::RegisterPrefixRangeFilter(const JoinFilterPushdownF CreateSelectivityOptionalExpressionFilter(std::move(filter_expr), info.columns[filter_idx].storage_type, SelectivityOptionalFilterType::PRF)); + return true; } unique_ptr JoinFilterPushdownInfo::FinalizeMinMax(JoinFilterGlobalState &gstate) const { @@ -1319,10 +1278,17 @@ static unique_ptr CreateSelectivityOptionalExpressionFilter(un static void CreateDynamicMinMaxFilter(const PhysicalComparisonJoin &op, const JoinFilterPushdownFilter &info, const ProjectionIndex &filter_col_idx, unique_ptr filter_expr, - const LogicalType &column_type) { - info.dynamic_filters->PushFilter(op, filter_col_idx, - CreateSelectivityOptionalExpressionFilter(std::move(filter_expr), column_type, - SelectivityOptionalFilterType::MIN_MAX)); + const LogicalType &column_type, bool selectivity_optional) { + if (selectivity_optional) { + info.dynamic_filters->PushFilter( + op, filter_col_idx, + CreateSelectivityOptionalExpressionFilter(std::move(filter_expr), column_type, + SelectivityOptionalFilterType::MIN_MAX)); + } else { + info.dynamic_filters->PushFilter( + op, filter_col_idx, + make_uniq(CreateOptionalFilterExpression(std::move(filter_expr), column_type))); + } } static unique_ptr CreateComparisonExpressionFilter(ExpressionType comparison_type, const Value &constant, @@ -1336,10 +1302,46 @@ static unique_ptr CreateComparisonExpressionFilter(ExpressionType co make_uniq(std::move(constant_value))); } -unique_ptr -JoinFilterPushdownInfo::FinalizeFilters(ClientContext &context, const PhysicalComparisonJoin &op, - unique_ptr final_min_max, optional_ptr ht, - optional_ptr perfect_join_executor) const { +static void CreateDynamicMinMaxFilters(const PhysicalComparisonJoin &op, const JoinFilterPushdownFilter &info, + ProjectionIndex filter_col_idx, ExpressionType cmp, const Value &min_val, + const Value &max_val, const LogicalType &condition_type, + bool selectivity_optional) { + switch (cmp) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_GREATERTHAN: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: { + CreateDynamicMinMaxFilter( + op, info, filter_col_idx, + CreateComparisonExpressionFilter(ExpressionType::COMPARE_GREATERTHANOREQUALTO, min_val, condition_type), + condition_type, selectivity_optional); + break; + } + default: + break; + } + switch (cmp) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_LESSTHANOREQUALTO: { + CreateDynamicMinMaxFilter( + op, info, filter_col_idx, + CreateComparisonExpressionFilter(ExpressionType::COMPARE_LESSTHANOREQUALTO, max_val, condition_type), + condition_type, selectivity_optional); + break; + } + default: + break; + } +} + +static idx_t BloomFilterBitBudget(idx_t ht_count) { + return BloomFilter::GetNumberOfSectors(ht_count) * 64; +} + +unique_ptr JoinFilterPushdownInfo::FinalizeFilters(ClientContext &context, const PhysicalComparisonJoin &op, + unique_ptr final_min_max, + optional_ptr ht, bool allow_bloom_filters, + bool allow_prefix_range_filters) const { if (probe_info.empty()) { return final_min_max; // There are no table sources in which we can push down filters } @@ -1379,17 +1381,10 @@ JoinFilterPushdownInfo::FinalizeFilters(ClientContext &context, const PhysicalCo auto condition_type = min_val.type(); auto runtime_filter_input_type = GetRuntimeFilterInputType(pushdown_column, condition_type); bool can_emit_runtime_filters = pushdown_column.mode == JoinFilterPushdownMode::RECONSTRUCT_EXPRESSION; - if (can_emit_runtime_filters && perfect_join_executor) { - can_emit_runtime_filters = runtime_filter_input_type == perfect_join_executor->GetKeyType(); - } else if (can_emit_runtime_filters && ht) { + if (can_emit_runtime_filters && ht) { can_emit_runtime_filters = runtime_filter_input_type == ht->conditions[0].GetLHS().GetReturnType(); } - // if the HT is small we can generate a complete "OR" filter - // but only if the join condition is equality. - if (ht && CanUseInFilter(context, ht, cmp)) { - PushInFilter(info, *ht, op, filter_idx, filter_col_idx); - } if (Value::NotDistinctFrom(min_val, max_val)) { // min = max - single value // generate a "one-sided" comparison filter for the LHS @@ -1398,45 +1393,48 @@ JoinFilterPushdownInfo::FinalizeFilters(ClientContext &context, const PhysicalCo op, filter_col_idx, make_uniq(CreateComparisonExpressionFilter(cmp, min_val, condition_type))); } else { - // min != max - generate a range filter or bloom filter + optional range filter - // for non-equalities, the range must be half-open - // e.g., for lhs < rhs we can only use lhs <= max - switch (cmp) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: { - CreateDynamicMinMaxFilter( - op, info, filter_col_idx, - CreateComparisonExpressionFilter(ExpressionType::COMPARE_GREATERTHANOREQUALTO, min_val, - condition_type), - condition_type); - break; + if (cmp != ExpressionType::COMPARE_EQUAL) { + // min != max - generate range filters for non-equality comparisons. + // For non-equalities, the range must be half-open. + CreateDynamicMinMaxFilters(op, info, filter_col_idx, cmp, min_val, max_val, condition_type, true); + continue; } - default: - break; + + uhugeint_t span; + const auto can_compute_span = + PrefixRangeFilter::TryComputeSpan(min_val_before_cast, max_val_before_cast, span); + const auto can_emit_prf = allow_prefix_range_filters && can_emit_runtime_filters && + CanUsePrefixRangeFilter(context, op, ht, cmp) && can_compute_span; + + bool pushed_in_filter = false; + if (CanUseInFilter(context, ht, cmp)) { + pushed_in_filter = PushInFilter(info, *ht, op, filter_idx, filter_col_idx); } - switch (cmp) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: { - CreateDynamicMinMaxFilter(op, info, filter_col_idx, - CreateComparisonExpressionFilter( - ExpressionType::COMPARE_LESSTHANOREQUALTO, max_val, condition_type), - condition_type); - break; + + static constexpr idx_t SMALL_EXACT_PRF_BITS = 1ULL << 26; + if (can_emit_prf && span < SMALL_EXACT_PRF_BITS && + TryRegisterPrefixRangeFilter(info, context, *ht, op, filter_idx, filter_col_idx, + min_val_before_cast, max_val_before_cast, SMALL_EXACT_PRF_BITS)) { + continue; } - default: - break; + + if (can_emit_prf) { + auto build_count = ht->Count(); + if (build_count == 0) { + build_count = ht->GetSinkCollection().Count(); + } + const auto bloom_filter_bits = BloomFilterBitBudget(build_count); + if (span <= bloom_filter_bits && + TryRegisterPrefixRangeFilter(info, context, *ht, op, filter_idx, filter_col_idx, + min_val_before_cast, max_val_before_cast, bloom_filter_bits)) { + continue; + } } - if (can_emit_runtime_filters && perfect_join_executor) { - PushPerfectHashJoinFilter(context, op, *perfect_join_executor, info, filter_idx, filter_col_idx); - } else if (can_emit_runtime_filters && - CanUsePrefixRangeFilter(context, ht, op, cmp, min_val_before_cast, max_val_before_cast)) { - // It's important that these get the min/max val before casting - RegisterPrefixRangeFilter(info, context, *ht, op, filter_idx, filter_col_idx, min_val_before_cast, - max_val_before_cast); - } else if (can_emit_runtime_filters && ht && CanUseBloomFilter(context, op, cmp, ht)) { + if (!pushed_in_filter) { + CreateDynamicMinMaxFilters(op, info, filter_col_idx, cmp, min_val, max_val, condition_type, false); + } + if (allow_bloom_filters && can_emit_runtime_filters && ht && CanUseBloomFilter(context, op, cmp, ht)) { PushBloomFilter(context, op, *ht, info, filter_idx, filter_col_idx); } } @@ -1445,12 +1443,11 @@ JoinFilterPushdownInfo::FinalizeFilters(ClientContext &context, const PhysicalCo return final_min_max; } -unique_ptr -JoinFilterPushdownInfo::Finalize(ClientContext &context, JoinFilterGlobalState &gstate, - const PhysicalComparisonJoin &op, optional_ptr ht, - optional_ptr perfect_hash_join_executor) const { +unique_ptr JoinFilterPushdownInfo::Finalize(ClientContext &context, JoinFilterGlobalState &gstate, + const PhysicalComparisonJoin &op, + optional_ptr ht) const { auto final_min_max = FinalizeMinMax(gstate); - return FinalizeFilters(context, op, std::move(final_min_max), ht, perfect_hash_join_executor); + return FinalizeFilters(context, op, std::move(final_min_max), ht, true); } SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, @@ -1514,8 +1511,9 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl sink.owned_local_hash_tables.clear(); if (filter_pushdown && !sink.skip_filter_pushdown && ht.GetSinkCollection().Count() > 0) { auto filter_min_max = filter_pushdown->FinalizeMinMax(*sink.global_filter_state); - filter_pushdown->FinalizeFilters(context, *this, std::move(filter_min_max), &ht, nullptr); + filter_pushdown->FinalizeFilters(context, *this, std::move(filter_min_max), &ht, true, false); } + ht.PrepareBloomFilterForFinalize(); D_ASSERT(sink.temporary_memory_state->GetReservation() >= sink.probe_side_requirement); sink.hash_table->PrepareExternalFinalize(sink.temporary_memory_state->GetReservation() - sink.probe_side_requirement); @@ -1549,8 +1547,6 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl // check for possible perfect hash table auto use_perfect_hash = sink.perfect_join_executor->CanDoPerfectHashJoin(*this, min, max); if (use_perfect_hash) { - D_ASSERT(ht.equality_types.size() == 1); - auto key_type = ht.equality_types[0]; use_perfect_hash = sink.perfect_join_executor->BuildPerfectHashTable(); } @@ -1559,14 +1555,15 @@ SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, Cl } if (filter_min_max) { - filter_pushdown->FinalizeFilters(context, *this, std::move(filter_min_max), &ht, sink.perfect_join_executor); - if (!use_perfect_hash) { - ht.PrepareBloomFilterForFinalize(); + filter_pushdown->FinalizeFilters(context, *this, std::move(filter_min_max), &ht, !use_perfect_hash); + if (use_perfect_hash) { + ht.BuildPrefixRangeFilter(); } } // In case of a large build side or duplicates, use regular hash join if (!use_perfect_hash) { + ht.PrepareBloomFilterForFinalize(); sink.ScheduleFinalize(pipeline, event); } sink.finalized = true; diff --git a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp index cd1524cf3..04dbcdac7 100644 --- a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp @@ -373,7 +373,8 @@ InsertionOrderPreservingMap PhysicalTableScan::ParamsToString() const { } if (extra_info.sample_options) { - result["Sample Method"] = "System: " + extra_info.sample_options->sample_size.ToString() + "%"; + result["Sample Method"] = "System: " + extra_info.sample_options->sample_size.ToString() + + (extra_info.sample_options->is_percentage ? "%" : " rows"); } if (!extra_info.file_filters.empty()) { result["File Filters"] = extra_info.file_filters; diff --git a/src/duckdb/src/execution/physical_plan/plan_sample.cpp b/src/duckdb/src/execution/physical_plan/plan_sample.cpp index 65aa2ea9b..76dc21951 100644 --- a/src/duckdb/src/execution/physical_plan/plan_sample.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_sample.cpp @@ -1,3 +1,5 @@ +#include "duckdb/execution/operator/helper/physical_limit.hpp" +#include "duckdb/execution/operator/helper/physical_streaming_limit.hpp" #include "duckdb/execution/operator/helper/physical_reservoir_sample.hpp" #include "duckdb/execution/operator/helper/physical_streaming_sample.hpp" #include "duckdb/execution/physical_plan_generator.hpp" @@ -10,6 +12,17 @@ namespace duckdb { PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalSample &op) { D_ASSERT(op.children.size() == 1); + // Only reached when a LogicalSample survives optimization. Sampling pushdown removes + // LogicalSample over a plain table GET, so that path uses LogicalLimit + scan instead + // and never hits the row-count LIMIT wrap below for the same sample. + + // For SYSTEM_SAMPLE with row count, we need to get the child's estimated cardinality + // BEFORE calling CreatePlan (which consumes the child). + idx_t child_cardinality = 0; + if (op.sample_options->method == SampleMethod::SYSTEM_SAMPLE && !op.sample_options->is_percentage) { + auto &first_child = *op.children[0]; + child_cardinality = first_child.EstimateCardinality(context); + } auto &plan = CreatePlan(*op.children[0]); if (!op.sample_options->seed.IsValid()) { @@ -23,7 +36,6 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalSample &op) { sample.children.push_back(plan); return sample; } - case SampleMethod::SYSTEM_SAMPLE: case SampleMethod::BERNOULLI_SAMPLE: { if (!op.sample_options->is_percentage) { throw ParserException("Sample method %s cannot be used with a discrete sample count, either switch to " @@ -34,6 +46,47 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalSample &op) { sample.children.push_back(plan); return sample; } + case SampleMethod::SYSTEM_SAMPLE: { + const bool is_percentage = op.sample_options->is_percentage; + int64_t rows = 0; + if (!is_percentage) { + rows = op.sample_options->sample_size.GetValue(); + // To ensure consistency between optimized and unoptimized paths, + // we calculate the rate based on the estimated cardinality of the child. + if (child_cardinality > 0) { + op.sample_options->sample_rate = static_cast(rows) / static_cast(child_cardinality); + } else { + op.sample_options->sample_rate = 1.0; + } + } + + auto &sample = Make(op.types, std::move(op.sample_options), op.estimated_cardinality); + sample.children.push_back(plan); + + if (!is_percentage) { + // Mirror sampling_pushdown.cpp: cap row count when LogicalSample is still present. + // As the sampling operator uses a distributed chunk-based approach it may + // oversample, so we wrap it with a LIMIT to ensure we stop as soon as the target is reached + // This also happens when no estimated cardinality is available. + auto limit_val = BoundLimitNode::ConstantValue(rows); + auto offset_val = BoundLimitNode(); + // PhysicalLimit requires batch-index support from the pipeline source. + // Sources like CTE scans don't provide it, so fall back to a streaming + // limit which has no such requirement. + const bool preserve_order = PreserveInsertionOrder(sample); + if (preserve_order && UseBatchIndex(sample)) { + auto &limit = Make(op.types, std::move(limit_val), std::move(offset_val), + op.estimated_cardinality); + limit.children.push_back(sample); + return limit; + } + auto &limit = Make(op.types, std::move(limit_val), std::move(offset_val), + op.estimated_cardinality, !preserve_order); + limit.children.push_back(sample); + return limit; + } + return sample; + } default: throw InternalException("Unimplemented sample method"); } diff --git a/src/duckdb/src/function/aggregate/distributive/minmax.cpp b/src/duckdb/src/function/aggregate/distributive/minmax.cpp index cdccf99cd..0d25a6a05 100644 --- a/src/duckdb/src/function/aggregate/distributive/minmax.cpp +++ b/src/duckdb/src/function/aggregate/distributive/minmax.cpp @@ -531,15 +531,6 @@ AggregateFunction GetMinMaxNFunction() { MinMaxNBind, nullptr); } -AggregateStateLayout GetExportStateType(AggregateLayoutInput &input) { - auto &function = input.function; - child_list_t struct_children_types; - struct_children_types.emplace_back("value", function.GetReturnType()); - struct_children_types.emplace_back("is_set", LogicalType::BOOLEAN); - return AggregateStateLayout(LogicalType::STRUCT(std::move(struct_children_types)), - AlignValue(function.GetStateSizeCallback()(function))); -} - } // namespace //--------------------------------------------------- // Function Registration @@ -547,14 +538,14 @@ AggregateStateLayout GetExportStateType(AggregateLayoutInput &input) { AggregateFunctionSet MinFun::GetFunctions() { AggregateFunctionSet min("min"); min.AddFunction(MinFunction::GetFunction()); - min.AddFunction(GetMinMaxNFunction().SetStructStateExport(GetExportStateType)); + min.AddFunction(GetMinMaxNFunction()); return min; } AggregateFunctionSet MaxFun::GetFunctions() { AggregateFunctionSet max("max"); max.AddFunction(MaxFunction::GetFunction()); - max.AddFunction(GetMinMaxNFunction().SetStructStateExport(GetExportStateType)); + max.AddFunction(GetMinMaxNFunction()); return max; } diff --git a/src/duckdb/src/function/cast/struct_cast.cpp b/src/duckdb/src/function/cast/struct_cast.cpp index 0e26f4be8..b0950ad86 100644 --- a/src/duckdb/src/function/cast/struct_cast.cpp +++ b/src/duckdb/src/function/cast/struct_cast.cpp @@ -124,14 +124,6 @@ static bool StructToStructCast(Vector &source, Vector &result, idx_t count, Cast ConstantVector::SetNull(target_vector, count_t(count)); } } - - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - FlatVector::SetSize(result, count); - ConstantVector::SetNull(result, ConstantVector::IsNull(source)); - return all_converted; - } - FlatVector::CopyValidity(result, source, count); FlatVector::SetSize(result, count); result.Verify(); diff --git a/src/duckdb/src/function/function_list.cpp b/src/duckdb/src/function/function_list.cpp index a61c9a55f..878d2b0be 100644 --- a/src/duckdb/src/function/function_list.cpp +++ b/src/duckdb/src/function/function_list.cpp @@ -123,7 +123,6 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION(TableFilterBloomFilterFun), DUCKDB_SCALAR_FUNCTION(TableFilterDynamicFun), DUCKDB_SCALAR_FUNCTION(TableFilterOptionalFun), - DUCKDB_SCALAR_FUNCTION(TableFilterPerfectHashJoinFun), DUCKDB_SCALAR_FUNCTION(TableFilterPrefixRangeFun), DUCKDB_SCALAR_FUNCTION(TableFilterSelectivityOptionalFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(AddFun), diff --git a/src/duckdb/src/function/scalar/system/aggregate_export.cpp b/src/duckdb/src/function/scalar/system/aggregate_export.cpp index d3f95073a..5b5aca876 100644 --- a/src/duckdb/src/function/scalar/system/aggregate_export.cpp +++ b/src/duckdb/src/function/scalar/system/aggregate_export.cpp @@ -185,12 +185,43 @@ static void SerializeField(const LogicalType &type, const AggregateStateField &f // linked list field: build the result LIST vector from each state's linked list // an empty linked list is exported as NULL, matching the finalize semantics of list aggregates D_ASSERT(type.id() == LogicalTypeId::LIST); + D_ASSERT(field.children.size() == 1); vector linked_lists; linked_lists.reserve(count); for (idx_t i = 0; i < count; i++) { linked_lists.push_back(Load(addresses[i] + base + field.field_offset)); } - field.list_functions.BuildLists(linked_lists, result, 0); + const auto &element = field.children[0]; + if (element.kind != AggregateFieldKind::SORT_KEY) { + // elements are stored directly - build the result LIST vector from each state's linked list + field.list_functions.BuildLists(linked_lists, result, 0); + break; + } + // the elements are sort keys: build the physically stored (BLOB) elements into a temporary LIST vector, then + // decode each sort key into the result child while rebuilding the result's list entries + Vector physical_list(LogicalType::LIST(LogicalType::BLOB), count); + field.list_functions.BuildLists(linked_lists, physical_list, 0); + + ListVector::Reserve(result, ListVector::GetListSize(physical_list)); + auto &result_child = ListVector::GetChildMutable(result); + auto result_entries = FlatVector::GetDataMutable(result); + const OrderModifiers modifiers(element.sort_key_order, OrderByNullType::NULLS_LAST); + + idx_t child_offset = 0; + for (const auto list_entry : physical_list.Values>()) { + const auto row = list_entry.GetIndex(); + if (!list_entry.IsValid()) { + // an empty linked list is exported as NULL, matching the finalize semantics of list aggregates + FlatVector::SetNull(result, row, true); + result_entries[row] = {child_offset, 0}; + continue; + } + result_entries[row] = {child_offset, list_entry.GetListLength()}; + for (const auto sort_key : list_entry.GetChildValues()) { + CreateSortKeyHelpers::DecodeSortKey(sort_key.GetValueUnsafe(), result_child, child_offset++, modifiers); + } + } + ListVector::SetListSize(result, child_offset); break; } } @@ -247,11 +278,26 @@ static void DeserializeField(const LogicalType &type, const AggregateStateField case AggregateFieldKind::LIST: { // linked list field: append each row of the input LIST vector into the state's linked list D_ASSERT(type.id() == LogicalTypeId::LIST); - // the child data is appended through the ListSegmentFunctions API, which takes a RecursiveUnifiedVectorFormat - RecursiveUnifiedVectorFormat child_data; - Vector::RecursiveToUnifiedFormat(ListVector::GetChild(input_vec), child_data); + D_ASSERT(field.children.size() == 1); + const auto values = input_vec.Values(); + const auto &element = field.children[0]; + const auto &logical_child = ListVector::GetChild(input_vec); + + // the child is appended through the ListSegmentFunctions API, which physically stores the element type - + // sort-key elements are first re-encoded from the logical child into a temporary BLOB child vector + optional_ptr physical_child = logical_child; + unique_ptr encoded_child; + if (element.kind == AggregateFieldKind::SORT_KEY) { + const auto child_count = ListVector::GetListSize(input_vec); + const OrderModifiers modifiers(element.sort_key_order, OrderByNullType::NULLS_LAST); + // the result must be sized for the full (possibly larger than standard) child up front + encoded_child = make_uniq(LogicalType::BLOB, MaxValue(child_count, 1)); + CreateSortKeyHelpers::CreateSortKey(logical_child, child_count, modifiers, *encoded_child); + physical_child = *encoded_child; + } - auto values = input_vec.Values(); + RecursiveUnifiedVectorFormat child_data; + Vector::RecursiveToUnifiedFormat(*physical_child, child_data); for (idx_t i = 0; i < count; i++) { LinkedList linked_list; const auto entry = values[i]; @@ -479,9 +525,10 @@ void ParseStateParameters(const Value ¶meters, vector &argument } argument_types.push_back(TypeValue::GetType(children[0])); if (!children[1].IsNull()) { - // the parameter is bound to a constant - decode it and cast it back to the declared argument type - constant_parameters.emplace(arg_idx, - VariantValue::GetValue(children[1]).DefaultCastAs(argument_types.back())); + // the parameter is bound to a constant - decode it as-is, without casting it to the declared + // argument type, so that re-binding sees the same (pre-cast) constant as the original bind + // (e.g. a DECIMAL quantile parameter must stay DECIMAL even though the signature says DOUBLE) + constant_parameters.emplace(arg_idx, VariantValue::GetValue(children[1])); } continue; } @@ -827,19 +874,14 @@ ExportAggregateFunction::Bind(unique_ptr child_aggrega if (!bound_function.HasStateCombineCallback()) { throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s", bound_function.GetName()); } - if (bound_function.HasStateDestructorCallback()) { - throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom destructors"); - } - // this should be required - D_ASSERT(bound_function.HasStateSizeCallback()); - D_ASSERT(bound_function.HasStateFinalizeCallback()); - - D_ASSERT(child_aggregate->Function().GetReturnType().id() != LogicalTypeId::INVALID); if (!bound_function.HasGetStateTypeCallback()) { throw NotImplementedException( "Aggregate function \"%s\" does not have a state type callback defined - cannot export state", bound_function.GetName()); } + D_ASSERT(bound_function.HasStateSizeCallback()); + D_ASSERT(bound_function.HasStateFinalizeCallback()); + D_ASSERT(child_aggregate->Function().GetReturnType().id() != LogicalTypeId::INVALID); SetStateExport(*child_aggregate, CreateAggregateStateType(bound_function, child_aggregate->BindInfo().get())); return child_aggregate; } diff --git a/src/duckdb/src/function/table/table_scan.cpp b/src/duckdb/src/function/table/table_scan.cpp index b480e98dc..00a6bba52 100644 --- a/src/duckdb/src/function/table/table_scan.cpp +++ b/src/duckdb/src/function/table/table_scan.cpp @@ -304,7 +304,8 @@ class DuckTableScanState : public TableScanGlobalState { make_uniq(*bind_data.order_options, TransactionData(tx)); } - l_state->scan_state.Initialize(std::move(storage_ids), context.client, input.filters, input.sample_options); + l_state->scan_state.Initialize(std::move(storage_ids), context.client, input.filters, input.sample_options, + total_rows); l_state->rows_in_current_row_group = storage.NextParallelScan(context.client, state, l_state->scan_state); if (l_state->rows_in_current_row_group > 0) { diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index 594b1c55b..f273972f6 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev8694" +#define DUCKDB_PATCH_VERSION "0-dev8815" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 6 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.6.0-dev8694" +#define DUCKDB_VERSION "v1.6.0-dev8815" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "72e5a0f30c" +#define DUCKDB_SOURCE_ID "68d73b7c3a" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/include/duckdb/common/enum_util.hpp b/src/duckdb/src/include/duckdb/common/enum_util.hpp index 9326eea29..549cd9347 100644 --- a/src/duckdb/src/include/duckdb/common/enum_util.hpp +++ b/src/duckdb/src/include/duckdb/common/enum_util.hpp @@ -184,6 +184,8 @@ enum class DeprecatedUsingKeySyntax : uint8_t; enum class DestroyBufferUpon : uint8_t; +enum class DialectCompatibilityMode : uint8_t; + enum class DistinctCountSource : uint8_t; enum class DistinctType : uint8_t; @@ -785,6 +787,9 @@ const char* EnumUtil::ToChars(DeprecatedUsingKeySyntax template<> const char* EnumUtil::ToChars(DestroyBufferUpon value); +template<> +const char* EnumUtil::ToChars(DialectCompatibilityMode value); + template<> const char* EnumUtil::ToChars(DistinctCountSource value); @@ -1572,6 +1577,9 @@ DeprecatedUsingKeySyntax EnumUtil::FromString(const ch template<> DestroyBufferUpon EnumUtil::FromString(const char *value); +template<> +DialectCompatibilityMode EnumUtil::FromString(const char *value); + template<> DistinctCountSource EnumUtil::FromString(const char *value); diff --git a/src/duckdb/src/include/duckdb/common/enums/dialect_compatibility_mode.hpp b/src/duckdb/src/include/duckdb/common/enums/dialect_compatibility_mode.hpp new file mode 100644 index 000000000..5834bea25 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/dialect_compatibility_mode.hpp @@ -0,0 +1,17 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/dialect_compatibility_mode.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class DialectCompatibilityMode : uint8_t { NONE = 0, SPARK = 1 }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp index 1470a430e..d82cd438f 100644 --- a/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/optimizer_type.hpp @@ -54,6 +54,7 @@ enum class OptimizerType : uint32_t { PARTITIONED_EXECUTION = 37, PARTIAL_AGGREGATE_PUSHDOWN = 38, REMOTE_PUSHDOWN = 39, + GROUPING_SETS = 40, }; string OptimizerTypeToString(OptimizerType type); diff --git a/src/duckdb/src/include/duckdb/common/file_system.hpp b/src/duckdb/src/include/duckdb/common/file_system.hpp index 9d1df1ebe..64a57d1bc 100644 --- a/src/duckdb/src/include/duckdb/common/file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/file_system.hpp @@ -110,6 +110,7 @@ struct FileHandle { DUCKDB_API virtual FileCompressionType GetFileCompressionType(); DUCKDB_API bool CanSeek(); + DUCKDB_API bool SupportsPositionalWrites(); DUCKDB_API bool IsPipe(); DUCKDB_API bool OnDiskFile(); //! Try to obtain a network throughput estimate (Local files return false). @@ -307,6 +308,8 @@ class FileSystem { //! If FS was manually set by the user DUCKDB_API virtual bool IsManuallySet(); + //! Whether positional writes to this handle can be issued independently and out of order + DUCKDB_API virtual bool SupportsPositionalWrites(FileHandle &handle); //! Whether or not we can seek into the file DUCKDB_API virtual bool CanSeek(); //! Whether or not the FS handles plain files on disk. This is relevant for certain optimizations, as random reads diff --git a/src/duckdb/src/include/duckdb/common/local_file_system.hpp b/src/duckdb/src/include/duckdb/common/local_file_system.hpp index 70543ae53..acfbe92b6 100644 --- a/src/duckdb/src/include/duckdb/common/local_file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/local_file_system.hpp @@ -33,6 +33,7 @@ class LocalFileSystem : public FileSystem { int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override; //! Write nr_bytes from the buffer into the file, moving the file pointer forward by nr_bytes. int64_t Write(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + bool SupportsPositionalWrites(FileHandle &handle) override; //! Excise a range of the file. The file-system is free to deallocate this //! range (sparse file support). Reads to the range will succeed but will return //! undefined data. diff --git a/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp b/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp index bafcb4023..94dbec0a0 100644 --- a/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp +++ b/src/duckdb/src/include/duckdb/common/primitive_dictionary.hpp @@ -16,6 +16,11 @@ namespace duckdb { struct ParquetOperatorPageState; +struct PrimitiveDictionaryTargetData { + AllocatedData data; + idx_t size = 0; +}; + struct PrimitiveCastOperator { template static TGT Operation(SRC input, ParquetOperatorPageState *state) { @@ -47,7 +52,7 @@ class PrimitiveDictionary { public: static constexpr uint32_t MAXIMUM_POSSIBLE_SIZE = INVALID_INDEX - 1; - static constexpr idx_t INITIAL_TARGET_CAPACITY = 1048576; + static constexpr idx_t INITIAL_TARGET_CAPACITY = 262144; //! PrimitiveDictionary is a fixed-size linear probing hash table for primitive types //! It is used to dictionary-encode data in, e.g., Parquet files @@ -131,6 +136,15 @@ class PrimitiveDictionary { return result; } + //! Take ownership of the target written values + PrimitiveDictionaryTargetData TakeTargetData() { + PrimitiveDictionaryTargetData result; + result.size = target_stream.GetPosition(); + result.data = std::move(allocated_target); + target_stream = MemoryStream(nullptr, 0); + return result; + } + void Reset() { allocated_dictionary.Reset(); allocated_target.Reset(); diff --git a/src/duckdb/src/include/duckdb/common/serializer/async_file_writer.hpp b/src/duckdb/src/include/duckdb/common/serializer/async_file_writer.hpp new file mode 100644 index 000000000..12d289e10 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/async_file_writer.hpp @@ -0,0 +1,141 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/async_file_writer.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/file_system.hpp" +#include "duckdb/common/optional_ptr.hpp" +#include "duckdb/common/serializer/async_write_queue.hpp" +#include "duckdb/common/serializer/write_stream.hpp" +#include "duckdb/main/query_context.hpp" + +namespace duckdb { + +class ClientContext; +class CopiedAsyncWriteBuffer; + +//! WriteStream implementation that registers writes cheaply and drains them on the async task scheduler. +//! This is a logical stream writer: offsets are assigned when writes are registered via GetTotalWritten(). +//! Physical writes may complete out of order when positional writes are supported; WaitAll/Close complete the file. +//! Calls into this writer must be externally serialized; internal locking only coordinates with async drain tasks. +class AsyncFileWriter : public WriteStream, private ManagedAsyncWriteStreamTarget { +public: + //! RAII handle that batches write registration. Finish() must be called on the normal path to leave the batch and + //! start draining; scope exit only leaves the batch as exception cleanup. + class BatchGuard { + public: + BatchGuard(const BatchGuard &) = delete; + BatchGuard &operator=(const BatchGuard &) = delete; + DUCKDB_API BatchGuard(BatchGuard &&other) noexcept; + BatchGuard &operator=(BatchGuard &&other) = delete; + DUCKDB_API ~BatchGuard(); + + public: + //! Leave the batch and apply the writer's normal post-batch scheduling/backpressure policy. + DUCKDB_API void Finish(); + + private: + friend class AsyncFileWriter; + + DUCKDB_API explicit BatchGuard(AsyncFileWriter &writer); + + private: + optional_ptr writer; + }; + + //! Default file-open behavior for creating a write-locked output file. + static constexpr FileOpenFlags DEFAULT_OPEN_FLAGS = FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE; + +public: + DUCKDB_API AsyncFileWriter(QueryContext context, FileSystem &fs, const string &path, + FileOpenFlags open_flags = DEFAULT_OPEN_FLAGS); + DUCKDB_API ~AsyncFileWriter() override; + using WriteStream::Write; + +public: + //! Copy the provided bytes into owned storage and register them for asynchronous writing. + DUCKDB_API void WriteData(const_data_ptr_t buffer, idx_t write_size) override; + //! Transfer ownership of an existing write buffer and register it without copying. + DUCKDB_API void WriteData(unique_ptr buffer); + + //! Delay async task scheduling while the returned guard is alive. + DUCKDB_API BatchGuard StartBatch(); + //! Flush this WriteStream by waiting until all registered writes have reached the file handle. + DUCKDB_API void Flush(); + //! Wait until all registered writes have reached the file handle, and rethrow any async write error. + DUCKDB_API void WaitAll(); + //! Help drain async writes when pending bytes exceed the current memory budget. No-op while a batch is open. + DUCKDB_API void ApplyBackpressure(); + //! Wait for all writes, then close the file handle. + DUCKDB_API void Close(); + //! Wait for all writes, then fsync the file handle. + DUCKDB_API void Sync(); + //! Wait for all writes, then truncate the file to the requested logical size. + DUCKDB_API void Truncate(idx_t size); + + //! Return the logical file size, including writes that have been registered but not drained yet. + DUCKDB_API idx_t GetFileSize(); + //! Return the logical number of bytes written, including writes that are still pending. + DUCKDB_API idx_t GetTotalWritten() const; + +private: + using BatchDrainMode = ManagedAsyncWriteStreamQueue::BatchDrainMode; + using ScheduleMode = ManagedAsyncWriteStreamQueue::ScheduleMode; + using SchedulePolicy = ManagedAsyncWriteStreamQueue::SchedulePolicy; + + //! Register an owned buffer for writing, using the configured synchronous/asynchronous mode. + void RegisterWrite(unique_ptr buffer, ScheduleMode schedule_mode = ScheduleMode::ALLOW); + //! Register an owned buffer whose bytes were already counted in total_written. + void RegisterStagedWrite(unique_ptr buffer, idx_t offset, + ScheduleMode schedule_mode = ScheduleMode::ALLOW); + //! Add a buffer with its assigned file offset to the configured sync/async write path. + void RegisterWriteInternal(unique_ptr buffer, idx_t offset, ScheduleMode schedule_mode); + //! Write caller-owned bytes through the local staging buffer when async draining is disabled. + void WriteDataSynchronously(data_ptr_t buffer, idx_t write_size); + //! Move any staged copied bytes into the pending write queue. + void SealCopiedBuffer(ScheduleMode schedule_mode = ScheduleMode::ALLOW); + //! Seal copied bytes, then schedule as many drain tasks as the pending queue allows. + void SchedulePendingWrites(SchedulePolicy policy = SchedulePolicy::THRESHOLD); + //! Enter a registration batch, delaying async draining until the batch is left. + void BeginBatch(); + //! Leave a registration batch without scheduling, blocking, or throwing. + void LeaveBatch() noexcept; + + //! Return whether the file handle supports independent positional writes. + bool SupportsPositionalWrites() override; + //! Return whether this writer targets a local file-like handle. + bool IsLocalFile() override; + //! Write bytes to the underlying file handle at the assigned logical stream offset. + void Write(data_ptr_t buffer, idx_t size, idx_t offset) override; + //! Write bytes to the underlying file handle's current position. + void Write(data_ptr_t buffer, idx_t size) override; + //! Surface an error thrown by an async drain task. + void RethrowTaskError(); + //! Wait for scheduled writes, optionally restoring an active registration batch afterwards. + void WaitAllInternal(BatchDrainMode batch_drain_mode); + +private: + QueryContext context; + ClientContext &client_context; + FileSystem &fs; + string path; + unique_ptr handle; + //! Managed queue that owns stream scheduling, backpressure, and write coalescing. + unique_ptr write_queue; + + //! Copy staging buffer for small transient WriteData inputs. Only accessed by the registering thread. + unique_ptr copied_buffer; + //! Logical file offset of the first byte in copied_buffer. + idx_t copied_buffer_offset = 0; + //! Logical stream position, including copied/staged/pending bytes. Updated by the registering thread. + idx_t total_written = 0; + //! Set once the handle has been closed or detached. + bool closed = false; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/serializer/async_write_queue.hpp b/src/duckdb/src/include/duckdb/common/serializer/async_write_queue.hpp new file mode 100644 index 000000000..cd7720fb3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/serializer/async_write_queue.hpp @@ -0,0 +1,468 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/serializer/async_write_queue.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/deque.hpp" +#include "duckdb/common/error_data.hpp" +#include "duckdb/common/mutex.hpp" +#include "duckdb/common/optional_ptr.hpp" + +#include + +namespace duckdb { + +class ClientContext; +class TaskExecutor; +class TemporaryMemoryState; + +//! Compile-time policy used by the async write layers. +struct AsyncWriteConfig { + //! Capacity of the staging buffer used for small transient stream writes. + static constexpr idx_t COPIED_BUFFER_CAPACITY = 4096; + //! Maximum bytes a single low-level async task should drain before yielding scheduler capacity. + static constexpr idx_t TASK_BYTE_BUDGET = 4ULL * 1024ULL * 1024ULL; + //! Local file systems are cheap to call, so only coalesce up to the buffered writer page size. + static constexpr idx_t LOCAL_COALESCE_THRESHOLD = 4096; + //! Remote file systems benefit from fewer round trips, so coalesce contiguous small buffers more aggressively. + static constexpr idx_t REMOTE_COALESCE_THRESHOLD = 8ULL * 1024ULL * 1024ULL; + //! Maximum queued async bytes retained per regular execution thread. + static constexpr idx_t MAX_PENDING_BYTES_PER_THREAD = 64ULL * 1024ULL * 1024ULL; + //! Minimum async write reservation requested per regular execution thread. + static constexpr idx_t MIN_PENDING_BYTES_PER_THREAD = 8ULL * 1024ULL * 1024ULL; + //! Maximum bytes a single managed stream request should submit before yielding scheduler capacity. + static constexpr idx_t DRAIN_TASK_BYTE_BUDGET = 16ULL * 1024ULL * 1024ULL; +}; + +//! Owned payload that can be handed to an async write queue. +class AsyncWritePayload { +public: + virtual ~AsyncWritePayload() = default; + + //! Pointer to the bytes to write. The buffer must remain valid for the lifetime of this object. + virtual data_ptr_t Ptr() = 0; + //! Number of bytes exposed by Ptr(). + virtual idx_t Size() const = 0; +}; + +//! Compatibility name for existing stream-oriented callers. +using AsyncWriteBuffer = AsyncWritePayload; + +//! Completion callback for one physical async write request. The error is set when the write failed. +using AsyncWriteCompletionCallback = std::function error)>; + +//! One positional physical write request. +class AsyncWriteRequest { +public: + AsyncWriteRequest() = default; + AsyncWriteRequest(unique_ptr payload, idx_t offset, + AsyncWriteCompletionCallback completion = nullptr); + + idx_t Size() const; + + unique_ptr payload; + idx_t offset = 0; + AsyncWriteCompletionCallback completion; +}; + +//! Positional physical write target used by the low-level AsyncWriteQueue. +class AsyncWriteTarget { +public: + virtual ~AsyncWriteTarget() = default; + + //! Write a specific byte range using the target's positional write path. + virtual void Write(data_ptr_t buffer, idx_t size, idx_t offset) = 0; +}; + +//! Minimal positional async write scheduler. +//! Requests are independent positional writes; stream ordering, coalescing, and memory policy live in wrappers. +class AsyncWriteQueue { + friend class AsyncWriteQueueTask; + friend class AsyncWriteQueueTaskGuard; + +public: + DUCKDB_API AsyncWriteQueue(ClientContext &client_context, AsyncWriteTarget &target); + DUCKDB_API ~AsyncWriteQueue(); + + AsyncWriteQueue(const AsyncWriteQueue &) = delete; + AsyncWriteQueue &operator=(const AsyncWriteQueue &) = delete; + +public: + //! Return whether writes are drained by async scheduler tasks. If false, Submit writes synchronously. + DUCKDB_API bool IsAsync() const; + //! Return whether the async task executor has captured an error. + DUCKDB_API bool HasError(); + //! Submit one owned positional request to the configured sync/async write path. + DUCKDB_API void Submit(AsyncWriteRequest request); + //! Return queued/in-flight bytes that have not reached the target yet. + DUCKDB_API idx_t PendingBytes(); + //! Execute one queued task owned by this queue, or yield if the tasks are already running. + DUCKDB_API void WorkOnPendingTask(); + //! Wait until all submitted writes have reached the target. + DUCKDB_API void Flush(); + //! Wait for all writes and close the queue. + DUCKDB_API void Close(); + //! Surface an error thrown by an async drain task. + DUCKDB_API void RethrowTaskError(); + +private: + struct PendingRequest { + PendingRequest() = default; + explicit PendingRequest(AsyncWriteRequest request); + + idx_t Size() const; + + AsyncWriteRequest request; + idx_t size; + }; + +private: + //! Schedule pending requests until max_active_tasks is reached. + void ScheduleTasksInternal(bool force = false); + //! Return the byte budget for one task. + idx_t TaskByteBudget() const; + //! Return how many bytes one task should reserve after skipping already scheduled bytes. + idx_t SelectPendingRequestBytes(idx_t skip_bytes) const; + //! Move one reserved prefix of pending requests into a write task. Caller owns a scheduled task slot. + idx_t TakeRequests(deque &requests); + //! Release one scheduled/running task slot and its in-flight byte accounting. + void FinishTask(idx_t task_size); + //! Release a task slot for a scheduled task that never entered the queue because another task failed. + void CancelScheduledTask(); + //! Release multiple reserved task slots that were never scheduled. + void CancelScheduledTasks(idx_t task_count); + + //! Async task entry point that drains a bounded batch of positional requests. + void DrainRequests(); + //! Write request bytes to the target and invoke its completion callback. + void WriteRequest(AsyncWriteRequest request); + //! Invoke a completion callback outside the queue lock. + void CompleteRequest(AsyncWriteRequest &request, idx_t size, optional_ptr error); + //! Write bytes to the target at the assigned physical offset. + void WriteBuffer(data_ptr_t buffer, idx_t size, idx_t offset); + //! Throw if a mutating API is used after Close(). + void VerifyOpen() const; + //! Throw if the queue still owns registered or scheduled write work. + void VerifyDrained() const; + //! Fail and discard queued requests after an async write failure once all scheduled tasks have stopped. + void CancelPendingRequestsAfterFailure(const ErrorData &error) noexcept; + +private: + ClientContext &client_context; + AsyncWriteTarget ⌖ + //! Maximum scheduled/running write tasks for this queue. + idx_t max_active_tasks = 1; + //! Maximum bytes a single async task should drain. + idx_t task_byte_budget = AsyncWriteConfig::TASK_BYTE_BUDGET; + + //! Protects state shared between the submitting thread and async write tasks. + mutex lock; + //! Positional requests waiting for an async task. + deque pending_requests; + //! Bytes queued in pending_requests that have not been taken by a task yet. + idx_t pending_bytes = 0; + //! Bytes owned by write tasks that have not reached the target yet. + idx_t in_flight_bytes = 0; + //! Bytes in pending_requests already reserved by scheduled-but-not-started tasks. + idx_t scheduled_pending_bytes = 0; + //! Scheduled or running write tasks for this queue. + idx_t active_tasks = 0; + //! Scheduled write tasks that have not yet claimed a request. + idx_t pending_tasks = 0; + //! Per-task byte reservations for scheduled tasks that have not yet claimed their requests. + deque pending_task_bytes; + //! Set after Close() has drained the queue. Further submissions are rejected. + bool closed = false; + + //! Async task executor. If absent, writes are performed synchronously on submission. + //! Keep this after task-accounting fields so queued task destructors can still release slots. + unique_ptr executor; +}; + +//! Stream target used by ManagedAsyncWriteStreamQueue. Sequential fallback is handled here, not in AsyncWriteQueue. +class ManagedAsyncWriteStreamTarget { +public: + virtual ~ManagedAsyncWriteStreamTarget() = default; + + //! Whether contiguous registered writes can safely drain concurrently through positional writes. + virtual bool SupportsPositionalWrites() = 0; + //! Whether this target is a local file-like target. Remote targets use larger coalesced writes. + virtual bool IsLocalFile() = 0; + //! Write a specific byte range using the target's positional write path. + virtual void Write(data_ptr_t buffer, idx_t size, idx_t offset) = 0; + //! Write bytes using the target's sequential write path. + virtual void Write(data_ptr_t buffer, idx_t size) = 0; +}; + +//! Managed positional write queue built on top of AsyncWriteQueue. +//! Requests may target independent offsets; stream ordering and coalescing live in ManagedAsyncWriteStreamQueue. +class ManagedAsyncWriteQueue : private AsyncWriteTarget { + friend class ManagedAsyncWriteStreamQueue; + +public: + //! Whether registering a payload may schedule an async drain request immediately. + enum class ScheduleMode : uint8_t { ALLOW, DEFER }; + //! Whether to force a TemporaryMemoryState growth check instead of relying on coarse growth. + enum class MemoryUpdateMode : uint8_t { COARSE, FORCE }; + //! Whether to schedule only enough request capacity for normal overlap, or force all pending bytes to drain. + enum class SchedulePolicy : uint8_t { THRESHOLD, FORCE }; + +public: + DUCKDB_API ManagedAsyncWriteQueue(ClientContext &client_context, AsyncWriteTarget &target); + DUCKDB_API ~ManagedAsyncWriteQueue() override; + + ManagedAsyncWriteQueue(const ManagedAsyncWriteQueue &) = delete; + ManagedAsyncWriteQueue &operator=(const ManagedAsyncWriteQueue &) = delete; + +public: + //! Return whether writes are drained by async scheduler tasks. If false, RegisterWrite writes synchronously. + DUCKDB_API bool IsAsync() const; + //! Return whether the async task executor has captured an error. + DUCKDB_API bool HasError(); + //! Add an owned positional payload to the configured sync/async write path. + DUCKDB_API void RegisterWrite(unique_ptr payload, idx_t offset, + ScheduleMode schedule_mode = ScheduleMode::ALLOW); + //! Add one positional request to the configured sync/async write path. + DUCKDB_API void RegisterWrite(AsyncWriteRequest request, ScheduleMode schedule_mode = ScheduleMode::ALLOW); + //! Schedule as many drain requests as the pending queue allows. + DUCKDB_API void SchedulePendingWrites(SchedulePolicy policy = SchedulePolicy::THRESHOLD); + //! Help drain async writes when pending bytes exceed the current memory budget. + DUCKDB_API void ApplyBackpressure(); + //! Wait until all registered writes have reached the target. + DUCKDB_API void WaitAll(); + //! Drain all writes, close the queue, and release the TemporaryMemoryState reservation. + DUCKDB_API void Close(); + //! Release the queue's TemporaryMemoryState reservation. + DUCKDB_API void ReleaseMemoryReservation(); + //! Surface an error thrown by an async drain task. + DUCKDB_API void RethrowTaskError(); + +private: + struct PendingWrite { + explicit PendingWrite(AsyncWriteRequest request); + + idx_t Size() const; + + AsyncWriteRequest request; + idx_t size; + }; + +private: + //! Add one positional request whose bytes are already tracked as external pending bytes. + void RegisterAccountedWrite(AsyncWriteRequest request, ScheduleMode schedule_mode = ScheduleMode::ALLOW); + //! Track bytes held by a wrapper before they become positional requests. + void AddExternalPendingBytes(idx_t bytes, bool update_memory = true); + //! Stop tracking wrapper-held bytes that will never become positional requests. + void DiscardExternalPendingBytes(idx_t bytes) noexcept; + //! Add one request to the managed queue. Caller may mark bytes already tracked as external. + void RegisterWriteInternal(AsyncWriteRequest request, idx_t accounted_external_bytes, ScheduleMode schedule_mode); + //! Schedule drain requests from already registered pending writes. + void SchedulePendingWritesInternal(SchedulePolicy policy = SchedulePolicy::THRESHOLD); + //! Grow the TemporaryMemoryState reservation coarsely; it is released only when the queue closes. + void UpdateMemoryState(MemoryUpdateMode mode = MemoryUpdateMode::COARSE); + + //! Return the current async backlog budget after applying the fixed queue cap. + idx_t BackpressureBudget(); + //! Effective byte budget for one managed drain request. + idx_t DrainTaskByteBudget() const; + //! Return queued/submitted/external bytes that have not reached the target yet. Caller must hold lock. + idx_t TotalPendingBytes() const; + //! Return how many physical bytes can be submitted to the low-level queue before refilling should pause. + idx_t SubmittedByteWindow() const; + + //! Move one pending positional write into a physical async request. + bool TakePendingWriteRequest(AsyncWriteRequest &request, SchedulePolicy policy); + //! Wrap a request callback so submitted-byte accounting is released before user callbacks run. + void AddCompletionAccounting(AsyncWriteRequest &request); + //! Release byte accounting for one submitted physical request. + void CompleteSubmittedWrite(idx_t offset, idx_t size, optional_ptr error); + + //! Throw if a mutating API is used after Close(). + void VerifyOpen() const; + //! Throw if the queue still owns registered or scheduled write work. + void VerifyDrained() const; + //! Fail and discard queued writes after an async write failure once all submitted writes have stopped. + void CancelPendingWritesAfterFailure(const ErrorData &error) noexcept; + + //! Write bytes to the managed target at the assigned physical offset. + void Write(data_ptr_t buffer, idx_t size, idx_t offset) override; + +private: + ClientContext &client_context; + AsyncWriteTarget ⌖ + + //! Low-level positional request scheduler. + unique_ptr write_queue; + //! Temporary memory reservation state used to limit queued async write data. + unique_ptr memory_state; + //! Last remaining-size request sent to TemporaryMemoryManager. Grows monotonically until close. + idx_t memory_request_bytes = 0; + //! Maximum number of submitted/running drain requests for this queue. + idx_t max_active_drain_tasks = 1; + //! Minimum TemporaryMemoryManager reservation while writes are outstanding. + idx_t min_pending_bytes = 0; + //! Hard cap over the TemporaryMemoryState reservation. + idx_t max_pending_bytes = 0; + //! Maximum bytes one managed async request should submit before yielding scheduler capacity. + idx_t drain_task_byte_budget = AsyncWriteConfig::DRAIN_TASK_BYTE_BUDGET; + + //! Protects state shared between registering threads and async completion callbacks. + mutex lock; + //! Positional payloads queued for submission to AsyncWriteQueue. + deque pending_writes; + //! Bytes queued in pending_writes that have not been submitted to AsyncWriteQueue yet. + idx_t pending_bytes = 0; + //! Bytes tracked by a wrapper before they become positional requests. + idx_t external_pending_bytes = 0; + //! Bytes submitted to AsyncWriteQueue that have not completed yet. + idx_t submitted_bytes = 0; + //! Submitted physical requests that have not completed yet. + idx_t submitted_requests = 0; + //! Set after Close() has drained the queue. Further write registration is rejected. + bool closed = false; +}; + +//! Managed stream-oriented write queue built on top of ManagedAsyncWriteQueue. +//! V1 is a contiguous logical write queue: each RegisterWrite offset must match the next expected offset. +//! Callers are responsible for assigning offsets and externally serializing RegisterWrite calls. +class ManagedAsyncWriteStreamQueue : private AsyncWriteTarget { +public: + //! Whether registering a payload may schedule an async drain request immediately. + enum class ScheduleMode : uint8_t { ALLOW, DEFER }; + //! Whether to schedule only enough request capacity for normal overlap, or force all pending bytes to drain. + enum class SchedulePolicy : uint8_t { THRESHOLD, FORCE }; + //! Whether async requests can write independent target ranges concurrently. + enum class DrainMode : uint8_t { SEQUENTIAL, POSITIONAL }; + //! Whether waiting for scheduled writes should preserve an open registration batch. + enum class BatchDrainMode : uint8_t { PRESERVE_BATCH, FORCE_CLOSE_BATCH }; + +public: + DUCKDB_API ManagedAsyncWriteStreamQueue(ClientContext &client_context, ManagedAsyncWriteStreamTarget &target); + DUCKDB_API ~ManagedAsyncWriteStreamQueue() override; + + ManagedAsyncWriteStreamQueue(const ManagedAsyncWriteStreamQueue &) = delete; + ManagedAsyncWriteStreamQueue &operator=(const ManagedAsyncWriteStreamQueue &) = delete; + +public: + //! Return whether writes are drained by async scheduler tasks. If false, RegisterWrite writes synchronously. + DUCKDB_API bool IsAsync() const; + //! Return whether the async task executor has captured an error. + DUCKDB_API bool HasError(); + //! Add an owned payload at the next contiguous logical offset to the configured sync/async write path. + DUCKDB_API void RegisterWrite(unique_ptr payload, idx_t offset, + ScheduleMode schedule_mode = ScheduleMode::ALLOW); + //! Enter a registration batch, delaying async draining until the batch is left. + DUCKDB_API void BeginBatch(); + //! Leave a registration batch without scheduling, blocking, or throwing. + DUCKDB_API void LeaveBatch() noexcept; + //! Return whether a registration batch is currently open. + DUCKDB_API bool HasOpenBatch(); + //! Schedule as many drain requests as the pending queue allows. + DUCKDB_API void SchedulePendingWrites(SchedulePolicy policy = SchedulePolicy::THRESHOLD); + //! Help drain async writes when pending bytes exceed the current memory budget. No-op while a batch is open. + DUCKDB_API void ApplyBackpressure(); + //! Wait for scheduled writes, optionally restoring an active registration batch afterwards. + DUCKDB_API void WaitAll(BatchDrainMode batch_drain_mode = BatchDrainMode::PRESERVE_BATCH); + //! Drain all writes, close any open registration batch, and release the TemporaryMemoryState reservation. + DUCKDB_API void Close(); + //! Reset the next expected contiguous offset after all registered writes have drained. + DUCKDB_API void ResetNextOffset(idx_t offset); + //! Release the queue's TemporaryMemoryState reservation. + DUCKDB_API void ReleaseMemoryReservation(); + //! Surface an error thrown by an async drain task. + DUCKDB_API void RethrowTaskError(); + +private: + struct PendingWrite { + PendingWrite(unique_ptr payload, idx_t offset); + + idx_t Size() const; + + unique_ptr payload; + idx_t offset; + }; + + class CoalescedWritePayload; + +private: + //! Schedule drain requests from already registered pending writes. + void SchedulePendingWritesInternal(SchedulePolicy policy = SchedulePolicy::THRESHOLD); + + //! Effective byte budget for one managed drain request, never smaller than the coalescing threshold. + idx_t DrainTaskByteBudget() const; + //! Return queued/submitted bytes that have not reached the target yet. Caller must hold lock. + idx_t TotalPendingBytes() const; + //! Select the pending write range one primitive task would claim. Caller must hold lock. + idx_t SelectPendingWriteEnd(idx_t start, idx_t &selected_bytes) const; + //! Select the pending write range for the next physical write request. Caller must hold lock. + idx_t SelectPhysicalWriteEnd(idx_t start, idx_t &selected_bytes) const; + //! Return how many physical bytes can be submitted to the low-level queue before refilling should pause. + idx_t SubmittedByteWindow() const; + + //! Move one byte-budgeted prefix of pending writes into a physical async request. + bool TakePendingWriteRequest(AsyncWriteRequest &request, SchedulePolicy policy); + //! Convert one or more contiguous pending writes into a lazily materialized payload. + unique_ptr CreatePayload(deque writes, idx_t size); + //! Release byte accounting for one submitted physical request. + void CompleteSubmittedWrite(idx_t offset, idx_t size, optional_ptr error); + + //! Validate a new registration against the contiguous offset contract. + idx_t ValidateRegistrationOffset(idx_t offset, idx_t write_size) const; + //! Throw if a mutating API is used after Close(). + void VerifyOpen() const; + //! Validate a pending write before coalescing it with its predecessor. + void VerifyContiguousWrite(const PendingWrite &write, idx_t expected_offset) const; + //! Return offset + write_size, throwing if it overflows idx_t. + idx_t NextWriteOffset(idx_t offset, idx_t write_size) const; + //! Throw if the queue still owns registered or scheduled write work. + void VerifyDrained() const; + //! Discard queued writes after an async write failure once all submitted writes have stopped. + void CancelPendingWritesAfterFailure() noexcept; + + //! Write bytes to the managed target at the assigned logical offset. + void Write(data_ptr_t buffer, idx_t size, idx_t offset) override; + +private: + ClientContext &client_context; + ManagedAsyncWriteStreamTarget ⌖ + + //! Positional managed queue that owns TMM reservation, backpressure, and task scheduling. + unique_ptr write_queue; + //! Whether async requests may drain independent ranges concurrently using positional writes. + DrainMode drain_mode = DrainMode::SEQUENTIAL; + //! Maximum number of submitted/running drain requests for this queue. + idx_t max_active_drain_tasks = 1; + //! Size below which adjacent writes are coalesced before reaching the target. + idx_t coalesce_threshold = 0; + //! Minimum queued bytes before threshold scheduling starts the first async task. + idx_t first_task_schedule_threshold = 0; + //! Maximum bytes one stream request should hand to the positional managed queue. + idx_t drain_task_byte_budget = 0; + //! Stop each local coalesced write at coalesce_threshold. + bool limit_coalesced_write_size = false; + + //! Protects state shared between the registering thread and async completion callbacks. + mutex lock; + //! Pending payloads in registration order with pre-assigned logical offsets. + deque pending_writes; + //! Bytes queued in pending_writes that have not been submitted to AsyncWriteQueue yet. + idx_t pending_bytes = 0; + //! Bytes submitted to AsyncWriteQueue that have not completed yet. + idx_t submitted_bytes = 0; + //! Submitted physical requests that have not completed yet. + idx_t submitted_requests = 0; + //! Nested batch depth. While non-zero, async draining and backpressure are delayed. + idx_t batch_depth = 0; + //! Next logical offset expected by RegisterWrite. Enforces v1 contiguous-registration semantics. + idx_t next_registration_offset = 0; + //! Set after Close() has drained the queue. Further write registration is rejected. + bool closed = false; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/vector/shredded_vector.hpp b/src/duckdb/src/include/duckdb/common/vector/shredded_vector.hpp index 03ef7b140..43e0272ec 100644 --- a/src/duckdb/src/include/duckdb/common/vector/shredded_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/vector/shredded_vector.hpp @@ -30,6 +30,7 @@ class ShreddedVectorBuffer : public VectorBuffer { idx_t GetAllocationSize() const override; string ToString(const LogicalType &type, idx_t count) const override; Value GetValue(const LogicalType &type, idx_t index) const override; + void SetVectorType(VectorType new_vector_type) override; protected: buffer_ptr FlattenSliceInternal(const LogicalType &type, const SelectionVector &sel, diff --git a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp index 1f2ba57fb..17aa86418 100644 --- a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp +++ b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp @@ -558,6 +558,7 @@ class JoinHashTable { return should_build_prefix_range_filter && prefix_range_filter; } + void BuildPrefixRangeFilter(); unique_ptr InitializePrefixRangeBuildState(); void InsertPrefixRangeChunk(TupleDataChunkState &chunk_state, idx_t count, PrefixRangeFilter::BuildState &state); void MergePrefixRangeBuildState(PrefixRangeFilter::BuildState &state); diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_streaming_sample.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_streaming_sample.hpp index 1b5165440..8d50f2630 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_streaming_sample.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_streaming_sample.hpp @@ -24,6 +24,8 @@ class PhysicalStreamingSample : public PhysicalOperator { unique_ptr sample_options; double percentage; + double system_sample_phase; + idx_t rows; public: // Operator interface @@ -36,7 +38,8 @@ class PhysicalStreamingSample : public PhysicalOperator { InsertionOrderPreservingMap ParamsToString() const override; private: - void SystemSample(DataChunk &input, DataChunk &result, OperatorState &state) const; + void SystemSamplePercent(DataChunk &input, DataChunk &result, OperatorState &state) const; + void SystemSampleRows(DataChunk &input, DataChunk &result, OperatorState &state) const; void BernoulliSample(DataChunk &input, DataChunk &result, OperatorState &state) const; }; diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp index dcb6b9dd3..da6cba504 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/join_filter_pushdown.hpp @@ -21,14 +21,13 @@ class DynamicTableFilterSet; class LogicalGet; class JoinHashTable; class PhysicalComparisonJoin; -class PerfectHashJoinExecutor; struct GlobalUngroupedAggregateState; struct LocalUngroupedAggregateState; enum class JoinFilterPushdownMode : uint8_t { - //! The pushed expression can be reconstructed on top of the raw scan value for BF/PRF/PHJ runtime filters + //! The pushed expression can be reconstructed on top of the raw scan value for BF/PRF runtime filters RECONSTRUCT_EXPRESSION, - //! Only storage-domain filters are safe; BF/PRF/PHJ reconstruction on raw scan values is not + //! Only storage-domain filters are safe; BF/PRF reconstruction on raw scan values is not STORAGE_ONLY }; @@ -40,7 +39,7 @@ struct JoinFilterPushdownColumn { //! Whether runtime filters can reconstruct the pushed expression, or whether only storage-domain filters are safe JoinFilterPushdownMode mode = JoinFilterPushdownMode::RECONSTRUCT_EXPRESSION; //! The original type of the pushed probe expression before rewriting to the LogicalGet storage column. Only used - //! when the mode allows reconstruction of the probe expression for BF/PRF/PHJ runtime filters. + //! when the mode allows reconstruction of the probe expression for BF/PRF runtime filters. LogicalType runtime_filter_type; }; @@ -96,33 +95,29 @@ struct JoinFilterPushdownInfo { void Sink(DataChunk &chunk, JoinFilterLocalState &lstate) const; void Combine(JoinFilterGlobalState &gstate, JoinFilterLocalState &lstate) const; unique_ptr Finalize(ClientContext &context, JoinFilterGlobalState &gstate, - const PhysicalComparisonJoin &op, optional_ptr ht = nullptr, - optional_ptr perfect_hash_join_executor = nullptr) const; + const PhysicalComparisonJoin &op, optional_ptr ht = nullptr) const; unique_ptr FinalizeMinMax(JoinFilterGlobalState &gstate) const; unique_ptr FinalizeFilters(ClientContext &context, const PhysicalComparisonJoin &op, unique_ptr final_min_max, optional_ptr ht = nullptr, - optional_ptr perfect_join_executor = nullptr) const; + bool allow_bloom_filters = true, + bool allow_prefix_range_filters = true) const; private: - void PushInFilter(const JoinFilterPushdownFilter &info, JoinHashTable &ht, const PhysicalOperator &op, + bool PushInFilter(const JoinFilterPushdownFilter &info, JoinHashTable &ht, const PhysicalOperator &op, idx_t filter_idx, ProjectionIndex filter_col_idx) const; void PushBloomFilter(ClientContext &context, const PhysicalOperator &op, JoinHashTable &ht, const JoinFilterPushdownFilter &info, idx_t filter_idx, ProjectionIndex filter_col_idx) const; - void PushPerfectHashJoinFilter(ClientContext &context, const PhysicalOperator &op, - PerfectHashJoinExecutor &perfect_join_executor, const JoinFilterPushdownFilter &info, - idx_t filter_idx, ProjectionIndex filter_col_idx) const; - void RegisterPrefixRangeFilter(const JoinFilterPushdownFilter &info, ClientContext &context, JoinHashTable &ht, - const PhysicalOperator &op, idx_t filter_idx, ProjectionIndex filter_col_idx, - const Value &min_val, const Value &max_val) const; + bool TryRegisterPrefixRangeFilter(const JoinFilterPushdownFilter &info, ClientContext &context, JoinHashTable &ht, + const PhysicalOperator &op, idx_t filter_idx, ProjectionIndex filter_col_idx, + const Value &min_val, const Value &max_val, idx_t max_bits) const; bool CanUseInFilter(const ClientContext &context, optional_ptr ht, const ExpressionType &cmp) const; bool CanUseBloomFilter(const ClientContext &context, const PhysicalComparisonJoin &op, const ExpressionType &cmp, optional_ptr ht = nullptr) const; - bool CanUsePrefixRangeFilter(ClientContext &context, optional_ptr ht, - const PhysicalComparisonJoin &op, const ExpressionType &cmp, const Value &min, - const Value &max) const; + bool CanUsePrefixRangeFilter(const ClientContext &context, const PhysicalComparisonJoin &op, + optional_ptr ht, const ExpressionType &cmp) const; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/aggregate/list_aggregate.hpp b/src/duckdb/src/include/duckdb/function/aggregate/list_aggregate.hpp index 8dd438fb4..f8f6fdfc6 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate/list_aggregate.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate/list_aggregate.hpp @@ -16,9 +16,9 @@ namespace duckdb { //! The state of the "list" aggregate - shared by aggregates that buffer their input in a linked list struct ListAggState { - LinkedList linked_list; - using STATE_TYPE = StateListType; + + LinkedList linked_list; }; struct ListFunction { diff --git a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp index db1eea887..55100de2a 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp @@ -765,7 +765,13 @@ inline LogicalType AggregateFunction::BuildStateLogical(const BoundAggregateFunc // the runtime types of the bound function - used to resolve StateReturnType/StateInputType sources StateLayoutTypeInfo info {bound_function.GetReturnType(), bound_function.GetArguments()}; if constexpr (IsStateListType::value) { - return ResolveStateSourceType(info); + using SRC = typename ST::SOURCE_TYPE; + if constexpr (IsStateTypeSource::value) { + return ResolveStateSourceType(info); + } else { + // the element is described by a nested field descriptor (e.g. a sort key) + return LogicalType::LIST(FieldToLogicalType(info)); + } } else if constexpr (IsOptionalStateType::value) { using V = typename ST::value_type; if constexpr (IsStructStateType::value) { diff --git a/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp b/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp index 840a2db34..e7c5a137d 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp @@ -52,6 +52,12 @@ struct IsStateInputType : std::false_type {}; template struct IsStateInputType> : std::true_type {}; +//! Detection trait: true when T is a type source marker (StateReturnType or StateInputType) - as opposed to a +//! field descriptor (StateSortKey, StateTypedValue, ...) that describes a physical field rather than just a type. +template +struct IsStateTypeSource + : std::integral_constant::value || IsStateInputType::value> {}; + //! The runtime types of a bound aggregate function - used to resolve the logical types of state fields that are //! only known after binding (see StateReturnType / StateInputType). struct StateLayoutTypeInfo { @@ -76,7 +82,7 @@ LogicalType ResolveStateSourceType(const StateLayoutTypeInfo &info) { //! Signals that the field stores a binary sort key (string_t) that must be decoded/encoded //! via CreateSortKeyHelpers when exporting/importing aggregate state. //! SOURCE describes where the decoded logical type comes from; ORDER is the ordering used when creating the sort key. -template +template struct StateSortKey { using SOURCE_TYPE = SOURCE; static constexpr OrderType order_type = ORDER; @@ -111,6 +117,8 @@ using StateString = StateTypedValue; //! Signals that the state is a LinkedList (see list_segment.hpp) holding the rows of a LIST value. //! Export reads the linked list into a LIST vector; import appends the LIST value's rows back into a linked list. //! SOURCE describes where the list's logical type comes from. +//! SOURCE may be a StateSortKey, signalling that the linked list physically stores binary +//! sort keys (string_t): export decodes each element via CreateSortKeyHelpers, import re-encodes them. template struct StateListType { using SOURCE_TYPE = SOURCE; @@ -185,7 +193,15 @@ template LogicalType FieldToLogicalType(const StateLayoutTypeInfo &info) { if constexpr (IsOptionalStateType::value) { return FieldToLogicalType(info); - } else if constexpr (IsStateSortKeyType::value || IsStateTypedValueType::value || IsStateListType::value) { + } else if constexpr (IsStateListType::value) { + using SRC = typename T::SOURCE_TYPE; + if constexpr (IsStateTypeSource::value) { + return ResolveStateSourceType(info); + } else { + // the element is described by a nested field descriptor (e.g. a sort key) + return LogicalType::LIST(FieldToLogicalType(info)); + } + } else if constexpr (IsStateSortKeyType::value || IsStateTypedValueType::value) { return ResolveStateSourceType(info); } else if constexpr (HasStructStateType::value) { return T::STATE_TYPE::GetLogicalType(T::STATE_NAMES, info); @@ -226,6 +242,8 @@ struct AggregateStateField { idx_t field_alignment = 0; AggregateFieldKind kind = AggregateFieldKind::PRIMITIVE; OrderType sort_key_order = OrderType::ASCENDING; // only meaningful when kind == SORT_KEY + //! For LIST: always holds a single element descriptor (children[0]). A SORT_KEY element means the linked list + //! physically stores binary sort keys; any other kind means the elements are stored directly. vector children; //! The segment functions used to read/write the linked list - only set when kind is LIST //! (populated by PopulateListFunctions, which requires the resolved logical type) @@ -259,10 +277,17 @@ struct AggregateStateField { //! alongside the fields. Called once when the layout is created. static void PopulateListFunctions(const LogicalType &type, AggregateStateField &field) { switch (field.kind) { - case AggregateFieldKind::LIST: + case AggregateFieldKind::LIST: { D_ASSERT(type.id() == LogicalTypeId::LIST); - GetSegmentDataFunctions(field.list_functions, ListType::GetChildType(type)); + D_ASSERT(field.children.size() == 1); + // sort-key elements are physically stored as BLOB, all other elements as their logical child type + const auto child_type = ListType::GetChildType(type); + const auto stored_type = + field.children[0].kind == AggregateFieldKind::SORT_KEY ? LogicalType::BLOB : child_type; + GetSegmentDataFunctions(field.list_functions, stored_type); + PopulateListFunctions(child_type, field.children[0]); break; + } case AggregateFieldKind::OPTIONAL_VALUE: D_ASSERT(field.children.size() == 1); PopulateListFunctions(type, field.children[0]); @@ -421,6 +446,14 @@ AggregateStateField BuildStateField() { field.kind = AggregateFieldKind::LIST; field.field_size = sizeof(LinkedList); field.field_alignment = alignof(LinkedList); + using SRC = typename T::SOURCE_TYPE; + if constexpr (IsStateTypeSource::value) { + // the elements are stored directly - the (PRIMITIVE) element field resolves its type from the logical child + field.children.emplace_back(); + } else { + // the source is a field descriptor (e.g. a sort key) describing a per-element transform + field.children.push_back(BuildStateField()); + } } else if constexpr (IsStructStateType::value) { // T is StructStateType — the phantom descriptor type itself field.kind = AggregateFieldKind::STRUCT; diff --git a/src/duckdb/src/include/duckdb/function/scalar/tablefilter_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/tablefilter_functions.hpp index 65f83baa3..6d1a27029 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/tablefilter_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/tablefilter_functions.hpp @@ -45,16 +45,6 @@ struct TableFilterOptionalFun { static ScalarFunction GetFunction(); }; -struct TableFilterPerfectHashJoinFun { - static constexpr const char *Name = "__internal_tablefilter_perfect_hash_join"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = ""; - static constexpr const char *Example = ""; - static constexpr const char *Categories = ""; - - static ScalarFunction GetFunction(); -}; - struct TableFilterPrefixRangeFun { static constexpr const char *Name = "__internal_tablefilter_prefix_range"; static constexpr const char *Parameters = ""; diff --git a/src/duckdb/src/include/duckdb/main/settings.hpp b/src/duckdb/src/include/duckdb/main/settings.hpp index 167a62fe9..bb269660f 100644 --- a/src/duckdb/src/include/duckdb/main/settings.hpp +++ b/src/duckdb/src/include/duckdb/main/settings.hpp @@ -781,6 +781,18 @@ struct DeprecatedUsingKeySyntaxSetting { static void OnSet(SettingCallbackInfo &info, Value &input); }; +struct DialectCompatibilityModeSetting { + using RETURN_TYPE = DialectCompatibilityMode; + static constexpr const char *Name = "dialect_compatibility_mode"; + static constexpr const char *Description = + "Enable SQL dialect compatibility for a certain engine (e.g. `SET dialect_compatibility_mode='spark'`)"; + static constexpr const char *InputType = "VARCHAR"; + static constexpr const char *DefaultValue = "NONE"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::LOCAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); + static void OnSet(SettingCallbackInfo &info, Value &input); +}; + struct DisableDatabaseInvalidationSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "disable_database_invalidation"; @@ -1158,6 +1170,17 @@ struct ForceMbedtlsUnsafeSetting { static Value GetSetting(const ClientContext &context); }; +struct ForceUpdateToDelAndInsertSetting { + using RETURN_TYPE = bool; + static constexpr const char *Name = "force_update_to_del_and_insert"; + static constexpr const char *Description = + "DEBUG SETTING: forces all updates to use the delete + insert code path instead of in-place updates"; + static constexpr const char *InputType = "BOOLEAN"; + static constexpr const char *DefaultValue = "false"; + static constexpr SettingScopeTarget Scope = SettingScopeTarget::LOCAL_DEFAULT; + static constexpr idx_t SettingIndex = NEXT_SETTING_INDEX(); +}; + struct ForceVariantShredding { using RETURN_TYPE = string; static constexpr const char *Name = "force_variant_shredding"; diff --git a/src/duckdb/src/include/duckdb/optimizer/grouping_sets_optimizer.hpp b/src/duckdb/src/include/duckdb/optimizer/grouping_sets_optimizer.hpp new file mode 100644 index 000000000..fb96eb8d3 --- /dev/null +++ b/src/duckdb/src/include/duckdb/optimizer/grouping_sets_optimizer.hpp @@ -0,0 +1,35 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/optimizer/grouping_sets_optimizer.hpp +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/planner/column_binding_map.hpp" +#include "duckdb/planner/logical_operator_visitor.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" + +namespace duckdb { +class Optimizer; + +//! The GroupingSetsOptimizer rewrites aggregates over multiple grouping sets (ROLLUP/CUBE/GROUPING SETS) into a +//! cascade of aggregations connected through materialized CTEs. The finest grouping set is computed over the base +//! data with the aggregate states exported, after which coarser grouping sets are computed by combining the states +//! of a finer grouping set - instead of re-scanning and re-aggregating the base data for every grouping set. +class GroupingSetsOptimizer : public LogicalOperatorVisitor { +public: + explicit GroupingSetsOptimizer(Optimizer &optimizer); + + void VisitOperator(unique_ptr &op) override; + +private: + bool TryRewriteGroupingSets(unique_ptr &op); + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override; + +private: + Optimizer &optimizer; + column_binding_map_t replacement_map; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/optimizer/sampling_pushdown.hpp b/src/duckdb/src/include/duckdb/optimizer/sampling_pushdown.hpp index 78c67a199..34952ee11 100644 --- a/src/duckdb/src/include/duckdb/optimizer/sampling_pushdown.hpp +++ b/src/duckdb/src/include/duckdb/optimizer/sampling_pushdown.hpp @@ -8,18 +8,22 @@ #pragma once -#include "duckdb/common/constants.hpp" #include "duckdb/planner/logical_operator.hpp" #include "duckdb/common/unique_ptr.hpp" namespace duckdb { -class LocigalOperator; +class ClientContext; class Optimizer; class SamplingPushdown { public: + explicit SamplingPushdown(ClientContext &context) : context(context) { + } //! Optimize SYSTEM SAMPLING + SCAN to SAMPLE SCAN unique_ptr Optimize(unique_ptr op); + +private: + ClientContext &context; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp b/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp index 834ed8d2a..2de1f7cbb 100644 --- a/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp +++ b/src/duckdb/src/include/duckdb/parallel/task_scheduler.hpp @@ -57,6 +57,8 @@ class TaskScheduler { unique_ptr CreateProducer(); //! Returns the number of threads DUCKDB_API int32_t NumberOfThreads(); + //! Returns the number of async threads + DUCKDB_API int32_t NumberOfAsyncThreads(); idx_t GetNumberOfTasks() const; idx_t GetProducerCount() const; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp index b9cb0841c..c7cda9302 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/sample_options.hpp @@ -33,6 +33,7 @@ class SampleOptions { SampleMethod method; optional_idx seed = optional_idx::Invalid(); bool repeatable; + double sample_rate = -1.0; unique_ptr Copy(); void SetSeed(idx_t new_seed); diff --git a/src/duckdb/src/include/duckdb/planner/filter/expression_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/expression_filter.hpp index 29d96feea..58b67c366 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/expression_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/expression_filter.hpp @@ -61,6 +61,8 @@ class ExpressionFilter : public TableFilter { static bool IsOptionalFilter(const TableFilter &filter); //! Check if the root of a table filter tree is an optional filter wrapper static bool IsRootOptionalFilter(const TableFilter &filter); + //! Check if the root of a table filter tree is a non-selectivity optional filter wrapper + static bool IsRootNonSelectivityOptionalFilter(const TableFilter &filter); //! If this is an optional/selectivity-optional wrapper around a root dynamic filter, //! return the shared dynamic filter state. static shared_ptr GetRootOptionalDynamicFilterData(const TableFilter &filter); diff --git a/src/duckdb/src/include/duckdb/planner/filter/perfect_hash_join_filter.hpp b/src/duckdb/src/include/duckdb/planner/filter/perfect_hash_join_filter.hpp deleted file mode 100644 index bc9b60f2a..000000000 --- a/src/duckdb/src/include/duckdb/planner/filter/perfect_hash_join_filter.hpp +++ /dev/null @@ -1,40 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/filter/perfect_hash_join_filter.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/planner/table_filter.hpp" - -namespace duckdb { - -class PerfectHashJoinExecutor; - -//! DEPRECATED - only preserved for backwards-compatible expression conversion -class LegacyPerfectHashJoinFilter final : public TableFilter { -public: - static constexpr auto TYPE = TableFilterType::LEGACY_PERFECT_HASH_JOIN_FILTER; - -public: - LegacyPerfectHashJoinFilter(optional_ptr perfect_join_executor, - const string &key_column_name, const LogicalType &key_type_p); - -private: - unique_ptr ToExpression(const Expression &column) const override; - - void Serialize(Serializer &serializer) const override; - static unique_ptr Deserialize(Deserializer &deserializer); - -private: - optional_ptr perfect_join_executor; - const string key_column_name; - const LogicalType key_type; -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/filter/table_filter_functions.hpp b/src/duckdb/src/include/duckdb/planner/filter/table_filter_functions.hpp index 299ebac71..b5e55d42b 100644 --- a/src/duckdb/src/include/duckdb/planner/filter/table_filter_functions.hpp +++ b/src/duckdb/src/include/duckdb/planner/filter/table_filter_functions.hpp @@ -21,7 +21,6 @@ namespace duckdb { class BaseStatistics; class Expression; -class PerfectHashJoinExecutor; class PrefixRangeFilter; struct DynamicFilterData; @@ -60,7 +59,7 @@ struct SelectivityOptionalFilterState final : public TableFilterState { } }; -enum class SelectivityOptionalFilterType : uint8_t { MIN_MAX, BF, PHJ, PRF }; +enum class SelectivityOptionalFilterType : uint8_t { MIN_MAX = 0, BF = 1, PRF = 3 }; void GetThresholdAndVectorsToCheck(SelectivityOptionalFilterType type, float &selectivity_threshold, idx_t &n_vectors_to_check); @@ -100,6 +99,8 @@ class BloomFilter { return initialized; } + static idx_t GetNumberOfSectors(idx_t number_of_rows); + private: idx_t num_sectors; uint64_t bitmask; // num_sectors - 1 -> used to get the sector offset @@ -126,20 +127,6 @@ struct BloomFilterFunctionData : public FunctionData { bool Equals(const FunctionData &other) const override; }; -//! FunctionData for perfect hash join internal function -struct PerfectHashJoinFunctionData : public FunctionData { - PerfectHashJoinFunctionData(optional_ptr executor_p, const string &key_column_name_p, - float selectivity_threshold_p, idx_t n_vectors_to_check_p); - - optional_ptr executor; - string key_column_name; - float selectivity_threshold; - idx_t n_vectors_to_check; - - unique_ptr Copy() const override; - bool Equals(const FunctionData &other) const override; -}; - //! Runtime prefix-range filter state used by join pushdown and internal tablefilter functions. class PrefixRangeFilter { public: @@ -159,7 +146,7 @@ class PrefixRangeFilter { }; virtual ~PrefixRangeFilter() = default; - virtual void Initialize(ClientContext &context, idx_t number_of_rows, Value min, Value max) = 0; + virtual void Initialize(ClientContext &context, idx_t number_of_rows, Value min, Value max, idx_t max_bits) = 0; virtual unique_ptr InitializeBuildState(ClientContext &context) const = 0; virtual void InsertKeys(Vector &keys, BuildState &state) const = 0; virtual void MergeBuildState(BuildState &state) = 0; @@ -249,15 +236,6 @@ struct BloomFilterScalarFun : public TableFilterBloomFilterFun { static string ToString(const string &column_name, const string &key_column_name); }; -//! Factory for perfect hash join internal function -struct PerfectHashJoinScalarFun : public TableFilterPerfectHashJoinFun { - using TableFilterPerfectHashJoinFun::GetFunction; - static constexpr const char *NAME = TableFilterPerfectHashJoinFun::Name; - static ScalarFunction GetFunction(const LogicalType &input_type); - static FilterPropagateResult FilterPrune(const FunctionStatisticsPruneInput &input); - static string ToString(const string &column_name, const string &key_column_name); -}; - //! Factory for prefix range internal function struct PrefixRangeScalarFun : public TableFilterPrefixRangeFun { using TableFilterPrefixRangeFun::GetFunction; diff --git a/src/duckdb/src/include/duckdb/planner/subquery/delim_join_cte_rewriter.hpp b/src/duckdb/src/include/duckdb/planner/subquery/delim_join_cte_rewriter.hpp index a97b7fea0..1d5dc411f 100644 --- a/src/duckdb/src/include/duckdb/planner/subquery/delim_join_cte_rewriter.hpp +++ b/src/duckdb/src/include/duckdb/planner/subquery/delim_join_cte_rewriter.hpp @@ -23,9 +23,9 @@ class DelimJoinCTERewriter { void Rewrite(unique_ptr &plan); void RewriteDelimJoinsToCTEs(unique_ptr &plan, LogicalOperator &rewrite_root, - bool null_rejecting_filter_above = false); + bool null_rejecting_filter_above = false, bool preserve_evidence_side = false); void MaterializeDelimJoinAsCTE(unique_ptr &plan, LogicalOperator &rewrite_root, - bool null_rejecting_filter_above); + bool null_rejecting_filter_above, bool preserve_evidence_side); private: Binder &binder; diff --git a/src/duckdb/src/include/duckdb/planner/table_filter.hpp b/src/duckdb/src/include/duckdb/planner/table_filter.hpp index dffb5e577..e327111d8 100644 --- a/src/duckdb/src/include/duckdb/planner/table_filter.hpp +++ b/src/duckdb/src/include/duckdb/planner/table_filter.hpp @@ -20,19 +20,18 @@ class PhysicalOperator; class PhysicalTableScan; enum class TableFilterType : uint8_t { - LEGACY_CONSTANT_COMPARISON = 0, // constant comparison (e.g. =C, >C, >=C, C, >=C, { idx_t NextEvictionSequenceNumber() { return ++eviction_seq_num; } + //! Returns true, if the block has a live (not yet dead-counted) entry in the eviction queue. + bool HasLiveQueueEntry(BlockLock &l) const { + VerifyMutex(l); + return has_queue_entry; + } + //! Lock-free overload of HasLiveQueueEntry. Only safe for callers with exclusive ownership of the + //! block memory (i.e., the destructor). + bool HasLiveQueueEntry() const { + return has_queue_entry; + } + //! Marks whether the block has a live entry in the eviction queue. Requires the block lock. + void SetHasLiveQueueEntry(BlockLock &l, bool has_queue_entry_p) { + VerifyMutex(l); + has_queue_entry = has_queue_entry_p; + } //! Get the LRU timestamp. int64_t GetLRUTimestamp() const { return lru_timestamp_msec; @@ -211,8 +226,13 @@ class BlockMemory : public enable_shared_from_this { const FileBufferType buffer_type; //! A pointer to the loaded data, if any. unique_ptr buffer; - //! The internal eviction sequence number. + //! The internal eviction sequence number. Monotonic: it is never reset, so an eviction queue + //! entry is stale if and only if its sequence number differs from this one. atomic eviction_seq_num; + //! Whether the block has a live entry in the eviction queue, i.e., an entry whose sequence + //! number matches eviction_seq_num and which has not been counted as a dead node. + //! Guarded by the block lock (read without it only by the destructor, which has exclusive ownership). + bool has_queue_entry; //! The LRU timestamp for age-based eviction. atomic lru_timestamp_msec; //! When to destroy the data buffer. diff --git a/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp b/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp index 93ad8dfcc..ce70f3116 100644 --- a/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp +++ b/src/duckdb/src/include/duckdb/storage/buffer/buffer_pool.hpp @@ -32,8 +32,6 @@ struct BufferEvictionNode { weak_ptr memory_p; idx_t handle_sequence_number; - bool CanUnload(BlockMemory &memory); - shared_ptr TryGetBlockMemory(); bool IsDeadNode(optional_idx debug_sleep_micros = optional_idx()); }; @@ -106,8 +104,8 @@ class BufferPool { //! Garbage collect dead nodes in the eviction queue. void PurgeQueue(const BlockHandle &handle); //! Add a buffer handle to the eviction queue. Returns true, if the queue is - //! ready to be purged, and false otherwise. - bool AddToEvictionQueue(shared_ptr &handle); + //! ready to be purged, and false otherwise. Requires the handle's block lock. + bool AddToEvictionQueue(BlockLock &lock, shared_ptr &handle); //! Gets the eviction queue for the specified type EvictionQueue &GetEvictionQueueForBlockMemory(const BlockMemory &memory); //! Increments the dead nodes for the queue with specified type diff --git a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp index 9e6a34790..53141d792 100644 --- a/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/scan_state.hpp @@ -277,8 +277,14 @@ class CollectionScanState { struct ScanSamplingInfo { //! Whether or not to do a system sample during scanning bool do_system_sample = false; - //! The sampling rate to use + //! The sampling rate to use (for percentage-based sampling) double sample_rate; + //! The seeded phase used for row-count based systematic sampling + double sample_phase = 0; + //! Whether the sampling is row-count based or percentage-based + bool is_percentage = false; + //! Target number of rows to sample (for row-count based sampling) + idx_t target_sample_rows = 0; }; struct TableScanOptions { @@ -316,7 +322,7 @@ class TableScanState { public: void Initialize(vector column_ids, optional_ptr context = nullptr, optional_ptr table_filters = nullptr, - optional_ptr table_sampling = nullptr); + optional_ptr table_sampling = nullptr, idx_t estimated_table_row_count = 0); const vector &GetColumnIds(); diff --git a/src/duckdb/src/include/duckdb/storage/temporary_memory_manager.hpp b/src/duckdb/src/include/duckdb/storage/temporary_memory_manager.hpp index 111ef1a56..a746441ff 100644 --- a/src/duckdb/src/include/duckdb/storage/temporary_memory_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/temporary_memory_manager.hpp @@ -102,6 +102,10 @@ class TemporaryMemoryManager { private: //! Get the default minimum reservation idx_t DefaultMinimumReservation() const DUCKDB_REQUIRES(lock); + //! Cap a reservation by the current memory limit managed by TMM + idx_t CapReservation(idx_t reservation) const DUCKDB_REQUIRES(lock); + //! Get the effective minimum reservation for a state after applying TMM limits + idx_t MinimumReservation(const TemporaryMemoryState &temporary_memory_state) const DUCKDB_REQUIRES(lock); //! Unregister a TemporaryMemoryState (called by the destructor of TemporaryMemoryState) void Unregister(TemporaryMemoryState &temporary_memory_state); //! Update memory_limit, has_temporary_directory, and num_threads (must hold the lock) @@ -114,6 +118,8 @@ class TemporaryMemoryManager { void SetReservation(TemporaryMemoryState &temporary_memory_state, idx_t new_reservation) DUCKDB_REQUIRES(lock); //! Computes optimal reservation of a TemporaryMemoryState based on a cost function idx_t ComputeReservation(const TemporaryMemoryState &temporary_memory_state) const DUCKDB_REQUIRES(lock); + //! Compute initial reservation for use in ComputeReservation + idx_t ComputeInitialReservation(const TemporaryMemoryState &temporary_memory_state) const DUCKDB_REQUIRES(lock); //! Verify internal counts (must hold the lock) void Verify() const DUCKDB_REQUIRES(lock); diff --git a/src/duckdb/src/main/config.cpp b/src/duckdb/src/main/config.cpp index 2f5ace035..86d5b59da 100644 --- a/src/duckdb/src/main/config.cpp +++ b/src/duckdb/src/main/config.cpp @@ -127,6 +127,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING_CALLBACK(DefaultTransactionInvalidationPolicySetting), DUCKDB_SETTING(DelimJoinAsCteSetting), DUCKDB_SETTING_CALLBACK(DeprecatedUsingKeySyntaxSetting), + DUCKDB_SETTING_CALLBACK(DialectCompatibilityModeSetting), DUCKDB_SETTING_CALLBACK(DisableDatabaseInvalidationSetting), DUCKDB_SETTING(DisableTimestamptzCastsSetting), DUCKDB_GLOBAL(DisabledCompressionMethodsSetting), @@ -162,6 +163,7 @@ static const ConfigurationOption internal_options[] = { DUCKDB_SETTING(ForceColumnMetadataReuseSetting), DUCKDB_SETTING_CALLBACK(ForceCompressionSetting), DUCKDB_GLOBAL(ForceMbedtlsUnsafeSetting), + DUCKDB_SETTING(ForceUpdateToDelAndInsertSetting), DUCKDB_GLOBAL(ForceVariantShredding), DUCKDB_SETTING(GeometryMinimumShreddingSize), DUCKDB_SETTING_CALLBACK(HomeDirectorySetting), @@ -238,12 +240,12 @@ static const ConfigurationOption internal_options[] = { static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("configure_metrics", 28), DUCKDB_SETTING_ALIAS("custom_profiling_settings", 28), - DUCKDB_SETTING_ALIAS("memory_limit", 120), + DUCKDB_SETTING_ALIAS("memory_limit", 122), DUCKDB_SETTING_ALIAS("null_order", 55), - DUCKDB_SETTING_ALIAS("profile_output", 143), - DUCKDB_SETTING_ALIAS("user", 159), + DUCKDB_SETTING_ALIAS("profile_output", 145), + DUCKDB_SETTING_ALIAS("user", 161), DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 27), - DUCKDB_SETTING_ALIAS("worker_threads", 157), + DUCKDB_SETTING_ALIAS("worker_threads", 159), FINAL_ALIAS}; vector DBConfig::GetOptions() { diff --git a/src/duckdb/src/main/settings/autogenerated_settings.cpp b/src/duckdb/src/main/settings/autogenerated_settings.cpp index 9438c93de..ddf55ee9d 100644 --- a/src/duckdb/src/main/settings/autogenerated_settings.cpp +++ b/src/duckdb/src/main/settings/autogenerated_settings.cpp @@ -126,6 +126,13 @@ void DeprecatedUsingKeySyntaxSetting::OnSet(SettingCallbackInfo &info, Value &pa EnumUtil::FromString(StringValue::Get(parameter)); } +//===----------------------------------------------------------------------===// +// Dialect Compatibility Mode +//===----------------------------------------------------------------------===// +void DialectCompatibilityModeSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + EnumUtil::FromString(StringValue::Get(parameter)); +} + //===----------------------------------------------------------------------===// // Enable Progress Bar //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp b/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp index d2cb6e05a..5151b635f 100644 --- a/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp +++ b/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp @@ -24,6 +24,7 @@ namespace duckdb { struct JoinFilterBuildSideHeuristics { static constexpr idx_t MIN_FILTER_TARGET_CARDINALITY = 1000000; static constexpr idx_t MAX_BUILD_TO_TARGET_RATIO = 64; + static constexpr idx_t SEMI_JOIN_FILTER_TARGET_RATIO = 3; }; static void GetRowidBindings(LogicalOperator &op, vector &bindings) { @@ -152,6 +153,19 @@ static double DynamicFilterBuildBonus(LogicalComparisonJoin &join, const idx_t p return static_cast(max_target_cardinality) / static_cast(MaxValue(build_cardinality, 1)); } +static idx_t MaxDynamicFilterTargetCardinality(LogicalComparisonJoin &join, const idx_t probe_idx) { + idx_t max_target_cardinality = 0; + for (auto &cond : join.conditions) { + if (!cond.IsComparison() || cond.GetComparisonType() != ExpressionType::COMPARE_EQUAL) { + continue; + } + auto &probe_expr = probe_idx == 0 ? cond.GetLHS() : cond.GetRHS(); + max_target_cardinality = + MaxValue(max_target_cardinality, MaxDynamicFilterTargetCardinality(*join.children[probe_idx], probe_expr)); + } + return max_target_cardinality; +} + BuildSize BuildProbeSideOptimizer::GetBuildSizes(const LogicalOperator &op, const idx_t lhs_cardinality, const idx_t rhs_cardinality) { BuildSize ret; @@ -249,6 +263,16 @@ bool BuildProbeSideOptimizer::TryFlipJoinChildren(LogicalOperator &op) const { JoinFilterPushdownUtil::JoinTypeIsSupported(InverseJoinType(join.join_type))) { left_side_build_cost /= DynamicFilterBuildBonus(join, 1, 0, lhs_cardinality); } + if (join.join_type == JoinType::SEMI && JoinFilterPushdownOptimizer::IsFiltering(join.children[0])) { + // SEMI joins often have a filtered domain on the LHS and a larger RHS with residual filters. If flipping + // lets that domain generate a runtime filter for a much larger RHS scan, prefer the domain build side. + auto right_filter_target = MaxDynamicFilterTargetCardinality(join, 1); + if (right_filter_target >= JoinFilterBuildSideHeuristics::MIN_FILTER_TARGET_CARDINALITY && + right_filter_target / JoinFilterBuildSideHeuristics::SEMI_JOIN_FILTER_TARGET_RATIO > lhs_cardinality && + right_filter_target / JoinFilterBuildSideHeuristics::SEMI_JOIN_FILTER_TARGET_RATIO > rhs_cardinality) { + swap = true; + } + } } idx_t left_child_joins = ChildHasJoins(*op.children[0]); diff --git a/src/duckdb/src/optimizer/grouping_sets_optimizer.cpp b/src/duckdb/src/optimizer/grouping_sets_optimizer.cpp new file mode 100644 index 000000000..ead270483 --- /dev/null +++ b/src/duckdb/src/optimizer/grouping_sets_optimizer.cpp @@ -0,0 +1,326 @@ +#include "duckdb/optimizer/grouping_sets_optimizer.hpp" + +#include "duckdb/common/algorithm.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/function/aggregate/distributive_functions.hpp" +#include "duckdb/function/function_binder.hpp" +#include "duckdb/function/scalar/generic_common.hpp" +#include "duckdb/optimizer/optimizer.hpp" +#include "duckdb/planner/binder.hpp" +#include "duckdb/planner/expression/bound_aggregate_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/operator/logical_aggregate.hpp" +#include "duckdb/planner/operator/logical_cteref.hpp" +#include "duckdb/planner/operator/logical_materialized_cte.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/operator/logical_set_operation.hpp" + +namespace duckdb { + +GroupingSetsOptimizer::GroupingSetsOptimizer(Optimizer &optimizer_p) : optimizer(optimizer_p) { +} + +namespace { + +//! A single grouping set in the cascade, in the order in which it is computed +struct GroupingSetLevel { + explicit GroupingSetLevel(idx_t grouping_set_idx) : grouping_set_idx(grouping_set_idx) { + } + + //! The index of the grouping set (within LogicalAggregate::grouping_sets) computed by this level + idx_t grouping_set_idx; + //! The level this level aggregates over (invalid for the finest level, which aggregates the base data) + optional_idx source_level; + //! Whether this level is the source of another level (and must be materialized as a CTE) + bool materialized = false; + //! The table index of the materialized CTE (only set if materialized) + TableIndex cte_index; + //! The output types of this level: the group columns (in ascending group order), then the aggregate states + vector output_types; + //! The output names of this level + vector output_names; + //! The position of each group within the level output + map group_positions; + //! The aggregate computing this level + unique_ptr aggregate; +}; + +} // namespace + +static bool CanRewriteAggregate(const BoundAggregateExpression &aggregate) { + if (aggregate.IsDistinct() || aggregate.GetOrderBys()) { + // DISTINCT / ORDER BY aggregates cannot be computed by combining states of a finer aggregation + return false; + } + if (aggregate.StateExportMode() != AggregateStateExportMode::NONE) { + // the aggregate already exports its state + return false; + } + // mirror the requirements of ExportAggregateFunction::Bind so that binding the export cannot fail + auto &function = aggregate.Function(); + if (!function.HasStateCombineCallback() || function.HasStateDestructorCallback() || + !function.HasStateSizeCallback() || !function.HasStateFinalizeCallback() || + !function.HasGetStateTypeCallback()) { + return false; + } + return true; +} + +//! Order the grouping sets so that each set can be computed by re-aggregating an already-computed superset +static bool FindCascade(const vector &grouping_sets, vector &levels) { + // order the grouping sets by size (descending) - supersets must be computed before their subsets + for (idx_t set_idx = 0; set_idx < grouping_sets.size(); set_idx++) { + levels.emplace_back(set_idx); + } + std::stable_sort(levels.begin(), levels.end(), [&](const GroupingSetLevel &a, const GroupingSetLevel &b) { + return grouping_sets[a.grouping_set_idx].size() > grouping_sets[b.grouping_set_idx].size(); + }); + // each level is computed from the smallest already-computed level that is a superset of it + // for ROLLUP this forms a chain, for CUBE a lattice rooted in the complete grouping set + for (idx_t level_idx = 1; level_idx < levels.size(); level_idx++) { + auto &grouping_set = grouping_sets[levels[level_idx].grouping_set_idx]; + for (idx_t source_idx = level_idx; source_idx > 0; source_idx--) { + auto &source_set = grouping_sets[levels[source_idx - 1].grouping_set_idx]; + if (std::includes(source_set.begin(), source_set.end(), grouping_set.begin(), grouping_set.end())) { + levels[level_idx].source_level = source_idx - 1; + levels[source_idx - 1].materialized = true; + break; + } + } + if (!levels[level_idx].source_level.IsValid()) { + // no computed superset to compute this grouping set from - we cannot cascade + return false; + } + } + return true; +} + +bool GroupingSetsOptimizer::TryRewriteGroupingSets(unique_ptr &op) { + if (op->type != LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY || op->children.size() != 1) { + return false; + } + auto &aggr = op->Cast(); + if (aggr.grouping_sets.size() < 2 || aggr.expressions.empty()) { + // the rewrite is only beneficial when multiple grouping sets are computed + return false; + } + for (auto &expr : aggr.expressions) { + if (expr->GetExpressionClass() != ExpressionClass::BOUND_AGGREGATE) { + return false; + } + if (!CanRewriteAggregate(expr->Cast())) { + return false; + } + } + vector levels; + if (!FindCascade(aggr.grouping_sets, levels)) { + return false; + } + + const idx_t group_count = aggr.groups.size(); + const idx_t aggregate_count = aggr.expressions.size(); + + // build the per-level aggregates + auto combine_function = CombineAggrFun::GetFunction(); + FunctionBinder function_binder(optimizer.context); + vector state_types; + for (idx_t level_idx = 0; level_idx < levels.size(); level_idx++) { + auto &level = levels[level_idx]; + auto &grouping_set = aggr.grouping_sets[level.grouping_set_idx]; + + vector> level_groups; + vector> level_aggregates; + unique_ptr level_child; + if (level_idx == 0) { + // the finest level aggregates the base data and exports the aggregate states + for (auto &group_idx : grouping_set) { + level_groups.push_back(aggr.groups[group_idx]->Copy()); + } + // bind the aggregates + for (auto &expr : aggr.expressions) { + auto aggregate_copy = unique_ptr_cast(expr->Copy()); + auto export_aggregate = ExportAggregateFunction::Bind(std::move(aggregate_copy)); + if (!export_aggregate->GetReturnType().IsAggregateState()) { + return false; + } + state_types.push_back(export_aggregate->GetReturnType()); + level_aggregates.push_back(std::move(export_aggregate)); + } + } else { + // coarser levels combine the aggregate states of their source level + auto &source = levels[level.source_level.GetIndex()]; + auto &source_set = aggr.grouping_sets[source.grouping_set_idx]; + const auto cte_ref_index = optimizer.binder.GenerateTableIndex(); + level_child = + make_uniq(cte_ref_index, source.cte_index, source.output_types, source.output_names); + + for (auto &group_idx : grouping_set) { + const auto group_pos = source.group_positions[group_idx]; + level_groups.push_back(make_uniq( + source.output_types[group_pos], ColumnBinding(cte_ref_index, ProjectionIndex(group_pos)))); + } + for (idx_t aggr_idx = 0; aggr_idx < aggregate_count; aggr_idx++) { + const auto state_pos = source_set.size() + aggr_idx; + vector> arguments; + arguments.push_back(make_uniq( + state_types[aggr_idx], ColumnBinding(cte_ref_index, ProjectionIndex(state_pos)))); + auto combine_aggregate = function_binder.BindAggregateFunction(combine_function, std::move(arguments)); + if (combine_aggregate->GetReturnType() != state_types[aggr_idx]) { + return false; + } + level_aggregates.push_back(std::move(combine_aggregate)); + } + } + + // fill in the output layout of this level: the group columns, followed by the aggregate states + for (auto &group_idx : grouping_set) { + level.group_positions[group_idx] = level.output_types.size(); + level.output_types.push_back(aggr.groups[group_idx]->GetReturnType()); + level.output_names.push_back(Identifier(StringUtil::Format("group_%llu", group_idx.GetIndex()))); + } + for (idx_t aggr_idx = 0; aggr_idx < aggregate_count; aggr_idx++) { + level.output_types.push_back(state_types[aggr_idx]); + level.output_names.push_back(Identifier(StringUtil::Format("state_%llu", aggr_idx))); + } + + level.aggregate = make_uniq( + optimizer.binder.GenerateTableIndex(), optimizer.binder.GenerateTableIndex(), std::move(level_aggregates)); + level.aggregate->groups = std::move(level_groups); + if (aggr.has_estimated_cardinality) { + level.aggregate->SetEstimatedCardinality(aggr.estimated_cardinality); + } + if (level_child) { + level.aggregate->children.push_back(std::move(level_child)); + } + if (level.materialized) { + level.cte_index = optimizer.binder.GenerateTableIndex(); + } + } + + // build the union branches - one branch per grouping set, in their original order + vector> branches(levels.size()); + for (auto &level : levels) { + auto &grouping_set = aggr.grouping_sets[level.grouping_set_idx]; + + // the branch reads from a reference to the CTE if the level is materialized, + // otherwise the level aggregate is inlined into the branch directly + unique_ptr branch_child; + TableIndex cte_ref_index; + if (level.materialized) { + cte_ref_index = optimizer.binder.GenerateTableIndex(); + branch_child = + make_uniq(cte_ref_index, level.cte_index, level.output_types, level.output_names); + } else { + branch_child = std::move(level.aggregate); + } + auto GetBranchBinding = [&](idx_t output_pos) { + if (level.materialized) { + return ColumnBinding(cte_ref_index, ProjectionIndex(output_pos)); + } + auto &branch_aggr = branch_child->Cast(); + if (output_pos < grouping_set.size()) { + return ColumnBinding(branch_aggr.group_index, ProjectionIndex(output_pos)); + } + return ColumnBinding(branch_aggr.aggregate_index, ProjectionIndex(output_pos - grouping_set.size())); + }; + + vector> proj_exprs; + for (idx_t group_idx = 0; group_idx < group_count; group_idx++) { + auto &group_type = aggr.groups[group_idx]->GetReturnType(); + auto entry = level.group_positions.find(ProjectionIndex(group_idx)); + if (entry != level.group_positions.end()) { + proj_exprs.push_back(make_uniq(group_type, GetBranchBinding(entry->second))); + } else { + // this group is not part of the grouping set: emit NULL + proj_exprs.push_back(make_uniq(Value(group_type))); + } + } + for (idx_t aggr_idx = 0; aggr_idx < aggregate_count; aggr_idx++) { + auto state_ref = make_uniq(state_types[aggr_idx], + GetBranchBinding(grouping_set.size() + aggr_idx)); + auto finalize_expr = optimizer.BindScalarFunction("finalize", std::move(state_ref)); + if (finalize_expr->GetReturnType() != aggr.expressions[aggr_idx]->GetReturnType()) { + return false; + } + proj_exprs.push_back(std::move(finalize_expr)); + } + // GROUPING() function calls are constant within a grouping set (see RadixPartitionedHashTable) + for (auto &grouping_function : aggr.grouping_functions) { + int64_t grouping_value = 0; + for (idx_t i = 0; i < grouping_function.size(); i++) { + if (grouping_set.find(grouping_function[i]) == grouping_set.end()) { + // we do not group on this column in this grouping set + grouping_value += 1LL << (grouping_function.size() - (i + 1)); + } + } + proj_exprs.push_back(make_uniq(Value::BIGINT(grouping_value))); + } + + auto branch = make_uniq(optimizer.binder.GenerateTableIndex(), std::move(proj_exprs)); + branch->children.push_back(std::move(branch_child)); + branches[level.grouping_set_idx] = std::move(branch); + } + + // from here on the rewrite can no longer fail - we can start modifying the original plan + // attach the base input to the finest level + levels[0].aggregate->children.push_back(std::move(aggr.children[0])); + + // union the branches together + const idx_t column_count = group_count + aggregate_count + aggr.grouping_functions.size(); + const auto union_index = optimizer.binder.GenerateTableIndex(); + unique_ptr result = make_uniq(union_index, column_count, std::move(branches), + LogicalOperatorType::LOGICAL_UNION, true); + if (aggr.has_estimated_cardinality) { + result->SetEstimatedCardinality(aggr.estimated_cardinality); + } + + // wrap the result in the materialized CTEs, finest level outermost so that coarser levels can reference it + for (idx_t level_idx = levels.size(); level_idx > 0; level_idx--) { + auto &level = levels[level_idx - 1]; + if (!level.materialized) { + continue; + } + auto cte_name = Identifier(StringUtil::Format("__grouping_sets_cte_%llu", level.cte_index.index)); + auto cte = make_uniq(std::move(cte_name), level.cte_index, level.output_types.size(), + std::move(level.aggregate), std::move(result), + CTEMaterialize::CTE_MATERIALIZE_DEFAULT); + if (aggr.has_estimated_cardinality) { + cte->SetEstimatedCardinality(aggr.estimated_cardinality); + } + result = std::move(cte); + } + + // replace the bindings of the original aggregate with the union output + for (idx_t group_idx = 0; group_idx < group_count; group_idx++) { + replacement_map[ColumnBinding(aggr.group_index, ProjectionIndex(group_idx))] = + ColumnBinding(union_index, ProjectionIndex(group_idx)); + } + for (idx_t aggr_idx = 0; aggr_idx < aggregate_count; aggr_idx++) { + replacement_map[ColumnBinding(aggr.aggregate_index, ProjectionIndex(aggr_idx))] = + ColumnBinding(union_index, ProjectionIndex(group_count + aggr_idx)); + } + for (idx_t grouping_idx = 0; grouping_idx < aggr.grouping_functions.size(); grouping_idx++) { + replacement_map[ColumnBinding(aggr.groupings_index, ProjectionIndex(grouping_idx))] = + ColumnBinding(union_index, ProjectionIndex(group_count + aggregate_count + grouping_idx)); + } + + result->ResolveOperatorTypes(); + op = std::move(result); + return true; +} + +void GroupingSetsOptimizer::VisitOperator(unique_ptr &op) { + LogicalOperatorVisitor::VisitOperator(op); + TryRewriteGroupingSets(op); +} + +unique_ptr GroupingSetsOptimizer::VisitReplace(BoundColumnRefExpression &expr, + unique_ptr *expr_ptr) { + auto entry = replacement_map.find(expr.Binding()); + if (entry != replacement_map.end()) { + expr.BindingMutable() = entry->second; + } + return nullptr; +} + +} // namespace duckdb diff --git a/src/duckdb/src/optimizer/late_materialization.cpp b/src/duckdb/src/optimizer/late_materialization.cpp index cc6c6cc81..782e67d86 100644 --- a/src/duckdb/src/optimizer/late_materialization.cpp +++ b/src/duckdb/src/optimizer/late_materialization.cpp @@ -228,6 +228,11 @@ bool LateMaterialization::TryLateMaterialization(unique_ptr &op // this function does not support late materialization return false; } + if (get.extra_info.sample_options && !get.extra_info.sample_options->is_percentage) { + // we should not apply late materialization when row-count sampling is pushed down + // the sample scan is already fast and creating a semi-join would duplicate the full table scan + return false; + } if (!get.function.get_row_id_columns) { throw InternalException("Function supports late materialization but not get_row_id_columns"); } diff --git a/src/duckdb/src/optimizer/optimizer.cpp b/src/duckdb/src/optimizer/optimizer.cpp index 844ce38ae..b7f7234cc 100644 --- a/src/duckdb/src/optimizer/optimizer.cpp +++ b/src/duckdb/src/optimizer/optimizer.cpp @@ -16,6 +16,7 @@ #include "duckdb/optimizer/expression_heuristics.hpp" #include "duckdb/optimizer/filter_pullup.hpp" #include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/optimizer/grouping_sets_optimizer.hpp" #include "duckdb/optimizer/in_clause_rewriter.hpp" #include "duckdb/optimizer/join_elimination.hpp" #include "duckdb/optimizer/join_filter_pushdown_optimizer.hpp" @@ -236,6 +237,12 @@ void Optimizer::RunBuiltInOptimizers() { plan = deliminator.Optimize(std::move(plan)); }); + // rewrite aggregates over multiple grouping sets (ROLLUP/CUBE/GROUPING SETS) into a cascade of aggregations + RunOptimizer(OptimizerType::GROUPING_SETS, [&]() { + GroupingSetsOptimizer grouping_sets_optimizer(*this); + grouping_sets_optimizer.VisitOperator(plan); + }); + // try to inline CTEs instead of materialization RunOptimizer(OptimizerType::CTE_INLINING, [&]() { CTEInlining cte_inlining(*this); @@ -345,7 +352,7 @@ void Optimizer::RunBuiltInOptimizers() { // perform sampling pushdown RunOptimizer(OptimizerType::SAMPLING_PUSHDOWN, [&]() { - SamplingPushdown sampling_pushdown; + SamplingPushdown sampling_pushdown(context); plan = sampling_pushdown.Optimize(std::move(plan)); }); diff --git a/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp b/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp index 1190dceb1..edbac27a2 100644 --- a/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp +++ b/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp @@ -62,6 +62,11 @@ static bool IsSupportedAggregate(const BoundAggregateExpression &expr) { if (expr.IsDistinct() || expr.GetFilter() || expr.GetOrderBys()) { return false; } + if (expr.StateExportMode() != AggregateStateExportMode::NONE) { + // the aggregate already exports its state - we cannot push it down again (and finalizing the + // re-exported state would not round-trip back to the original return type) + return false; + } if (expr.GetChildren().size() != 1) { return false; } diff --git a/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp b/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp index fc8c34cf3..47f84726a 100644 --- a/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp +++ b/src/duckdb/src/optimizer/pushdown/pushdown_mark_join.cpp @@ -1,11 +1,149 @@ #include "duckdb/optimizer/filter_pushdown.hpp" +#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_comparison_expression.hpp" +#include "duckdb/planner/expression/bound_conjunction_expression.hpp" #include "duckdb/planner/expression/bound_operator_expression.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" +#include "duckdb/planner/operator/logical_filter.hpp" +#include "duckdb/planner/operator/logical_get.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/planner/filter/expression_filter.hpp" +#include "duckdb/storage/statistics/base_statistics.hpp" namespace duckdb { using Filter = FilterPushdown::Filter; +static bool FilterNullRejectsExpression(const Expression &filter, const Expression &expr) { + if (filter.GetExpressionType() == ExpressionType::CONJUNCTION_AND) { + auto &conjunction = filter.Cast(); + for (auto &child : conjunction.GetChildren()) { + if (FilterNullRejectsExpression(*child, expr)) { + return true; + } + } + return false; + } + if (filter.GetExpressionType() == ExpressionType::CONJUNCTION_OR) { + auto &conjunction = filter.Cast(); + if (conjunction.GetChildren().empty()) { + return false; + } + for (auto &child : conjunction.GetChildren()) { + if (!FilterNullRejectsExpression(*child, expr)) { + return false; + } + } + return true; + } + if (filter.GetExpressionType() == ExpressionType::OPERATOR_IS_NOT_NULL) { + auto &op = filter.Cast(); + return !op.GetChildren().empty() && Expression::Equals(*op.GetChildren()[0], expr); + } + if (!BoundComparisonExpression::IsComparison(filter)) { + return false; + } + if (filter.GetExpressionType() == ExpressionType::COMPARE_DISTINCT_FROM || + filter.GetExpressionType() == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + return false; + } + auto &comparison = filter.Cast(); + return Expression::Equals(BoundComparisonExpression::Left(comparison), expr) || + Expression::Equals(BoundComparisonExpression::Right(comparison), expr); +} + +static bool GetColumnRefBinding(const Expression &expr, ColumnBinding &binding) { + if (expr.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + auto &colref = expr.Cast(); + if (colref.Depth() != 0) { + return false; + } + binding = colref.Binding(); + return true; +} + +static bool ExpressionIsNotNull(ClientContext &context, LogicalOperator &op, const Expression &expr) { + switch (op.type) { + case LogicalOperatorType::LOGICAL_PROJECTION: { + auto &projection = op.Cast(); + if (projection.children.size() != 1) { + return false; + } + ColumnBinding binding; + if (!GetColumnRefBinding(expr, binding)) { + return false; + } + auto projection_bindings = projection.GetColumnBindings(); + for (idx_t idx = 0; idx < projection_bindings.size(); idx++) { + if (projection_bindings[idx] == binding) { + return ExpressionIsNotNull(context, *projection.children[0], *projection.expressions[idx]); + } + } + return false; + } + case LogicalOperatorType::LOGICAL_FILTER: { + auto &filter = op.Cast(); + for (auto &filter_expr : filter.expressions) { + if (FilterNullRejectsExpression(*filter_expr, expr)) { + return true; + } + } + return filter.children.size() == 1 && ExpressionIsNotNull(context, *filter.children[0], expr); + } + case LogicalOperatorType::LOGICAL_GET: { + ColumnBinding binding; + if (!GetColumnRefBinding(expr, binding)) { + return false; + } + auto &get = op.Cast(); + if (binding.table_index != get.table_index) { + return false; + } + if (get.table_filters.HasFilter(binding.column_index)) { + auto column_expr = make_uniq(expr.GetReturnType(), binding); + auto filter_expr = + get.table_filters.GetFilterByColumnIndex(binding.column_index).ToExpression(*column_expr); + if (FilterNullRejectsExpression(*filter_expr, expr)) { + return true; + } + } + auto table = get.GetTable(); + if (!table) { + return false; + } + auto &column_index = get.GetColumnIndex(binding); + if (!column_index.HasPrimaryIndex() || column_index.HasChildren() || + column_index.GetPrimaryIndex() == DConstants::INVALID_INDEX) { + return false; + } + auto stats = table->GetStatistics(context, column_index.GetPrimaryIndex()); + return stats && !stats->CanHaveNull(); + } + default: + return false; + } +} + +static void SimplifyNullSafeSemiJoinConditions(ClientContext &context, LogicalComparisonJoin &join) { + D_ASSERT(join.join_type == JoinType::SEMI); + for (auto &cond : join.conditions) { + if (!cond.IsComparison() || cond.GetComparisonType() != ExpressionType::COMPARE_NOT_DISTINCT_FROM) { + continue; + } + // Once a MARK join is reduced to SEMI, a null-safe equality is equivalent to regular equality if either + // join key is known not to be NULL. Regular equality unlocks the existing runtime-filter infrastructure. + if (!ExpressionIsNotNull(context, *join.children[0], cond.GetLHS()) && + !ExpressionIsNotNull(context, *join.children[1], cond.GetRHS())) { + continue; + } + cond = + JoinCondition(cond.LeftReference()->Copy(), cond.RightReference()->Copy(), ExpressionType::COMPARE_EQUAL); + } +} + unique_ptr FilterPushdown::PushdownMarkJoin(unique_ptr op, unordered_set &left_bindings, unordered_set &right_bindings) { @@ -82,6 +220,9 @@ unique_ptr FilterPushdown::PushdownMarkJoin(unique_ptrchildren[0] = left_pushdown.Rewrite(std::move(op->children[0])); op->children[1] = right_pushdown.Rewrite(std::move(op->children[1])); + if (join.join_type == JoinType::SEMI) { + SimplifyNullSafeSemiJoinConditions(GetContext(), comp_join); + } return PushFinalFilters(std::move(op)); } diff --git a/src/duckdb/src/optimizer/sampling_pushdown.cpp b/src/duckdb/src/optimizer/sampling_pushdown.cpp index ca805e64e..0c934242b 100644 --- a/src/duckdb/src/optimizer/sampling_pushdown.cpp +++ b/src/duckdb/src/optimizer/sampling_pushdown.cpp @@ -1,19 +1,53 @@ #include "duckdb/optimizer/sampling_pushdown.hpp" +#include "duckdb/common/random_engine.hpp" #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_sample.hpp" -#include "duckdb/common/types/value.hpp" +#include "duckdb/planner/operator/logical_limit.hpp" + namespace duckdb { unique_ptr SamplingPushdown::Optimize(unique_ptr op) { - if (op->type == LogicalOperatorType::LOGICAL_SAMPLE && - op->Cast().sample_options->method == SampleMethod::SYSTEM_SAMPLE && - op->Cast().sample_options->is_percentage && !op->children.empty() && - op->children[0]->type == LogicalOperatorType::LOGICAL_GET && - op->children[0]->Cast().function.sampling_pushdown && op->children[0]->children.empty()) { + if (op->type == LogicalOperatorType::LOGICAL_SAMPLE && !op->children.empty() && + op->children[0]->type == LogicalOperatorType::LOGICAL_GET && op->children[0]->children.empty()) { + auto &sample_op = op->Cast(); auto &get = op->children[0]->Cast(); - // set sampling option - get.extra_info.sample_options = std::move(op->Cast().sample_options); - op = std::move(op->children[0]); + const auto &sample_options = *sample_op.sample_options; + const bool has_filters = get.table_filters.HasFilters() || get.dynamic_filters; + const bool can_push_system_sample = + get.function.sampling_pushdown && sample_options.method == SampleMethod::SYSTEM_SAMPLE && !has_filters; + + if (can_push_system_sample) { + const bool is_row_count_sampling = !sample_options.is_percentage; + int64_t row_limit = 0; + if (is_row_count_sampling) { + if (!sample_op.sample_options->seed.IsValid()) { + auto &random_engine = RandomEngine::Get(context); + sample_op.sample_options->SetSeed(random_engine.NextRandomInteger()); + } + // For row-count sampling, calculate the sampling rate based on estimated cardinality. + // Use EstimateCardinality which can query table function stats if has_estimated_cardinality is not set. + row_limit = sample_options.sample_size.GetValue(); + const idx_t estimated_card = get.EstimateCardinality(context); + if (estimated_card > 0) { + sample_op.sample_options->sample_rate = + static_cast(row_limit) / static_cast(estimated_card); + } else { + sample_op.sample_options->sample_rate = 1.0; + } + } + + get.extra_info.sample_options = std::move(sample_op.sample_options); + op = std::move(op->children[0]); + + if (is_row_count_sampling) { + // Wrap with LIMIT to ensure exact row count and enable early stopping. + // The pushdown sampling may oversample due to the distributed chunk-based + // approach, so LIMIT ensures we stop as soon as the target is reached. + auto limit = make_uniq(BoundLimitNode::ConstantValue(row_limit), BoundLimitNode()); + limit->children.push_back(std::move(op)); + op = std::move(limit); + } + } } for (auto &child : op->children) { child = Optimize(std::move(child)); diff --git a/src/duckdb/src/optimizer/unnest_rewriter.cpp b/src/duckdb/src/optimizer/unnest_rewriter.cpp index c375c96de..d39753c30 100644 --- a/src/duckdb/src/optimizer/unnest_rewriter.cpp +++ b/src/duckdb/src/optimizer/unnest_rewriter.cpp @@ -36,6 +36,16 @@ static optional_idx FindBindingIndex(const vector &bindings, cons return NumericCast(entry - bindings.begin()); } +static idx_t CountUniqueBindings(const vector &bindings) { + vector unique_bindings; + for (auto &binding : bindings) { + if (!FindBindingIndex(unique_bindings, binding).IsValid()) { + unique_bindings.push_back(binding); + } + } + return unique_bindings.size(); +} + static idx_t CountCTERefs(LogicalOperator &op, TableIndex cte_index) { idx_t result = 0; if (op.type == LogicalOperatorType::LOGICAL_CTE_REF && op.Cast().cte_index == cte_index) { @@ -454,10 +464,12 @@ bool UnnestRewriter::RewriteInlineCTEDedupCandidate(unique_ptr for (idx_t binding_idx = 0; binding_idx < dedup_bindings.size(); binding_idx++) { domain_ref_replacer.replacement_bindings.emplace_back(dedup_bindings[binding_idx], delim_columns[binding_idx]); } - LogicalOperatorVisitor::EnumerateExpressions( - topmost_op, [&](unique_ptr *expr) { domain_ref_replacer.VisitExpression(expr); }); + domain_ref_replacer.VisitOperator(topmost_op); overwritten_tbl_idx = dedup_bindings[0].table_index; - distinct_unnest_count = dedup_bindings.size(); + // Inline CTE dedup inputs can contain repeated source columns, e.g. table-in-out UNNEST can project an input + // column and carry it again as a delimiter column. The RHS projection path only has the unique delimiter + // columns at its tail. + distinct_unnest_count = CountUniqueBindings(delim_columns); unnest.children[0] = std::move(domain_cte.children[0]); if (path_to_unnest.empty()) { @@ -720,31 +732,21 @@ void UnnestRewriter::UpdateRHSBindings(unique_ptr &plan, unique // update all bindings coming from the LHS to RHS bindings D_ASSERT(topmost_op.children[0]->type == LogicalOperatorType::LOGICAL_PROJECTION); auto &top_proj = topmost_op.children[0]->Cast(); + vector source_lhs_bindings; + source_lhs_bindings.reserve(lhs_bindings.size()); for (idx_t i = 0; i < lhs_bindings.size(); i++) { - ReplaceBinding replace_binding(lhs_bindings[i].binding, - ColumnBinding(top_proj.table_index, ProjectionIndex(i))); - updater.replace_bindings.push_back(replace_binding); - } - - // temporarily remove the BOUND_UNNESTs and the child of the LOGICAL_UNNEST from the plan - D_ASSERT(curr_op.get()->type == LogicalOperatorType::LOGICAL_UNNEST); - auto &unnest = curr_op.get()->Cast(); - vector> temp_bound_unnests; - for (auto &temp_bound_unnest : unnest.expressions) { - temp_bound_unnests.push_back(std::move(temp_bound_unnest)); + source_lhs_bindings.push_back(lhs_bindings[i].binding); } - D_ASSERT(unnest.children.size() == 1); - auto temp_unnest_child = std::move(unnest.children[0]); - unnest.expressions.clear(); - unnest.children.clear(); - // update the bindings of the plan - updater.VisitOperator(*plan); - updater.replace_bindings.clear(); - // add the children again - for (auto &temp_bound_unnest : temp_bound_unnests) { - unnest.expressions.push_back(std::move(temp_bound_unnest)); + // References above the RHS projection path should point at the LHS columns exposed by the top projection. + // Expressions inside the path itself still need to bind against their child; those are repaired below as each + // projection gets the LHS columns prepended. + ColumnBindingReplacer lhs_replacer; + lhs_replacer.stop_operator = topmost_op.children[0]; + for (idx_t i = 0; i < lhs_bindings.size(); i++) { + lhs_replacer.replacement_bindings.emplace_back( + lhs_bindings[i].binding, ColumnBinding(top_proj.table_index, ProjectionIndex(i)), lhs_bindings[i].type); } - unnest.children.push_back(std::move(temp_unnest_child)); + lhs_replacer.VisitOperator(*plan); // add the LHS expressions to each LOGICAL_PROJECTION for (idx_t i = path_to_unnest.size(); i > 0; i--) { @@ -755,6 +757,18 @@ void UnnestRewriter::UpdateRHSBindings(unique_ptr &plan, unique auto existing_expressions = std::move(proj.expressions); proj.expressions.clear(); + ColumnBindingReplacer child_lhs_replacer; + for (idx_t expr_idx = 0; expr_idx < lhs_bindings.size(); expr_idx++) { + if (source_lhs_bindings[expr_idx] == lhs_bindings[expr_idx].binding) { + continue; + } + child_lhs_replacer.replacement_bindings.emplace_back( + source_lhs_bindings[expr_idx], lhs_bindings[expr_idx].binding, lhs_bindings[expr_idx].type); + } + for (auto &expr : existing_expressions) { + child_lhs_replacer.VisitExpression(&expr); + } + // add the new expressions for (idx_t expr_idx = 0; expr_idx < lhs_bindings.size(); expr_idx++) { auto new_expr = make_uniq( diff --git a/src/duckdb/src/parallel/task_executor.cpp b/src/duckdb/src/parallel/task_executor.cpp index 11dca042d..d66d09117 100644 --- a/src/duckdb/src/parallel/task_executor.cpp +++ b/src/duckdb/src/parallel/task_executor.cpp @@ -32,7 +32,12 @@ void TaskExecutor::ThrowError() { void TaskExecutor::ScheduleTask(unique_ptr task) { ++total_tasks; - scheduler.ScheduleTask(*token, std::move(task), type); + try { + scheduler.ScheduleTask(*token, std::move(task), type); + } catch (...) { + --total_tasks; + throw; + } } void TaskExecutor::FinishTask() { ++completed_tasks; diff --git a/src/duckdb/src/parallel/task_scheduler.cpp b/src/duckdb/src/parallel/task_scheduler.cpp index 1038ca382..29537e8e2 100644 --- a/src/duckdb/src/parallel/task_scheduler.cpp +++ b/src/duckdb/src/parallel/task_scheduler.cpp @@ -273,6 +273,10 @@ int32_t TaskScheduler::NumberOfThreads() { return GetPool(TaskSchedulerType::REGULAR).NumberOfThreads(); } +int32_t TaskScheduler::NumberOfAsyncThreads() { + return GetPool(TaskSchedulerType::ASYNC).NumberOfThreads(); +} + idx_t TaskScheduler::GetNumberOfTasks() const { idx_t num_tasks = 0; for (auto &queue : queues) { diff --git a/src/duckdb/src/parser/parsed_data/sample_options.cpp b/src/duckdb/src/parser/parsed_data/sample_options.cpp index 1dfcba72f..d55ee9464 100644 --- a/src/duckdb/src/parser/parsed_data/sample_options.cpp +++ b/src/duckdb/src/parser/parsed_data/sample_options.cpp @@ -26,6 +26,7 @@ unique_ptr SampleOptions::Copy() { result->method = method; result->seed = seed; result->repeatable = repeatable; + result->sample_rate = sample_rate; return result; } @@ -49,7 +50,7 @@ bool SampleOptions::Equals(SampleOptions *a, SampleOptions *b) { return true; } if (a->sample_size != b->sample_size || a->is_percentage != b->is_percentage || a->method != b->method || - a->seed.GetIndex() != b->seed.GetIndex()) { + a->sample_rate != b->sample_rate || a->seed.GetIndex() != b->seed.GetIndex()) { return false; } return true; diff --git a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp index 2a17bcf5c..5de015273 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp @@ -31,6 +31,8 @@ #include "duckdb/planner/expression_binder/where_binder.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" #include "duckdb/planner/operator/logical_sample.hpp" +#include "duckdb/common/enums/dialect_compatibility_mode.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -179,6 +181,25 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B break; } } + // Spark Compatibility Mode: when ALL is not reserved, ORDER BY ALL is parsed as a column reference. + // If no column named "all" exists, treat as ORDER BY ALL. + if (Settings::Get(context) == DialectCompatibilityMode::SPARK && + order.orders.size() == 1 && + order.orders[0].expression->GetExpressionClass() == ExpressionClass::COLUMN_REF) { + auto &colref = order.orders[0].expression->Cast(); + if (colref.ColumnNames().size() == 1 && + StringUtil::CIEquals(colref.ColumnNames()[0].GetIdentifierName(), "all")) { + auto matching = bind_context.GetMatchingBindings("all"); + if (matching.empty()) { + auto order_type = config.ResolveOrder(context, order.orders[0].type); + auto null_order = config.ResolveNullOrder(context, order_type, order.orders[0].null_order); + auto constant_expr = make_uniq(Value("ALL")); + bound_order->orders.emplace_back(order_type, null_order, std::move(constant_expr)); + bound_modifier = std::move(bound_order); + break; + } + } + } #if 0 // When this verification is enabled, replace ORDER BY x, y with ORDER BY create_sort_key(x, y) // note that we don't enable this during actual verification since it doesn't always work @@ -493,6 +514,24 @@ BoundStatement Binder::BindSelectNode(SelectNode &statement, BoundStatement from } auto &group_expressions = statement.groups.group_expressions; + + // Spark Compatibility Mode: when ALL is not a reserved keyword, GROUP BY ALL is parsed as a column reference + // instead of the special GROUP BY ALL syntax. Detect this and convert to FORCE_AGGREGATES + // if no column named "all" actually exists in scope. + if (Settings::Get(context) == DialectCompatibilityMode::SPARK && + group_expressions.size() == 1 && group_expressions[0]->GetExpressionClass() == ExpressionClass::COLUMN_REF) { + auto &colref = group_expressions[0]->Cast(); + if (colref.ColumnNames().size() == 1 && + StringUtil::CIEquals(colref.ColumnNames()[0].GetIdentifierName(), "all")) { + auto matching = bind_context.GetMatchingBindings("all"); + if (matching.empty()) { + statement.aggregate_handling = AggregateHandling::FORCE_AGGREGATES; + group_expressions.clear(); + statement.groups.grouping_sets.clear(); + } + } + } + if (!group_expressions.empty()) { // the statement has a GROUP BY clause, bind it GroupBinder group_binder(*this, context, result.group_index, bind_state); diff --git a/src/duckdb/src/planner/filter/expression_filter.cpp b/src/duckdb/src/planner/filter/expression_filter.cpp index e30c4219b..5ceafeed3 100644 --- a/src/duckdb/src/planner/filter/expression_filter.cpp +++ b/src/duckdb/src/planner/filter/expression_filter.cpp @@ -15,7 +15,6 @@ #include "duckdb/planner/filter/bloom_filter.hpp" #include "duckdb/planner/filter/dynamic_filter.hpp" #include "duckdb/planner/filter/optional_filter.hpp" -#include "duckdb/planner/filter/perfect_hash_join_filter.hpp" #include "duckdb/planner/filter/prefix_range_filter.hpp" #include "duckdb/planner/filter/selectivity_optional_filter.hpp" #include "duckdb/planner/filter/table_filter_functions.hpp" @@ -72,9 +71,16 @@ static bool IsOptionalInternalFunction(const BoundFunctionExpression &func) { func.Function().GetName() == SelectivityOptionalFilterScalarFun::NAME; } -static bool IsOptionalExpressionInternal(const Expression &expr, bool recurse_through_and) { +static bool IsNonSelectivityOptionalInternalFunction(const BoundFunctionExpression &func) { + return func.Function().GetName() == OptionalFilterScalarFun::NAME; +} + +static bool IsOptionalExpressionInternal(const Expression &expr, bool recurse_through_and, + bool include_selectivity_optional) { if (expr.GetExpressionClass() == ExpressionClass::BOUND_FUNCTION) { - return IsOptionalInternalFunction(expr.Cast()); + auto &func = expr.Cast(); + return include_selectivity_optional ? IsOptionalInternalFunction(func) + : IsNonSelectivityOptionalInternalFunction(func); } if (!recurse_through_and || expr.GetExpressionClass() != ExpressionClass::BOUND_CONJUNCTION || expr.GetExpressionType() != ExpressionType::CONJUNCTION_AND) { @@ -85,7 +91,7 @@ static bool IsOptionalExpressionInternal(const Expression &expr, bool recurse_th return false; } for (auto &child : conj.GetChildren()) { - if (!IsOptionalExpressionInternal(*child, true)) { + if (!IsOptionalExpressionInternal(*child, true, include_selectivity_optional)) { return false; } } @@ -488,11 +494,11 @@ bool ExpressionFilter::ContainsInternalFunction(const Expression &expr, const st } bool ExpressionFilter::IsOptionalExpression(const Expression &expr) { - return IsOptionalExpressionInternal(expr, true); + return IsOptionalExpressionInternal(expr, true, true); } bool ExpressionFilter::IsRootOptionalExpression(const Expression &expr) { - return IsOptionalExpressionInternal(expr, false); + return IsOptionalExpressionInternal(expr, false, true); } bool ExpressionFilter::IsOptionalFilter(const TableFilter &filter) { @@ -505,6 +511,11 @@ bool ExpressionFilter::IsRootOptionalFilter(const TableFilter &filter) { return IsRootOptionalExpression(*expr_filter.expr); } +bool ExpressionFilter::IsRootNonSelectivityOptionalFilter(const TableFilter &filter) { + auto &expr_filter = GetExpressionFilter(filter, "ExpressionFilter::IsRootNonSelectivityOptionalFilter"); + return IsOptionalExpressionInternal(*expr_filter.expr, false, false); +} + static shared_ptr TryGetRootDynamicFilterData(const Expression &expr) { if (expr.GetExpressionClass() != ExpressionClass::BOUND_FUNCTION) { return nullptr; @@ -552,9 +563,6 @@ string ExpressionFilter::InternalFunctionToString(const BoundFunctionExpression if (func_name == BloomFilterScalarFun::NAME) { auto &data = func_expr.BindInfo()->Cast(); return BloomFilterScalarFun::ToString(column_name, data.key_column_name); - } else if (func_name == PerfectHashJoinScalarFun::NAME) { - auto &data = func_expr.BindInfo()->Cast(); - return PerfectHashJoinScalarFun::ToString(column_name, data.key_column_name); } else if (func_name == PrefixRangeScalarFun::NAME) { auto &data = func_expr.BindInfo()->Cast(); return PrefixRangeScalarFun::ToString(column_name, data.key_column_name); diff --git a/src/duckdb/src/planner/filter/perfect_hash_join_filter.cpp b/src/duckdb/src/planner/filter/perfect_hash_join_filter.cpp deleted file mode 100644 index 65860e65b..000000000 --- a/src/duckdb/src/planner/filter/perfect_hash_join_filter.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "duckdb/planner/filter/perfect_hash_join_filter.hpp" - -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/filter/table_filter_functions.hpp" - -namespace duckdb { - -LegacyPerfectHashJoinFilter::LegacyPerfectHashJoinFilter( - optional_ptr perfect_join_executor_p, const string &key_column_name_p, - const LogicalType &key_type_p) - : TableFilter(TYPE), perfect_join_executor(perfect_join_executor_p), key_column_name(key_column_name_p), - key_type(key_type_p) { -} - -unique_ptr LegacyPerfectHashJoinFilter::ToExpression(const Expression &column) const { - auto function = PerfectHashJoinScalarFun::GetFunction(column.GetReturnType()); - auto bind_data = make_uniq(perfect_join_executor, key_column_name, 0.0f, idx_t(0)); - vector> arguments; - arguments.push_back(column.Copy()); - return make_uniq(BoundScalarFunction(function), std::move(arguments), - std::move(bind_data)); -} - -void LegacyPerfectHashJoinFilter::Serialize(Serializer &serializer) const { - TableFilter::Serialize(serializer); - serializer.WriteProperty(200, "key_column_name", key_column_name); - serializer.WriteProperty(201, "key_type", key_type); -} - -unique_ptr LegacyPerfectHashJoinFilter::Deserialize(Deserializer &deserializer) { - auto key_column_name = deserializer.ReadProperty(200, "key_column_name"); - auto key_type = deserializer.ReadProperty(201, "key_type"); - return make_uniq(nullptr, key_column_name, key_type); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/table_filter_bloom_function.cpp b/src/duckdb/src/planner/filter/table_filter_bloom_function.cpp index d50a3f65a..b46827965 100644 --- a/src/duckdb/src/planner/filter/table_filter_bloom_function.cpp +++ b/src/duckdb/src/planner/filter/table_filter_bloom_function.cpp @@ -30,8 +30,7 @@ static constexpr idx_t N_BITS = 4; // the number of bits to void BloomFilter::Initialize(ClientContext &context_p, idx_t number_of_rows) { BufferManager &buffer_manager = BufferManager::GetBufferManager(context_p); - const idx_t min_bits = MaxValue(MIN_NUM_BITS, number_of_rows * MIN_NUM_BITS_PER_KEY); - num_sectors = MinValue(NextPowerOfTwo(min_bits) >> LOG_SECTOR_SIZE, MAX_NUM_SECTORS); + num_sectors = GetNumberOfSectors(number_of_rows); bitmask = num_sectors - 1; buf_ = buffer_manager.GetBufferAllocator().Allocate(64 + num_sectors * sizeof(uint64_t)); @@ -60,6 +59,11 @@ void BloomFilter::Reset() { bf = nullptr; } +idx_t BloomFilter::GetNumberOfSectors(idx_t number_of_rows) { + const idx_t min_bits = MaxValue(MIN_NUM_BITS, number_of_rows * MIN_NUM_BITS_PER_KEY); + return MinValue(NextPowerOfTwo(min_bits) >> LOG_SECTOR_SIZE, MAX_NUM_SECTORS); +} + inline uint64_t GetMask(const hash_t hash) { const uint64_t shifts = hash & SHIFT_MASK; const auto shifts_8 = reinterpret_cast(&shifts); diff --git a/src/duckdb/src/planner/filter/table_filter_functions.cpp b/src/duckdb/src/planner/filter/table_filter_functions.cpp index 436a8ba94..5fe8e6238 100644 --- a/src/duckdb/src/planner/filter/table_filter_functions.cpp +++ b/src/duckdb/src/planner/filter/table_filter_functions.cpp @@ -21,9 +21,9 @@ unique_ptr TableFilterFunctions::Bind(BindScalarFunctionInput &inp } bool TableFilterFunctions::IsTableFilterFunction(const Identifier &name) { - static const char *const TABLE_FILTER_FUNCTIONS[] = { - BloomFilterScalarFun::NAME, DynamicFilterScalarFun::NAME, OptionalFilterScalarFun::NAME, - PerfectHashJoinScalarFun::NAME, PrefixRangeScalarFun::NAME, SelectivityOptionalFilterScalarFun::NAME}; + static const char *const TABLE_FILTER_FUNCTIONS[] = {BloomFilterScalarFun::NAME, DynamicFilterScalarFun::NAME, + OptionalFilterScalarFun::NAME, PrefixRangeScalarFun::NAME, + SelectivityOptionalFilterScalarFun::NAME}; for (auto function_name : TABLE_FILTER_FUNCTIONS) { if (name == function_name) { return true; @@ -37,12 +37,10 @@ void GetThresholdAndVectorsToCheck(SelectivityOptionalFilterType type, float &se idx_t &n_vectors_to_check) { static constexpr float MIN_MAX_THRESHOLD = 0.9f; static constexpr float BF_THRESHOLD = 0.5f; - static constexpr float PHJ_THRESHOLD = 0.3f; static constexpr float PRF_THRESHOLD = 0.5f; static constexpr idx_t MIN_MAX_CHECK_N = 6; static constexpr idx_t BF_CHECK_N = 6; - static constexpr idx_t PHJ_CHECK_N = 6; static constexpr idx_t PRF_CHECK_N = 6; switch (type) { @@ -54,10 +52,6 @@ void GetThresholdAndVectorsToCheck(SelectivityOptionalFilterType type, float &se selectivity_threshold = BF_THRESHOLD; n_vectors_to_check = BF_CHECK_N; return; - case SelectivityOptionalFilterType::PHJ: - selectivity_threshold = PHJ_THRESHOLD; - n_vectors_to_check = PHJ_CHECK_N; - return; case SelectivityOptionalFilterType::PRF: selectivity_threshold = PRF_THRESHOLD; n_vectors_to_check = PRF_CHECK_N; @@ -110,9 +104,6 @@ unique_ptr TableFilterFunctionDeserialize(Deserializer &deserializ if (function.GetName() == BloomFilterScalarFun::NAME) { return make_uniq(nullptr, false, string(), key_type, 0.0f, idx_t(0)); } - if (function.GetName() == PerfectHashJoinScalarFun::NAME) { - return make_uniq(nullptr, string(), 0.0f, idx_t(0)); - } if (function.GetName() == PrefixRangeScalarFun::NAME) { return make_uniq(nullptr, string(), key_type, 0.0f, idx_t(0)); } diff --git a/src/duckdb/src/planner/filter/table_filter_perfect_hash_join_function.cpp b/src/duckdb/src/planner/filter/table_filter_perfect_hash_join_function.cpp deleted file mode 100644 index 9b5d45743..000000000 --- a/src/duckdb/src/planner/filter/table_filter_perfect_hash_join_function.cpp +++ /dev/null @@ -1,196 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/planner/filter/table_filter_perfect_hash_join_function.cpp -// -// -//===----------------------------------------------------------------------===// - -#include "duckdb/planner/filter/table_filter_functions.hpp" -#include "duckdb/planner/filter/table_filter_function_helpers.hpp" - -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/vector_size.hpp" -#include "duckdb/execution/expression_executor_state.hpp" -#include "duckdb/execution/operator/join/perfect_hash_join_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/storage/statistics/numeric_stats.hpp" - -namespace duckdb { - -PerfectHashJoinFunctionData::PerfectHashJoinFunctionData(optional_ptr executor_p, - const string &key_column_name_p, float selectivity_threshold_p, - idx_t n_vectors_to_check_p) - : executor(executor_p), key_column_name(key_column_name_p), selectivity_threshold(selectivity_threshold_p), - n_vectors_to_check(n_vectors_to_check_p) { -} - -unique_ptr PerfectHashJoinFunctionData::Copy() const { - return make_uniq(executor, key_column_name, selectivity_threshold, n_vectors_to_check); -} - -bool PerfectHashJoinFunctionData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return executor.get() == other.executor.get() && key_column_name == other.key_column_name; -} - -static idx_t SelectPerfectHashJoin(Vector &input, const PerfectHashJoinFunctionData &func_data, - SelectionVector &result_sel, idx_t count) { - D_ASSERT(func_data.executor); - idx_t approved_count = 0; - func_data.executor->FillSelectionVectorSwitchProbe(input, count, result_sel, approved_count, nullptr); - return approved_count; -} - -static unique_ptr -PerfectHashJoinInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) { - auto &data = bind_data->Cast(); - if (!data.executor) { - return nullptr; - } - return InitSelectivityTrackingLocalState(data.n_vectors_to_check, data.selectivity_threshold); -} - -static void PerfectHashJoinFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &func_data = func_expr.BindInfo()->Cast(); - auto local_state_ptr = ExecuteFunctionState::GetFunctionState(state); - auto tracking_state = local_state_ptr ? &local_state_ptr->Cast() : nullptr; - - if (!func_data.executor) { - SetAllTrue(args, result); - return; - } - - ExecuteWithSelectivityTracking(args, result, tracking_state, [&] { - SelectionVector probe_sel(args.size()); - auto approved_count = SelectPerfectHashJoin(args.data[0], func_data, probe_sel, args.size()); - SelectionToBooleanResult(args.size(), probe_sel, approved_count, result); - return approved_count; - }); -} - -static idx_t PerfectHashJoinSelect(DataChunk &args, ExpressionState &state, optional_ptr sel, - optional_ptr true_sel, optional_ptr false_sel) { - auto &func_expr = state.expr.Cast(); - auto &func_data = func_expr.BindInfo()->Cast(); - auto local_state_ptr = ExecuteFunctionState::GetFunctionState(state); - auto tracking_state = local_state_ptr ? &local_state_ptr->Cast() : nullptr; - - auto count = args.size(); - if (!func_data.executor) { - return SetAllTrueSelection(count, sel, true_sel, false_sel); - } - if (tracking_state && !tracking_state->IsActive()) { - tracking_state->Update(0, 0); - return SetAllTrueSelection(count, sel, true_sel, false_sel); - } - - SelectionVector temp_true(count); - auto result_true_sel = (!true_sel || (sel && true_sel.get() == sel.get())) ? &temp_true : true_sel.get(); - auto approved_count = SelectPerfectHashJoin(args.data[0], func_data, *result_true_sel, count); - approved_count = TranslateSelection(count, sel, *result_true_sel, approved_count, true_sel, false_sel); - if (tracking_state) { - tracking_state->Update(approved_count, count); - } - return approved_count; -} - -template -static FilterPropagateResult TemplatedPerfectHashJoinPrune(const PerfectHashJoinExecutor &executor, - const BaseStatistics &stats) { - if (!NumericStats::HasMinMax(stats)) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - const auto min = NumericStats::GetMin(stats); - const auto max = NumericStats::GetMax(stats); - if (min > max) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - T range_typed; - idx_t range; - if (!TrySubtractOperator::Operation(max, min, range_typed) || !TryCast::Operation(range_typed, range) || - range >= DEFAULT_STANDARD_VECTOR_SIZE) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - Vector range_vec(stats.GetType(), DEFAULT_STANDARD_VECTOR_SIZE); - auto range_data = FlatVector::GetDataMutable(range_vec); - T val = min; - for (; val < max; val += 1) { - *range_data++ = val; - } - *range_data = val; - - const auto total_count = NumericCast(range_typed) + 1; - idx_t approved_tuple_count = 0; - SelectionVector probe_sel(total_count); - executor.FillSelectionVectorSwitchProbe(range_vec, total_count, probe_sel, approved_tuple_count, nullptr); - - if (approved_tuple_count == 0) { - return FilterPropagateResult::FILTER_ALWAYS_FALSE; - } - if (approved_tuple_count == total_count) { - return FilterPropagateResult::FILTER_ALWAYS_TRUE; - } - return FilterPropagateResult::NO_PRUNING_POSSIBLE; -} - -ScalarFunction PerfectHashJoinScalarFun::GetFunction(const LogicalType &input_type) { - ScalarFunction func(NAME, {input_type}, LogicalType::BOOLEAN, PerfectHashJoinFunction, TableFilterFunctions::Bind); - func.SetInitStateCallback(PerfectHashJoinInitLocalState); - func.SetSelectCallback(PerfectHashJoinSelect); - func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); - func.SetFilterPruneCallback(PerfectHashJoinScalarFun::FilterPrune); - func.SetSerializeCallback(TableFilterFunctionSerialize); - func.SetDeserializeCallback(TableFilterFunctionDeserialize); - return func; -} - -string PerfectHashJoinScalarFun::ToString(const string &column_name, const string &key_column_name) { - return column_name + " IN PHJ(" + key_column_name + ")"; -} - -FilterPropagateResult PerfectHashJoinScalarFun::FilterPrune(const FunctionStatisticsPruneInput &input) { - if (!input.bind_data) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - auto &data = input.bind_data->Cast(); - if (!data.executor) { - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - - switch (input.stats.GetType().InternalType()) { - case PhysicalType::UINT8: - return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); - case PhysicalType::UINT16: - return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); - case PhysicalType::UINT32: - return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); - case PhysicalType::UINT64: - return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); - case PhysicalType::UINT128: - return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); - case PhysicalType::INT8: - return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); - case PhysicalType::INT16: - return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); - case PhysicalType::INT32: - return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); - case PhysicalType::INT64: - return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); - case PhysicalType::INT128: - return TemplatedPerfectHashJoinPrune(*data.executor, input.stats); - default: - return FilterPropagateResult::NO_PRUNING_POSSIBLE; - } -} - -ScalarFunction TableFilterPerfectHashJoinFun::GetFunction() { - return PerfectHashJoinScalarFun::GetFunction(LogicalType::ANY); -} - -} // namespace duckdb diff --git a/src/duckdb/src/planner/filter/table_filter_prefix_range_function.cpp b/src/duckdb/src/planner/filter/table_filter_prefix_range_function.cpp index dec6475df..488100e7e 100644 --- a/src/duckdb/src/planner/filter/table_filter_prefix_range_function.cpp +++ b/src/duckdb/src/planner/filter/table_filter_prefix_range_function.cpp @@ -60,14 +60,13 @@ struct PrefixRangeBitmapBuildState : public PrefixRangeFilter::BuildState { template class PrefixRangeBitmap { public: - void Initialize(ClientContext &context, U min_p, U span_p) { + void Initialize(ClientContext &context, U min_p, U span_p, idx_t max_bits) { min = min_p; span = span_p; shift = 0; - if (span >= CAP_BITS) { - const auto q = UnsafeNumericCast(span >> MAX_PREFIX_LENGTH); - shift = (q <= 1) ? 0 : (64 - CountZeros::Leading(q - 1)); + while ((span >> shift) >= max_bits) { + shift++; } const idx_t buckets = UnsafeNumericCast((span >> shift) + 1); @@ -189,8 +188,6 @@ class PrefixRangeBitmap { } private: - static constexpr idx_t MAX_PREFIX_LENGTH = 20; - static constexpr idx_t CAP_BITS = 1ULL << MAX_PREFIX_LENGTH; static constexpr idx_t WORD_SHIFT = 6; static constexpr idx_t WORD_MASK = 63; @@ -244,12 +241,13 @@ class NumericPrefixRangeFilter : public PrefixRangeFilter { using Comparable = typename MakeUnsigned::type; public: - void Initialize(ClientContext &context, idx_t number_of_rows, Value min_val, Value max_val) override { + void Initialize(ClientContext &context, idx_t number_of_rows, Value min_val, Value max_val, + idx_t max_bits) override { D_ASSERT(min_val <= max_val); D_ASSERT(number_of_rows > 0); const auto min = NumericConverter::Convert(min_val.GetValueUnsafe()); const auto max = NumericConverter::Convert(max_val.GetValueUnsafe()); - bitmap.Initialize(context, min, max - min); + bitmap.Initialize(context, min, max - min, max_bits); } unique_ptr InitializeBuildState(ClientContext &context) const override { @@ -297,13 +295,14 @@ class NumericPrefixRangeFilter : public PrefixRangeFilter { class StringPrefixRangeFilter : public PrefixRangeFilter { public: - void Initialize(ClientContext &context, idx_t number_of_rows, Value min_val, Value max_val) override { + void Initialize(ClientContext &context, idx_t number_of_rows, Value min_val, Value max_val, + idx_t max_bits) override { D_ASSERT(min_val <= max_val); D_ASSERT(number_of_rows > 0); const auto min = StringPrefixConverter::Convert(min_val.GetValueUnsafe()); const auto max = StringPrefixConverter::Convert(max_val.GetValueUnsafe()); D_ASSERT(min <= max); - bitmap.Initialize(context, min, max - min); + bitmap.Initialize(context, min, max - min, max_bits); } unique_ptr InitializeBuildState(ClientContext &context) const override { diff --git a/src/duckdb/src/planner/operator/logical_get.cpp b/src/duckdb/src/planner/operator/logical_get.cpp index 6c3512ba4..60d79b63c 100644 --- a/src/duckdb/src/planner/operator/logical_get.cpp +++ b/src/duckdb/src/planner/operator/logical_get.cpp @@ -84,7 +84,11 @@ InsertionOrderPreservingMap LogicalGet::ParamsToString() const { result["Filters"] = filters_info; if (extra_info.sample_options) { - result["Sample Method"] = "System: " + extra_info.sample_options->sample_size.ToString() + "%"; + if (extra_info.sample_options->is_percentage) { + result["Sample Method"] = "System: " + extra_info.sample_options->sample_size.ToString() + "%"; + } else { + result["Sample Method"] = "System: " + extra_info.sample_options->sample_size.ToString() + " rows"; + } } if (!extra_info.file_filters.empty()) { diff --git a/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp b/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp index d487b8886..1364e04ab 100644 --- a/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp +++ b/src/duckdb/src/planner/subquery/delim_join_cte_rewriter.cpp @@ -241,6 +241,30 @@ static bool FilterNullRejectsDelimJoinRHS(LogicalFilter &filter, LogicalComparis return false; } +static bool ExpressionIsMarkerRef(Expression &expr, TableIndex mark_index) { + if (expr.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return false; + } + auto &colref = expr.Cast(); + return colref.Depth() == 0 && colref.Binding().table_index == mark_index; +} + +static bool FilterNegatesDelimJoinMarker(LogicalFilter &filter, LogicalComparisonJoin &delim_join) { + if (filter.HasProjectionMap() || delim_join.join_type != JoinType::MARK) { + return false; + } + for (auto &expr : filter.expressions) { + if (expr->GetExpressionType() != ExpressionType::OPERATOR_NOT) { + continue; + } + auto &op = expr->Cast(); + if (!op.GetChildren().empty() && ExpressionIsMarkerRef(*op.GetChildren()[0], delim_join.mark_index)) { + return true; + } + } + return false; +} + static bool PushEligibleFilterExpressionsIntoDelimJoinInputs(unique_ptr &plan) { auto &filter = plan->Cast(); if (filter.HasProjectionMap()) { @@ -350,14 +374,68 @@ static bool HasSelection(const LogicalOperator &op) { return false; } +static bool IsEvidenceSide(LogicalOperator &op, idx_t child_idx) { + if (op.type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + return false; + } + auto &join = op.Cast(); + switch (join.join_type) { + case JoinType::MARK: + case JoinType::ANTI: + return child_idx == 1; + case JoinType::RIGHT_ANTI: + return child_idx == 0; + default: + return false; + } +} + +static bool HasEvidenceSide(JoinType join_type) { + switch (join_type) { + case JoinType::MARK: + case JoinType::ANTI: + case JoinType::RIGHT_ANTI: + return true; + default: + return false; + } +} + +static bool ContainsSubqueryJoin(LogicalOperator &op) { + if (op.type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN || op.type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { + auto &join = op.Cast(); + switch (join.join_type) { + case JoinType::MARK: + case JoinType::SEMI: + case JoinType::ANTI: + case JoinType::RIGHT_SEMI: + case JoinType::RIGHT_ANTI: + case JoinType::SINGLE: + return true; + default: + break; + } + } + for (auto &child : op.children) { + if (ContainsSubqueryJoin(*child)) { + return true; + } + } + return false; +} + struct JoinWithGeneratedDedupRef { - JoinWithGeneratedDedupRef(unique_ptr &join_p, idx_t depth_p, bool filter_cross_product_p = false) - : join(join_p), depth(depth_p), filter_cross_product(filter_cross_product_p) { + JoinWithGeneratedDedupRef(unique_ptr &join_p, idx_t depth_p, bool filter_cross_product_p = false, + bool under_aggregate_p = false, bool under_evidence_side_p = false) + : join(join_p), depth(depth_p), filter_cross_product(filter_cross_product_p), + under_aggregate(under_aggregate_p), under_evidence_side(under_evidence_side_p) { } reference> join; idx_t depth; bool filter_cross_product; + bool under_aggregate; + bool under_evidence_side; }; struct GeneratedDedupRef { @@ -368,11 +446,8 @@ struct GeneratedDedupRef { bool has_projection = false; }; -static bool IsEqualityJoinCondition(const JoinCondition &cond) { - if (!cond.IsComparison()) { - return false; - } - switch (cond.GetComparisonType()) { +static bool IsEqualityComparison(ExpressionType comparison_type) { + switch (comparison_type) { case ExpressionType::COMPARE_EQUAL: case ExpressionType::COMPARE_NOT_DISTINCT_FROM: return true; @@ -381,6 +456,10 @@ static bool IsEqualityJoinCondition(const JoinCondition &cond) { } } +static bool IsEqualityJoinCondition(const JoinCondition &cond) { + return cond.IsComparison() && IsEqualityComparison(cond.GetComparisonType()); +} + static bool FindAndReplaceBindings(vector &traced_bindings, const vector> &expressions, const vector ¤t_bindings) { @@ -410,6 +489,20 @@ class ExpressionBindingReplacer : public LogicalOperatorVisitor { D_ASSERT(bindings.size() == expressions.size()); } + ExpressionBindingReplacer(const vector &bindings, const vector> &expressions, + optional_ptr stop_operator) + : bindings(bindings), expressions(expressions), stop_operator(stop_operator) { + D_ASSERT(bindings.size() == expressions.size()); + } + + void VisitOperator(LogicalOperator &op) override { + if (stop_operator && stop_operator.get() == &op) { + return; + } + VisitOperatorChildren(op); + VisitOperatorExpressions(op); + } + unique_ptr VisitReplace(BoundColumnRefExpression &expr, unique_ptr *expr_ptr) override { if (expr.Depth() != 0) { return nullptr; @@ -425,6 +518,7 @@ class ExpressionBindingReplacer : public LogicalOperatorVisitor { private: const vector &bindings; const vector> &expressions; + optional_ptr stop_operator; }; static void ReplaceExpressionBindings(unique_ptr &expr, const vector &bindings, @@ -436,7 +530,7 @@ static void ReplaceExpressionBindings(unique_ptr &expr, const vector replacer.VisitExpression(&expr); } -static bool GetBoundColumnRefBinding(Expression &expr, ColumnBinding &binding) { +static bool GetBoundColumnRefBinding(const Expression &expr, ColumnBinding &binding) { if (expr.GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { return false; } @@ -499,30 +593,43 @@ static void ReplaceOperatorBindings(LogicalOperator &op, const vector &bindings, + const vector> &expressions, + optional_ptr stop_operator) { + if (bindings.empty()) { + return; + } + ExpressionBindingReplacer replacer(bindings, expressions, stop_operator); + replacer.VisitOperator(op); +} + class GeneratedDedupRefEliminator { public: GeneratedDedupRefEliminator(LogicalComparisonJoin &delim_join, TableIndex dedup_cte_index, idx_t dedup_ref_count, - LogicalOperator &rewrite_root); + LogicalOperator &rewrite_root, bool preserve_evidence_side); idx_t Remove(); private: unique_ptr GetGeneratedDedupRef(LogicalOperator &op, bool collect_filters = false, bool allow_projection = false) const; - bool ExpressionReferencesGeneratedDedupRef(Expression &expr, const GeneratedDedupRef &dedup_ref) const; + bool ExpressionReferencesGeneratedDedupRef(const Expression &expr, const GeneratedDedupRef &dedup_ref) const; bool AddReplacement(vector &replacements, ColumnBinding old_binding, ColumnBinding new_binding) const; bool AddReplacement(vector &replacements, ColumnBinding old_binding, ColumnBinding new_binding, const LogicalType &new_type) const; bool CoversAllDedupColumns(const GeneratedDedupRef &dedup_ref, const vector &bindings) const; - optional_idx FindGeneratedOutputBinding(Expression &expr, const GeneratedDedupRef &dedup_ref) const; - bool ExpressionReferencesGeneratedSide(Expression &expr, const GeneratedDedupRef &dedup_ref) const; + optional_idx FindGeneratedOutputBinding(const Expression &expr, const GeneratedDedupRef &dedup_ref) const; + bool ExpressionReferencesGeneratedSide(const Expression &expr, const GeneratedDedupRef &dedup_ref) const; bool FilterIsGeneratedDedupCrossProduct(LogicalOperator &op) const; void FindJoinsWithGeneratedDedupRefs(unique_ptr &op, vector &joins, - idx_t depth = 0) const; + idx_t depth = 0, bool under_aggregate = false, + bool under_evidence_side = false) const; idx_t CountGeneratedDedupRefs(LogicalOperator &op) const; bool RemoveInequalityJoinConditions(LogicalOperator &target_op, const vector &join_conditions, idx_t dedup_idx); + bool PreserveJoinAsSemi(unique_ptr &join); + bool PreserveFilterCrossProductAsSemi(unique_ptr &filter_op); bool RemoveJoin(unique_ptr &join); bool RemoveFilterCrossProduct(unique_ptr &filter_op); @@ -531,12 +638,14 @@ class GeneratedDedupRefEliminator { TableIndex dedup_cte_index; idx_t dedup_ref_count; LogicalOperator &rewrite_root; + bool preserve_evidence_side; }; GeneratedDedupRefEliminator::GeneratedDedupRefEliminator(LogicalComparisonJoin &delim_join, TableIndex dedup_cte_index, - idx_t dedup_ref_count, LogicalOperator &rewrite_root) + idx_t dedup_ref_count, LogicalOperator &rewrite_root, + bool preserve_evidence_side) : delim_join(delim_join), dedup_cte_index(dedup_cte_index), dedup_ref_count(dedup_ref_count), - rewrite_root(rewrite_root) { + rewrite_root(rewrite_root), preserve_evidence_side(preserve_evidence_side) { } unique_ptr GeneratedDedupRefEliminator::GetGeneratedDedupRef(LogicalOperator &op, @@ -601,7 +710,7 @@ unique_ptr GeneratedDedupRefEliminator::GetGeneratedDedupRef( return nullptr; } -bool GeneratedDedupRefEliminator::ExpressionReferencesGeneratedDedupRef(Expression &expr, +bool GeneratedDedupRefEliminator::ExpressionReferencesGeneratedDedupRef(const Expression &expr, const GeneratedDedupRef &dedup_ref) const { bool found = false; ExpressionIterator::VisitExpression(expr, [&](const BoundColumnRefExpression &colref) { @@ -661,7 +770,7 @@ bool GeneratedDedupRefEliminator::CoversAllDedupColumns(const GeneratedDedupRef return true; } -optional_idx GeneratedDedupRefEliminator::FindGeneratedOutputBinding(Expression &expr, +optional_idx GeneratedDedupRefEliminator::FindGeneratedOutputBinding(const Expression &expr, const GeneratedDedupRef &dedup_ref) const { optional_idx result; bool unsupported = false; @@ -682,7 +791,7 @@ optional_idx GeneratedDedupRefEliminator::FindGeneratedOutputBinding(Expression return unsupported ? optional_idx() : result; } -bool GeneratedDedupRefEliminator::ExpressionReferencesGeneratedSide(Expression &expr, +bool GeneratedDedupRefEliminator::ExpressionReferencesGeneratedSide(const Expression &expr, const GeneratedDedupRef &dedup_ref) const { if (ExpressionReferencesGeneratedDedupRef(expr, dedup_ref)) { return true; @@ -713,24 +822,27 @@ bool GeneratedDedupRefEliminator::FilterIsGeneratedDedupCrossProduct(LogicalOper } void GeneratedDedupRefEliminator::FindJoinsWithGeneratedDedupRefs(unique_ptr &op, - vector &joins, - idx_t depth) const { + vector &joins, idx_t depth, + bool under_aggregate, + bool under_evidence_side) const { if (op->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { if (!op->children.empty()) { - FindJoinsWithGeneratedDedupRefs(op->children[0], joins, depth + 1); + FindJoinsWithGeneratedDedupRefs(op->children[0], joins, depth + 1, under_aggregate, under_evidence_side); } return; } - for (auto &child : op->children) { - FindJoinsWithGeneratedDedupRefs(child, joins, depth + 1); + auto child_under_aggregate = under_aggregate || op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY; + for (idx_t child_idx = 0; child_idx < op->children.size(); child_idx++) { + FindJoinsWithGeneratedDedupRefs(op->children[child_idx], joins, depth + 1, child_under_aggregate, + under_evidence_side); } if (op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN && (GetGeneratedDedupRef(*op->children[0], false, true) || GetGeneratedDedupRef(*op->children[1], false, true))) { - joins.emplace_back(op, depth); + joins.emplace_back(op, depth, false, under_aggregate, under_evidence_side); } else if (FilterIsGeneratedDedupCrossProduct(*op)) { - joins.emplace_back(op, depth, true); + joins.emplace_back(op, depth, true, under_aggregate, under_evidence_side); } } @@ -865,6 +977,226 @@ bool GeneratedDedupRefEliminator::RemoveInequalityJoinConditions(LogicalOperator return found_all; } +bool GeneratedDedupRefEliminator::PreserveJoinAsSemi(unique_ptr &join) { + auto &comparison_join = join->Cast(); + if (comparison_join.join_type != JoinType::INNER && comparison_join.join_type != JoinType::SEMI) { + return false; + } + if (comparison_join.HasProjectionMap()) { + return false; + } + + auto left_is_generated = GetGeneratedDedupRef(*join->children[0], false, true) != nullptr; + auto right_is_generated = GetGeneratedDedupRef(*join->children[1], false, true) != nullptr; + if (left_is_generated == right_is_generated) { + return false; + } + const idx_t dedup_idx = left_is_generated ? 0 : 1; + if (comparison_join.join_type == JoinType::SEMI && dedup_idx == 0) { + return false; + } + + auto dedup_ref = GetGeneratedDedupRef(*join->children[dedup_idx], false, true); + if (!dedup_ref) { + return false; + } + + ColumnBindingReplacer replacer; + auto &replacement_bindings = replacer.replacement_bindings; + vector covered_dedup_bindings; + covered_dedup_bindings.reserve(comparison_join.conditions.size()); + vector base_replacement_bindings; + vector> base_replacement_expressions; + vector semi_conditions; + semi_conditions.reserve(comparison_join.conditions.size()); + + for (auto &cond : comparison_join.conditions) { + if (!cond.IsComparison() || !IsEqualityJoinCondition(cond)) { + return false; + } + + auto lhs_generated_idx = FindGeneratedOutputBinding(cond.GetLHS(), *dedup_ref); + auto rhs_generated_idx = FindGeneratedOutputBinding(cond.GetRHS(), *dedup_ref); + if (lhs_generated_idx.IsValid() == rhs_generated_idx.IsValid()) { + return false; + } + auto generated_idx = lhs_generated_idx.IsValid() ? lhs_generated_idx.GetIndex() : rhs_generated_idx.GetIndex(); + auto &generated_binding = dedup_ref->output_bindings[generated_idx]; + auto &generated_expression = *dedup_ref->output_expressions[generated_idx]; + auto &other_side = lhs_generated_idx.IsValid() ? cond.GetRHS() : cond.GetLHS(); + + ColumnBinding other_binding; + if (!GetBoundColumnRefBinding(other_side, other_binding)) { + return false; + } + if (!AddReplacement(replacement_bindings, generated_binding, other_binding, + generated_expression.GetReturnType())) { + return false; + } + + ColumnBinding base_binding; + if (GetBoundColumnRefBinding(generated_expression, base_binding) && + base_binding.table_index == dedup_ref->cte_ref->table_index) { + if (!AddReplacement(replacement_bindings, base_binding, other_binding)) { + return false; + } + covered_dedup_bindings.emplace_back(base_binding); + base_replacement_bindings.push_back(base_binding); + base_replacement_expressions.push_back(other_side.Copy()); + } + + auto generated_expr = lhs_generated_idx.IsValid() ? cond.GetLHS().Copy() : cond.GetRHS().Copy(); + auto other_expr = lhs_generated_idx.IsValid() ? cond.GetRHS().Copy() : cond.GetLHS().Copy(); + auto comparison_type = + lhs_generated_idx.IsValid() ? FlipComparisonExpression(cond.GetComparisonType()) : cond.GetComparisonType(); + semi_conditions.emplace_back(std::move(other_expr), std::move(generated_expr), comparison_type); + } + if (!CoversAllDedupColumns(*dedup_ref, covered_dedup_bindings)) { + return false; + } + + vector> generated_output_replacements; + generated_output_replacements.reserve(dedup_ref->output_expressions.size()); + for (auto &expr : dedup_ref->output_expressions) { + auto rewritten_expr = expr->Copy(); + ReplaceExpressionBindings(rewritten_expr, base_replacement_bindings, base_replacement_expressions); + if (ExpressionReferencesGeneratedSide(*rewritten_expr, *dedup_ref)) { + return false; + } + generated_output_replacements.push_back(std::move(rewritten_expr)); + } + + if (dedup_idx == 0) { + std::swap(comparison_join.children[0], comparison_join.children[1]); + } + comparison_join.join_type = JoinType::SEMI; + comparison_join.conditions = std::move(semi_conditions); + comparison_join.left_projection_map.clear(); + comparison_join.right_projection_map.clear(); + comparison_join.ResolveOperatorTypes(); + + replacer.stop_operator = join.get(); + replacer.VisitOperator(rewrite_root); + ReplaceOperatorBindings(rewrite_root, dedup_ref->output_bindings, generated_output_replacements, join.get()); + return true; +} + +bool GeneratedDedupRefEliminator::PreserveFilterCrossProductAsSemi(unique_ptr &filter_op) { + auto &filter = filter_op->Cast(); + if (filter.HasProjectionMap() || filter.children.size() != 1 || + filter.children[0]->type != LogicalOperatorType::LOGICAL_CROSS_PRODUCT) { + return false; + } + auto &cross_product = *filter.children[0]; + if (cross_product.children.size() != 2) { + return false; + } + + const idx_t dedup_idx = GetGeneratedDedupRef(*cross_product.children[0], false, true) ? 0 : 1; + auto dedup_ref = GetGeneratedDedupRef(*cross_product.children[dedup_idx], false, true); + if (!dedup_ref) { + return false; + } + + filter.SplitPredicates(); + vector consumed(filter.expressions.size(), false); + vector covered_dedup_bindings; + covered_dedup_bindings.reserve(dedup_ref->output_bindings.size()); + vector base_replacement_bindings; + vector> base_replacement_expressions; + vector semi_conditions; + ColumnBindingReplacer replacer; + auto &replacement_bindings = replacer.replacement_bindings; + + for (idx_t expr_idx = 0; expr_idx < filter.expressions.size(); expr_idx++) { + auto &expr = *filter.expressions[expr_idx]; + if (!BoundComparisonExpression::IsComparison(expr) || !IsEqualityComparison(expr.GetExpressionType())) { + continue; + } + auto &comparison = expr.Cast(); + auto &lhs = BoundComparisonExpression::Left(comparison); + auto &rhs = BoundComparisonExpression::Right(comparison); + + auto lhs_generated_idx = FindGeneratedOutputBinding(lhs, *dedup_ref); + auto rhs_generated_idx = FindGeneratedOutputBinding(rhs, *dedup_ref); + if (lhs_generated_idx.IsValid() == rhs_generated_idx.IsValid()) { + continue; + } + + auto generated_idx = lhs_generated_idx.IsValid() ? lhs_generated_idx.GetIndex() : rhs_generated_idx.GetIndex(); + auto &generated_binding = dedup_ref->output_bindings[generated_idx]; + auto &generated_expression = *dedup_ref->output_expressions[generated_idx]; + auto &other_side = lhs_generated_idx.IsValid() ? rhs : lhs; + + ColumnBinding other_binding; + if (!GetBoundColumnRefBinding(other_side, other_binding)) { + return false; + } + if (!AddReplacement(replacement_bindings, generated_binding, other_binding, + generated_expression.GetReturnType())) { + return false; + } + + ColumnBinding base_binding; + if (GetBoundColumnRefBinding(generated_expression, base_binding) && + base_binding.table_index == dedup_ref->cte_ref->table_index) { + if (!AddReplacement(replacement_bindings, base_binding, other_binding)) { + return false; + } + covered_dedup_bindings.emplace_back(base_binding); + base_replacement_bindings.push_back(base_binding); + base_replacement_expressions.push_back(other_side.Copy()); + } + + auto generated_expr = lhs_generated_idx.IsValid() ? lhs.Copy() : rhs.Copy(); + auto other_expr = lhs_generated_idx.IsValid() ? rhs.Copy() : lhs.Copy(); + auto comparison_type = + lhs_generated_idx.IsValid() ? FlipComparisonExpression(expr.GetExpressionType()) : expr.GetExpressionType(); + semi_conditions.emplace_back(std::move(other_expr), std::move(generated_expr), comparison_type); + consumed[expr_idx] = true; + } + + if (semi_conditions.empty() || !CoversAllDedupColumns(*dedup_ref, covered_dedup_bindings)) { + return false; + } + for (idx_t expr_idx = 0; expr_idx < filter.expressions.size(); expr_idx++) { + if (!consumed[expr_idx] && ExpressionReferencesGeneratedSide(*filter.expressions[expr_idx], *dedup_ref)) { + return false; + } + } + + vector> generated_output_replacements; + generated_output_replacements.reserve(dedup_ref->output_expressions.size()); + for (auto &expr : dedup_ref->output_expressions) { + auto rewritten_expr = expr->Copy(); + ReplaceExpressionBindings(rewritten_expr, base_replacement_bindings, base_replacement_expressions); + if (ExpressionReferencesGeneratedSide(*rewritten_expr, *dedup_ref)) { + return false; + } + generated_output_replacements.push_back(std::move(rewritten_expr)); + } + + auto semi_join = make_uniq(JoinType::SEMI); + semi_join->conditions = std::move(semi_conditions); + semi_join->children.push_back(std::move(cross_product.children[1 - dedup_idx])); + semi_join->children.push_back(std::move(cross_product.children[dedup_idx])); + semi_join->ResolveOperatorTypes(); + + unique_ptr replacement_op = std::move(semi_join); + for (idx_t expr_idx = 0; expr_idx < filter.expressions.size(); expr_idx++) { + if (consumed[expr_idx]) { + continue; + } + AddFilterToOperator(replacement_op, std::move(filter.expressions[expr_idx])); + } + + filter_op = std::move(replacement_op); + replacer.stop_operator = filter_op.get(); + replacer.VisitOperator(rewrite_root); + ReplaceOperatorBindings(rewrite_root, dedup_ref->output_bindings, generated_output_replacements, filter_op.get()); + return true; +} + bool GeneratedDedupRefEliminator::RemoveJoin(unique_ptr &join) { auto &comparison_join = join->Cast(); if (comparison_join.join_type != JoinType::INNER && comparison_join.join_type != JoinType::SEMI) { @@ -1115,7 +1447,11 @@ bool GeneratedDedupRefEliminator::RemoveFilterCrossProduct(unique_ptr joins; - FindJoinsWithGeneratedDedupRefs(delim_join.children[1], joins); + auto preserve_selected_domain = HasSelection(*delim_join.children[0]); + auto selected_evidence_side = preserve_selected_domain && preserve_evidence_side && + HasEvidenceSide(delim_join.join_type) && + !ContainsSubqueryJoin(*delim_join.children[0]); + FindJoinsWithGeneratedDedupRefs(delim_join.children[1], joins, 0, false, selected_evidence_side); if (joins.empty()) { return dedup_ref_count; } @@ -1125,11 +1461,17 @@ idx_t GeneratedDedupRefEliminator::Remove() { return lhs.depth > rhs.depth; }); - if (!joins.empty() && HasSelection(*delim_join.children[0])) { - joins.erase(joins.begin()); - } - for (auto &join : joins) { + if (preserve_selected_domain && (join.under_aggregate || join.under_evidence_side)) { + // This join is a semijoin reduction for a grouped RHS or an existence-check evidence side. Removing it is + // valid, but can turn a selective correlated subquery into a global aggregate or full evidence scan. + if (join.filter_cross_product) { + PreserveFilterCrossProductAsSemi(join.join.get()); + } else { + PreserveJoinAsSemi(join.join.get()); + } + continue; + } if (join.filter_cross_product) { RemoveFilterCrossProduct(join.join.get()); } else { @@ -1146,6 +1488,7 @@ struct GeneratedDomainRef { vector output_bindings; vector> output_expressions; vector> filters; + bool has_selection = false; }; class GeneratedDomainJoinEliminator { @@ -1158,21 +1501,25 @@ class GeneratedDomainJoinEliminator { private: void CollectCTEs(LogicalOperator &op); optional_ptr FindCTE(TableIndex cte_index) const; - bool TryRewriteOnce(unique_ptr &op); + bool TryRewriteOnce(unique_ptr &op, bool under_aggregate = false, bool under_evidence_side = false, + bool negated_marker_filter_above = false); unique_ptr GetGeneratedDedupRef(LogicalOperator &op, bool collect_filters = false, bool allow_projection = false) const; unique_ptr GetGeneratedDomainDefinition(LogicalOperator &op) const; unique_ptr GetGeneratedDomainRef(LogicalOperator &op, bool collect_filters = false, bool allow_projection = false) const; + bool CTEHasSelection(TableIndex cte_index, vector &seen_ctes) const; + bool OperatorHasSelection(LogicalOperator &op, vector &seen_ctes) const; + bool GeneratedDedupRefHasSelection(const GeneratedDedupRef &dedup_ref) const; optional_idx FindOutputBinding(Expression &expr, const vector &bindings) const; bool ContainsRecursiveCTERef(LogicalOperator &op) const; bool AddReplacement(vector &replacements, ColumnBinding old_binding, ColumnBinding new_binding) const; - bool RemoveGeneratedDedupJoin(unique_ptr &join); - bool RemoveGeneratedDomainJoin(unique_ptr &join); + bool RemoveGeneratedDedupJoin(unique_ptr &join, bool under_aggregate, bool under_evidence_side); + bool RemoveGeneratedDomainJoin(unique_ptr &join, bool under_aggregate, bool under_evidence_side); private: unique_ptr &rewrite_root; @@ -1204,6 +1551,41 @@ optional_ptr GeneratedDomainJoinEliminator::FindCTE(TableIndex cte_i return nullptr; } +bool GeneratedDomainJoinEliminator::CTEHasSelection(TableIndex cte_index, vector &seen_ctes) const { + if (std::find(seen_ctes.begin(), seen_ctes.end(), cte_index) != seen_ctes.end()) { + return false; + } + seen_ctes.push_back(cte_index); + auto cte = FindCTE(cte_index); + if (!cte || cte->children.empty()) { + return false; + } + return OperatorHasSelection(*cte->children[0], seen_ctes); +} + +bool GeneratedDomainJoinEliminator::OperatorHasSelection(LogicalOperator &op, vector &seen_ctes) const { + if (HasSelection(op)) { + return true; + } + if (op.type == LogicalOperatorType::LOGICAL_CTE_REF) { + return CTEHasSelection(op.Cast().cte_index, seen_ctes); + } + for (auto &child : op.children) { + if (OperatorHasSelection(*child, seen_ctes)) { + return true; + } + } + return false; +} + +bool GeneratedDomainJoinEliminator::GeneratedDedupRefHasSelection(const GeneratedDedupRef &dedup_ref) const { + if (!dedup_ref.cte_ref) { + return false; + } + vector seen_ctes; + return CTEHasSelection(dedup_ref.cte_ref->cte_index, seen_ctes); +} + unique_ptr GeneratedDomainJoinEliminator::GetGeneratedDedupRef(LogicalOperator &op, bool collect_filters, bool allow_projection) const { @@ -1351,6 +1733,8 @@ unique_ptr GeneratedDomainJoinEliminator::GetGeneratedDomain result->cte_ref = cteref; result->output_bindings = cteref.GetColumnBindings(); + vector seen_ctes; + result->has_selection = CTEHasSelection(cteref.cte_index, seen_ctes); return result; } if (op.type == LogicalOperatorType::LOGICAL_FILTER) { @@ -1448,7 +1832,8 @@ bool GeneratedDomainJoinEliminator::AddReplacement(vector &r return true; } -bool GeneratedDomainJoinEliminator::RemoveGeneratedDedupJoin(unique_ptr &join) { +bool GeneratedDomainJoinEliminator::RemoveGeneratedDedupJoin(unique_ptr &join, bool under_aggregate, + bool under_evidence_side) { auto &comparison_join = join->Cast(); if (comparison_join.join_type != JoinType::INNER && (comparison_join.join_type != JoinType::SEMI || !GetGeneratedDedupRef(*join->children[1], false, true))) { @@ -1468,6 +1853,11 @@ bool GeneratedDomainJoinEliminator::RemoveGeneratedDedupJoin(unique_ptr covered_dedup_bindings; vector base_replacement_bindings; @@ -1562,7 +1952,8 @@ bool GeneratedDomainJoinEliminator::RemoveGeneratedDedupJoin(unique_ptr &join) { +bool GeneratedDomainJoinEliminator::RemoveGeneratedDomainJoin(unique_ptr &join, bool under_aggregate, + bool under_evidence_side) { auto &comparison_join = join->Cast(); if (comparison_join.join_type != JoinType::INNER) { return false; @@ -1590,6 +1981,11 @@ bool GeneratedDomainJoinEliminator::RemoveGeneratedDomainJoin(unique_ptrhas_selection || !domain_ref->filters.empty())) { + // Same invariant as above: selected domains below aggregates or existence checks are part of the physical + // reduction. + return false; + } vector source_replacement_bindings; vector> source_replacement_expressions; @@ -1674,16 +2070,40 @@ bool GeneratedDomainJoinEliminator::RemoveGeneratedDomainJoin(unique_ptr &op) { - for (auto &child : op->children) { - if (TryRewriteOnce(child)) { +bool GeneratedDomainJoinEliminator::TryRewriteOnce(unique_ptr &op, bool under_aggregate, + bool under_evidence_side, bool negated_marker_filter_above) { + auto child_under_aggregate = under_aggregate || op->type == LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY; + for (idx_t child_idx = 0; child_idx < op->children.size(); child_idx++) { + auto &child = op->children[child_idx]; + auto child_under_evidence_side = under_evidence_side; + if (negated_marker_filter_above && IsEvidenceSide(*op, child_idx)) { + child_under_evidence_side = true; + } else if (!under_evidence_side && op->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + auto &join = op->Cast(); + if ((join.join_type == JoinType::ANTI || join.join_type == JoinType::RIGHT_ANTI) && + IsEvidenceSide(*op, child_idx)) { + child_under_evidence_side = true; + } + } + + bool child_negated_marker_filter_above = false; + if (op->type == LogicalOperatorType::LOGICAL_FILTER && child_idx == 0 && + child->type == LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { + auto &filter = op->Cast(); + auto &join = child->Cast(); + child_negated_marker_filter_above = FilterNegatesDelimJoinMarker(filter, join); + } + + if (TryRewriteOnce(child, child_under_aggregate, child_under_evidence_side, + child_negated_marker_filter_above)) { return true; } } if (op->type != LogicalOperatorType::LOGICAL_COMPARISON_JOIN) { return false; } - return RemoveGeneratedDomainJoin(op) || RemoveGeneratedDedupJoin(op); + return RemoveGeneratedDomainJoin(op, under_aggregate, under_evidence_side) || + RemoveGeneratedDedupJoin(op, under_aggregate, under_evidence_side); } bool GeneratedDomainJoinEliminator::Rewrite() { @@ -1764,25 +2184,21 @@ DelimJoinCTERewriter::DelimJoinCTERewriter(Binder &binder) : binder(binder) { } void DelimJoinCTERewriter::MaterializeDelimJoinAsCTE(unique_ptr &plan, LogicalOperator &rewrite_root, - bool null_rejecting_filter_above) { + bool null_rejecting_filter_above, bool preserve_evidence_side) { auto &join = plan->Cast(); if (join.delim_flipped) { throw InternalException("Flatten dependent joins - flipped delim join CTE rewrite not supported"); } plan->type = LogicalOperatorType::LOGICAL_COMPARISON_JOIN; - if (join.join_type == JoinType::MARK) { - // Match the LOGICAL_DELIM_JOIN filter-pushdown semantics: the mark column can still be required above - // this join, so pushing NOT(mark) into the join must not drop it by rewriting to ANTI. - join.convert_mark_to_semi = false; - } auto dedup_cte_index = binder.GenerateTableIndex(); auto dedup_ref_count = RewriteDelimScanReferences(plan->children[1], dedup_cte_index); if (cte_deliminator_enabled) { auto cte_deliminator_timer = QueryProfiler::Get(binder.context).StartTimerInternal(CTE_DELIMINATOR_PROFILER_KEY); - GeneratedDedupRefEliminator eliminator(join, dedup_cte_index, dedup_ref_count, rewrite_root); + GeneratedDedupRefEliminator eliminator(join, dedup_cte_index, dedup_ref_count, rewrite_root, + preserve_evidence_side); dedup_ref_count = eliminator.Remove(); if (SingleJoinRHSIsDeduplicated(join)) { join.join_type = null_rejecting_filter_above ? JoinType::INNER : JoinType::LEFT; @@ -1904,21 +2320,23 @@ void DelimJoinCTERewriter::MaterializeDelimJoinAsCTE(unique_ptr } void DelimJoinCTERewriter::RewriteDelimJoinsToCTEs(unique_ptr &plan, LogicalOperator &rewrite_root, - bool null_rejecting_filter_above) { + bool null_rejecting_filter_above, bool preserve_evidence_side) { for (auto &child : plan->children) { auto old_child_bindings = child->GetColumnBindings(); bool child_null_rejecting_filter_above = false; + bool child_preserve_evidence_side = false; if (plan->type == LogicalOperatorType::LOGICAL_FILTER && child->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { auto &filter = plan->Cast(); auto &delim_join = child->Cast(); child_null_rejecting_filter_above = FilterNullRejectsDelimJoinRHS(filter, delim_join); + child_preserve_evidence_side = FilterNegatesDelimJoinMarker(filter, delim_join); } - RewriteDelimJoinsToCTEs(child, rewrite_root, child_null_rejecting_filter_above); + RewriteDelimJoinsToCTEs(child, rewrite_root, child_null_rejecting_filter_above, child_preserve_evidence_side); RewriteChangedChildBindings(*plan, *child, old_child_bindings); } if (plan->type == LogicalOperatorType::LOGICAL_DELIM_JOIN) { - MaterializeDelimJoinAsCTE(plan, rewrite_root, null_rejecting_filter_above); + MaterializeDelimJoinAsCTE(plan, rewrite_root, null_rejecting_filter_above, preserve_evidence_side); } } diff --git a/src/duckdb/src/planner/table_filter_set.cpp b/src/duckdb/src/planner/table_filter_set.cpp index bdb3a32d2..e14bb4337 100644 --- a/src/duckdb/src/planner/table_filter_set.cpp +++ b/src/duckdb/src/planner/table_filter_set.cpp @@ -246,8 +246,7 @@ static unique_ptr SerializeInternalFunctionToLegacyFilter(const Bou auto &data = func_expr.BindInfo()->Cast(); return make_uniq(data.filter_data); } - if (func_name == BloomFilterScalarFun::NAME || func_name == PerfectHashJoinScalarFun::NAME || - func_name == PrefixRangeScalarFun::NAME) { + if (func_name == BloomFilterScalarFun::NAME || func_name == PrefixRangeScalarFun::NAME) { return make_uniq(); } throw SerializationException("Unsupported internal tablefilter function \"%s\" during serialization", func_name); diff --git a/src/duckdb/src/storage/buffer/block_handle.cpp b/src/duckdb/src/storage/buffer/block_handle.cpp index c05ab53ae..78c987a55 100644 --- a/src/duckdb/src/storage/buffer/block_handle.cpp +++ b/src/duckdb/src/storage/buffer/block_handle.cpp @@ -13,8 +13,8 @@ namespace duckdb { BlockMemory::BlockMemory(BufferManager &buffer_manager, block_id_t block_id_p, MemoryTag tag_p, idx_t block_alloc_size_p) : buffer_manager(buffer_manager), block_id(block_id_p), state(BlockState::BLOCK_UNLOADED), readers(0), tag(tag_p), - buffer_type(FileBufferType::BLOCK), buffer(nullptr), eviction_seq_num(0), lru_timestamp_msec(), - destroy_buffer_upon(DestroyBufferUpon::BLOCK), memory_usage(block_alloc_size_p), + buffer_type(FileBufferType::BLOCK), buffer(nullptr), eviction_seq_num(0), has_queue_entry(false), + lru_timestamp_msec(), destroy_buffer_upon(DestroyBufferUpon::BLOCK), memory_usage(block_alloc_size_p), memory_charge(tag, buffer_manager.GetBufferPool()), unswizzled(nullptr), eviction_queue_idx(DConstants::INVALID_INDEX) { } @@ -23,8 +23,8 @@ BlockMemory::BlockMemory(BufferManager &buffer_manager, block_id_t block_id_p, M unique_ptr buffer_p, DestroyBufferUpon destroy_buffer_upon_p, idx_t size_p, BufferPoolReservation &&reservation) : buffer_manager(buffer_manager), block_id(block_id_p), state(BlockState::BLOCK_LOADED), readers(0), tag(tag_p), - buffer_type(buffer_p->GetBufferType()), buffer(std::move(buffer_p)), eviction_seq_num(0), lru_timestamp_msec(), - destroy_buffer_upon(destroy_buffer_upon_p), memory_usage(size_p), + buffer_type(buffer_p->GetBufferType()), buffer(std::move(buffer_p)), eviction_seq_num(0), has_queue_entry(false), + lru_timestamp_msec(), destroy_buffer_upon(destroy_buffer_upon_p), memory_usage(size_p), memory_charge(tag, buffer_manager.GetBufferPool()), unswizzled(nullptr), eviction_queue_idx(DConstants::INVALID_INDEX) { memory_charge = std::move(reservation); // Moved to constructor body due to tidy check. @@ -34,9 +34,12 @@ BlockMemory::~BlockMemory() { // NOLINT: allow internal exceptions // The block memory is being destroyed, meaning that any unswizzled pointers are now binary junk. SetSwizzling(nullptr); D_ASSERT(!GetBuffer() || GetBuffer()->GetBufferType() == GetBufferType()); - if (GetEvictionSequenceNumber() > 0 && GetBufferType() != FileBufferType::TINY_BUFFER) { - // eviction_seq_num > 0 means there is a live queue entry for this block (it's reset - // to 0 on unload/evict). That entry is now dead — account for it. + if (HasLiveQueueEntry() && GetBufferType() != FileBufferType::TINY_BUFFER) { + // The block still has a live entry in the eviction queue. That entry is now dead; + // account for it. (No lock needed: the destructor has exclusive ownership.) + // Note: the weak pointer in the queue entry can become unlockable before this + // destructor body runs, so a queue consumer can briefly decrement before this increment. + // This increment repairs the final count for that expired live entry. GetBufferManager().GetBufferPool().IncrementDeadNodes(*this); } @@ -136,7 +139,6 @@ unique_ptr BlockMemory::UnloadAndTakeBlock(BlockLock &l) { // Thus, we write to it to a temporary file. buffer_manager.WriteTemporaryBuffer(GetMemoryTag(), BlockId(), *GetBuffer()); } - eviction_seq_num = 0; memory_charge.Resize(0); SetState(BlockState::BLOCK_UNLOADED); return std::move(GetBuffer()); diff --git a/src/duckdb/src/storage/buffer/block_manager.cpp b/src/duckdb/src/storage/buffer/block_manager.cpp index 0ac33b53e..44abcf0a5 100644 --- a/src/duckdb/src/storage/buffer/block_manager.cpp +++ b/src/duckdb/src/storage/buffer/block_manager.cpp @@ -105,7 +105,12 @@ shared_ptr BlockManager::ConvertToPersistent(QueryContext context, old_block.reset(); // potentially purge the queue - auto purge_queue = buffer_manager.GetBufferPool().AddToEvictionQueue(new_block); + bool purge_queue; + { + // AddToEvictionQueue requires the block lock. new_block was just created here, so this is uncontended. + auto new_lock = new_block->GetMemory().GetLock(); + purge_queue = buffer_manager.GetBufferPool().AddToEvictionQueue(new_lock, new_block); + } if (purge_queue) { buffer_manager.GetBufferPool().PurgeQueue(*new_block); } diff --git a/src/duckdb/src/storage/buffer/buffer_pool.cpp b/src/duckdb/src/storage/buffer/buffer_pool.cpp index fb4ad93c8..a563ef21f 100644 --- a/src/duckdb/src/storage/buffer/buffer_pool.cpp +++ b/src/duckdb/src/storage/buffer/buffer_pool.cpp @@ -44,28 +44,6 @@ BufferEvictionNode::BufferEvictionNode(weak_ptr block_memory_p, idx D_ASSERT(!memory_p.expired()); } -bool BufferEvictionNode::CanUnload(BlockMemory &memory) { - if (handle_sequence_number != memory.GetEvictionSequenceNumber()) { - // handle was used in between - return false; - } - return memory.CanUnload(); -} - -shared_ptr BufferEvictionNode::TryGetBlockMemory() { - auto shared_memory_p = memory_p.lock(); - if (!shared_memory_p) { - // The block memory has been destroyed. - return nullptr; - } - if (!CanUnload(*shared_memory_p)) { - // The memory handle was used in between. - return nullptr; - } - // The node is the latest node in the queue with this memory. - return shared_memory_p; -} - bool BufferEvictionNode::IsDeadNode(optional_idx debug_sleep_micros) { auto shared_memory_p = memory_p.lock(); if (debug_sleep_micros.IsValid()) { @@ -290,24 +268,29 @@ BufferPool::BufferPool(BlockAllocator &block_allocator, idx_t maximum_memory, bo BufferPool::~BufferPool() { } -bool BufferPool::AddToEvictionQueue(shared_ptr &handle) { +bool BufferPool::AddToEvictionQueue(BlockLock &lock, shared_ptr &handle) { auto &memory = handle->GetMemory(); + // Verify the caller passed this block's lock before we mutate any of its state. + // The block lock is held throughout: Unpin holds it; ConvertToPersistent acquires the + // (uncontended) lock of the freshly created block before calling. + memory.VerifyMutex(lock); auto &queue = GetEvictionQueueForBlockMemory(memory); - // The block handle is locked during this operation (Unpin), - // or the block handle is still a local variable (ConvertToPersistent) D_ASSERT(memory.GetReaders() == 0); + if (memory.HasLiveQueueEntry(lock)) { + // Count the previous live entry before bumping the sequence number. PurgeIteration + // reads sequence numbers without the block lock; bumping first could let it see the + // previous entry as stale and decrement dead_nodes before this matching increment. + queue.IncrementDeadNodes(); + } + auto ts = memory.NextEvictionSequenceNumber(); if (track_eviction_timestamps) { memory.SetLRUTimestamp(std::chrono::time_point_cast(std::chrono::steady_clock::now()) .time_since_epoch() .count()); } - - if (ts != 1) { - // we add a newer version, i.e., we kill exactly one previous version - queue.IncrementDeadNodes(); - } + memory.SetHasLiveQueueEntry(lock, true); // Get the eviction queue for the block and add it BufferEvictionNode node(handle->GetMemoryWeak(), ts); @@ -492,7 +475,7 @@ void EvictionQueue::IterateUnloadableBlocks(FN fn) { } // get a reference to the underlying block pointer - auto handle = node.TryGetBlockMemory(); + auto handle = node.memory_p.lock(); if (debug_sleep_micros > 0) { // Debug race conditions regarding the ownership of the BlockMemory. // Note that for this to trigger we need at least one purge iteration with the setting active. @@ -505,11 +488,19 @@ void EvictionQueue::IterateUnloadableBlocks(FN fn) { // we might be able to free this block: grab the mutex and check if we can free it auto lock = handle->GetLock(); - if (!node.CanUnload(*handle)) { - // something changed in the mean-time, bail out + if (node.handle_sequence_number != handle->GetEvictionSequenceNumber()) { + // A newer entry superseded this node: it was counted as dead when that entry was added. DecrementDeadNodes(); continue; } + // This node is the block's live queue entry, and we just dequeued it: the block no longer + // has an entry in the queue. Live entries are never counted as dead, so no decrement. + handle->SetHasLiveQueueEntry(lock, false); + if (!handle->CanUnload()) { + // The block cannot be unloaded right now (e.g. it is pinned). It gets a new queue + // entry when it is unpinned again. + continue; + } if (!fn(node, handle, lock)) { break; diff --git a/src/duckdb/src/storage/compression/numeric_constant.cpp b/src/duckdb/src/storage/compression/numeric_constant.cpp index 854059058..0ff573aea 100644 --- a/src/duckdb/src/storage/compression/numeric_constant.cpp +++ b/src/duckdb/src/storage/compression/numeric_constant.cpp @@ -156,17 +156,6 @@ static bool TryExpressionFiltersNullValues(const Expression &expression, bool &f } return TryExpressionFiltersNullValues(*data.child_filter_expr, filters_nulls, filters_valid_values); } - if (function_name == PerfectHashJoinScalarFun::NAME) { - if (!func_expr->BindInfo()) { - return true; - } - auto &data = func_expr->BindInfo()->Cast(); - if (!data.executor) { - return true; - } - filters_nulls = true; - return true; - } if (function_name == PrefixRangeScalarFun::NAME) { if (!func_expr->BindInfo()) { return true; diff --git a/src/duckdb/src/storage/serialization/serialize_nodes.cpp b/src/duckdb/src/storage/serialization/serialize_nodes.cpp index 85d4197f4..1045f1e56 100644 --- a/src/duckdb/src/storage/serialization/serialize_nodes.cpp +++ b/src/duckdb/src/storage/serialization/serialize_nodes.cpp @@ -531,6 +531,7 @@ void SampleOptions::Serialize(Serializer &serializer) const { serializer.WriteProperty(102, "method", method); serializer.WritePropertyWithDefault(103, "seed", GetSeed()); serializer.WritePropertyWithDefault(104, "repeatable", repeatable); + serializer.WritePropertyWithDefault(105, "sample_rate", sample_rate, -1.0); } unique_ptr SampleOptions::Deserialize(Deserializer &deserializer) { @@ -543,6 +544,7 @@ unique_ptr SampleOptions::Deserialize(Deserializer &deserializer) result->is_percentage = is_percentage; result->method = method; deserializer.ReadPropertyWithDefault(104, "repeatable", result->repeatable); + deserializer.ReadPropertyWithExplicitDefault(105, "sample_rate", result->sample_rate, -1.0); return result; } diff --git a/src/duckdb/src/storage/standard_buffer_manager.cpp b/src/duckdb/src/storage/standard_buffer_manager.cpp index 66245acd6..999393680 100644 --- a/src/duckdb/src/storage/standard_buffer_manager.cpp +++ b/src/duckdb/src/storage/standard_buffer_manager.cpp @@ -376,7 +376,8 @@ void StandardBufferManager::PurgeQueue(const BlockHandle &handle) { } void StandardBufferManager::AddToEvictionQueue(shared_ptr &handle) { - buffer_pool.AddToEvictionQueue(handle); + auto lock = handle->GetMemory().GetLock(); + buffer_pool.AddToEvictionQueue(lock, handle); } void StandardBufferManager::VerifyZeroReaders(BlockLock &lock, shared_ptr &handle) { @@ -412,7 +413,7 @@ void StandardBufferManager::Unpin(shared_ptr &handle) { if (new_readers == 0) { VerifyZeroReaders(lock, handle); if (block_memory.MustAddToEvictionQueue()) { - purge = buffer_pool.AddToEvictionQueue(handle); + purge = buffer_pool.AddToEvictionQueue(lock, handle); } else { block_memory.Unload(lock); } diff --git a/src/duckdb/src/storage/table/row_group.cpp b/src/duckdb/src/storage/table/row_group.cpp index 46b0524c9..c7b6b9dc3 100644 --- a/src/duckdb/src/storage/table/row_group.cpp +++ b/src/duckdb/src/storage/table/row_group.cpp @@ -3,6 +3,7 @@ #include "duckdb/transaction/commit_state.hpp" #include "duckdb/common/exception.hpp" +#include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/serializer/binary_serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" @@ -654,6 +655,45 @@ void RowGroup::NextVector(CollectionScanState &state) { } } +static idx_t SystemRowsSelection(const ScanSamplingInfo &sampling_info, idx_t start_row, idx_t count, + SelectionVector &sel) { + auto rate = sampling_info.sample_rate; + if (rate >= 1) { + return count; + } + idx_t result_count = 0; + for (idx_t i = 0; i < count; i++) { + auto row_idx = start_row + i; + auto before = std::floor(LossyNumericCast(row_idx) * rate + sampling_info.sample_phase); + auto after = std::floor(LossyNumericCast(row_idx + 1) * rate + sampling_info.sample_phase); + if (after > before) { + sel.set_index(result_count++, i); + } + } + return result_count; +} + +static idx_t IntersectSelections(const SelectionVector &left, idx_t left_count, const SelectionVector &right, + idx_t right_count, SelectionVector &result) { + idx_t left_idx = 0; + idx_t right_idx = 0; + idx_t result_count = 0; + while (left_idx < left_count && right_idx < right_count) { + auto left_entry = left.get_index(left_idx); + auto right_entry = right.get_index(right_idx); + if (left_entry == right_entry) { + result.set_index(result_count++, left_entry); + left_idx++; + right_idx++; + } else if (left_entry < right_entry) { + left_idx++; + } else { + right_idx++; + } + } + return result_count; +} + FilterPropagateResult RowGroup::CheckRowIdFilter(const TableFilter &filter, idx_t beg_row, idx_t end_row) { // RowId columns dont have a zonemap, but we can trivially create stats to check the filter against. BaseStatistics dummy_stats = NumericStats::CreateEmpty(LogicalType::ROW_TYPE); @@ -678,7 +718,7 @@ bool RowGroup::CheckZonemap(optional_ptr context, ScanFilterInfo if (prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { return false; } - if (ExpressionFilter::IsRootOptionalFilter(filter)) { + if (ExpressionFilter::IsRootNonSelectivityOptionalFilter(filter)) { // these are only for row group checking, set as always true so we don't check it filters.SetFilterAlwaysTrue(i); } else if (prune_result == FilterPropagateResult::FILTER_ALWAYS_TRUE) { @@ -758,13 +798,37 @@ void RowGroup::Scan(ScanOptions options, CollectionScanState &state, DataChunk & return; } idx_t current_row = state.vector_index * STANDARD_VECTOR_SIZE; - auto max_count = MinValue(STANDARD_VECTOR_SIZE, state.max_row_group_row - current_row); + idx_t max_count = MinValue(STANDARD_VECTOR_SIZE, state.max_row_group_row - current_row); + bool has_sample_selection = false; + idx_t sample_count = max_count; + SelectionVector sample_sel(STANDARD_VECTOR_SIZE); // check the sampling info if we have to sample this chunk - if (state.GetSamplingInfo().do_system_sample && - state.random.NextRandom() > state.GetSamplingInfo().sample_rate) { - NextVector(state); - continue; + if (state.GetSamplingInfo().do_system_sample) { + auto &sampling_info = state.GetSamplingInfo(); + if (!sampling_info.is_percentage) { + double rate = sampling_info.sample_rate; + if (rate <= 0) { + NextVector(state); + continue; + } + if (rate < 1) { + auto row_group_start = state.row_group->GetRowStart(); + sample_count = + SystemRowsSelection(sampling_info, row_group_start + current_row, max_count, sample_sel); + if (sample_count == 0) { + NextVector(state); + continue; + } + has_sample_selection = true; + } + } else { + // Percentage-based system sampling: original behavior + if (state.random.NextRandom() > sampling_info.sample_rate) { + NextVector(state); + continue; + } + } } //! first check the zonemap if we have to scan this partition @@ -803,12 +867,30 @@ void RowGroup::Scan(ScanOptions options, CollectionScanState &state, DataChunk & // pass max_count explicitly so we never read past the row count we captured at scan // init time (concurrent inserts can grow the column past max_count) col_data.Scan(transaction, state.vector_index, state.column_scans[i], result.data[i], max_count); + if (has_sample_selection) { + result.data[i].Slice(sample_sel, sample_count); + } + } + if (has_sample_selection) { + count = sample_count; } } else { // partial scan: we have deletions or table filters idx_t approved_tuple_count = count; SelectionVector sel; - if (count != max_count) { + SelectionVector intersect_sel(STANDARD_VECTOR_SIZE); + if (has_sample_selection && count != max_count) { + approved_tuple_count = + IntersectSelections(state.valid_sel, count, sample_sel, sample_count, intersect_sel); + if (approved_tuple_count == 0) { + NextVector(state); + continue; + } + sel.Initialize(intersect_sel); + } else if (has_sample_selection) { + approved_tuple_count = sample_count; + sel.Initialize(sample_sel); + } else if (count != max_count) { sel.Initialize(state.valid_sel); } else { sel.Initialize(nullptr); diff --git a/src/duckdb/src/storage/table/scan_state.cpp b/src/duckdb/src/storage/table/scan_state.cpp index 64bbcda2c..76c435600 100644 --- a/src/duckdb/src/storage/table/scan_state.cpp +++ b/src/duckdb/src/storage/table/scan_state.cpp @@ -7,6 +7,7 @@ #include "duckdb/storage/table/row_group.hpp" #include "duckdb/storage/table/row_group_collection.hpp" #include "duckdb/storage/table/row_group_segment_tree.hpp" +#include "duckdb/common/numeric_utils.hpp" namespace duckdb { @@ -17,15 +18,38 @@ TableScanState::~TableScanState() { } void TableScanState::Initialize(vector column_ids_p, optional_ptr context, - optional_ptr table_filters, - optional_ptr table_sampling) { + optional_ptr table_filters, optional_ptr table_sampling, + idx_t estimated_table_row_count) { this->column_ids = std::move(column_ids_p); if (table_filters) { filters.Initialize(*context, *table_filters, column_ids); } if (table_sampling) { sampling_info.do_system_sample = table_sampling->method == SampleMethod::SYSTEM_SAMPLE; - sampling_info.sample_rate = table_sampling->sample_size.GetValue() / 100.0; + if (table_sampling->is_percentage) { + // Percentage-based system sampling + sampling_info.is_percentage = true; + sampling_info.sample_rate = table_sampling->sample_size.GetValue() / 100.0; + } else { + // Row-count based system sampling: convert target row count to approximate rate. + // Prefer the pre-calculated sample_rate from the optimizer if available, + // otherwise derive from estimated_table_row_count. + sampling_info.is_percentage = false; + sampling_info.target_sample_rows = NumericCast(table_sampling->sample_size.GetValue()); + if (table_sampling->sample_rate > 0) { + sampling_info.sample_rate = table_sampling->sample_rate; + } else if (estimated_table_row_count > 0) { + sampling_info.sample_rate = static_cast(sampling_info.target_sample_rows) / + static_cast(estimated_table_row_count); + } else { + // No estimate available, use a conservative rate + sampling_info.sample_rate = 1.0; + } + sampling_info.sample_rate = MinValue(1.0, MaxValue(0.0, sampling_info.sample_rate)); + RandomEngine random(table_sampling->seed.IsValid() ? static_cast(table_sampling->seed.GetIndex()) + : -1); + sampling_info.sample_phase = random.NextRandom(); + } if (table_sampling->seed.IsValid()) { table_state.random.SetSeed(table_sampling->seed.GetIndex()); } diff --git a/src/duckdb/src/storage/temporary_memory_manager.cpp b/src/duckdb/src/storage/temporary_memory_manager.cpp index 88072834b..0b7a3841b 100644 --- a/src/duckdb/src/storage/temporary_memory_manager.cpp +++ b/src/duckdb/src/storage/temporary_memory_manager.cpp @@ -76,6 +76,14 @@ idx_t TemporaryMemoryManager::DefaultMinimumReservation() const { memory_limit / MINIMUM_RESERVATION_MEMORY_LIMIT_DIVISOR); } +idx_t TemporaryMemoryManager::CapReservation(idx_t reservation) const { + return MinValue(reservation, memory_limit); +} + +idx_t TemporaryMemoryManager::MinimumReservation(const TemporaryMemoryState &temporary_memory_state) const { + return CapReservation(temporary_memory_state.GetMinimumReservation()); +} + void TemporaryMemoryManager::Unregister(TemporaryMemoryState &temporary_memory_state) { const annotated_lock_guard guard(lock); @@ -107,8 +115,8 @@ unique_ptr TemporaryMemoryManager::Register(ClientContext UpdateConfiguration(context); auto result = unique_ptr(new TemporaryMemoryState(*this, DefaultMinimumReservation())); - SetRemainingSize(*result, result->GetMinimumReservation()); - SetReservation(*result, result->GetMinimumReservation()); + SetRemainingSize(*result, MinimumReservation(*result)); + SetReservation(*result, MinimumReservation(*result)); active_states.insert(*result); Verify(); @@ -120,7 +128,7 @@ void TemporaryMemoryManager::UpdateState(ClientContext &context, TemporaryMemory // The lower bound for the reservation of this state is either the minimum reservation or the remaining size const auto lower_bound = - MinValue(temporary_memory_state.GetMinimumReservation(), temporary_memory_state.GetRemainingSize()); + MinValue(MinimumReservation(temporary_memory_state), temporary_memory_state.GetRemainingSize()); if (temporary_memory_state.GetRemainingSize() == 0) { // Sometimes set to 0 to denote end of state (before actually deleting the state) @@ -166,20 +174,21 @@ void TemporaryMemoryManager::SetRemainingSize(TemporaryMemoryState &temporary_me } void TemporaryMemoryManager::SetReservation(TemporaryMemoryState &temporary_memory_state, idx_t new_reservation) { + new_reservation = CapReservation(new_reservation); D_ASSERT(this->reservation >= temporary_memory_state.GetReservation()); this->reservation -= temporary_memory_state.GetReservation(); temporary_memory_state.reservation = new_reservation; this->reservation += temporary_memory_state.GetReservation(); } -//! Compute initial reservation for use in ComputeReservation -static idx_t ComputeInitialReservation(const TemporaryMemoryState &temporary_memory_state) { +idx_t TemporaryMemoryManager::ComputeInitialReservation(const TemporaryMemoryState &temporary_memory_state) const { // Maximum of minimum reservation and the current reservation - auto result = MaxValue(temporary_memory_state.GetMinimumReservation(), temporary_memory_state.GetReservation()); + auto result = + MaxValue(MinimumReservation(temporary_memory_state), CapReservation(temporary_memory_state.GetReservation())); // Bounded by the remaining size result = MinValue(result, temporary_memory_state.GetRemainingSize()); // At least 1 - return MaxValue(result, 1); + return MaxValue(result, MinValue(memory_limit, 1)); } static void ComputeDerivatives(const vector> &states, const vector &res, diff --git a/src/duckdb/ub_src_common_serializer.cpp b/src/duckdb/ub_src_common_serializer.cpp index c640128b7..99b04f84b 100644 --- a/src/duckdb/ub_src_common_serializer.cpp +++ b/src/duckdb/ub_src_common_serializer.cpp @@ -1,3 +1,7 @@ +#include "src/common/serializer/async_file_writer.cpp" + +#include "src/common/serializer/async_write_queue.cpp" + #include "src/common/serializer/binary_deserializer.cpp" #include "src/common/serializer/binary_serializer.cpp" diff --git a/src/duckdb/ub_src_optimizer.cpp b/src/duckdb/ub_src_optimizer.cpp index 5f936aea0..acd0f1400 100644 --- a/src/duckdb/ub_src_optimizer.cpp +++ b/src/duckdb/ub_src_optimizer.cpp @@ -32,6 +32,8 @@ #include "src/optimizer/filter_pushdown.cpp" +#include "src/optimizer/grouping_sets_optimizer.cpp" + #include "src/optimizer/in_clause_rewriter.cpp" #include "src/optimizer/join_elimination.cpp" diff --git a/src/duckdb/ub_src_planner_filter.cpp b/src/duckdb/ub_src_planner_filter.cpp index 6a6f4169d..2e6f65ddc 100644 --- a/src/duckdb/ub_src_planner_filter.cpp +++ b/src/duckdb/ub_src_planner_filter.cpp @@ -14,8 +14,6 @@ #include "src/planner/filter/optional_filter.cpp" -#include "src/planner/filter/perfect_hash_join_filter.cpp" - #include "src/planner/filter/prefix_range_filter.cpp" #include "src/planner/filter/selectivity_optional_filter.cpp" @@ -30,8 +28,6 @@ #include "src/planner/filter/table_filter_optional_function.cpp" -#include "src/planner/filter/table_filter_perfect_hash_join_function.cpp" - #include "src/planner/filter/table_filter_prefix_range_function.cpp" #include "src/planner/filter/table_filter_selectivity_optional_function.cpp"