From 73ed85af60060ea9118b13d2b498f5b9b4e80e50 Mon Sep 17 00:00:00 2001 From: DuckDB Labs GitHub Bot Date: Mon, 15 Jun 2026 14:27:29 +0000 Subject: [PATCH] Update vendored DuckDB sources to 72e5a0f30c --- .../aggregate/distributive/bitagg.cpp | 15 +- .../aggregate/distributive/bitstring_agg.cpp | 5 +- .../aggregate/distributive/string_agg.cpp | 17 ++ .../aggregate/holistic/approx_top_k.cpp | 2 +- .../holistic/approximate_quantile.cpp | 32 +-- .../aggregate/holistic/mode.cpp | 10 +- .../aggregate/holistic/quantile.cpp | 2 + .../aggregate/holistic/reservoir_quantile.cpp | 42 +-- .../aggregate/nested/binned_histogram.cpp | 2 +- .../aggregate/nested/histogram.cpp | 3 +- .../core_functions/aggregate/nested/list.cpp | 2 +- .../aggregate/quantile_state.hpp | 36 +-- .../scalar/list/list_aggregates.cpp | 2 +- .../parquet/include/parquet_reader.hpp | 2 +- .../extension/parquet/parquet_reader.cpp | 16 +- ...nerated_extension_loader_package_build.cpp | 1 + .../catalog_entry/table_catalog_entry.cpp | 17 ++ .../common/row_operations/row_aggregate.cpp | 16 +- src/duckdb/src/common/types/time.cpp | 7 +- .../common/types/variant/variant_value.cpp | 9 + .../aggregate/physical_streaming_window.cpp | 3 +- .../physical_ungrouped_aggregate.cpp | 2 +- .../operator/projection/physical_pivot.cpp | 2 +- .../execution/physical_plan/plan_insert.cpp | 5 +- .../physical_plan/plan_merge_into.cpp | 3 + .../execution/radix_partitioned_hashtable.cpp | 2 +- .../function/aggregate/distributive/count.cpp | 3 +- .../aggregate/distributive/minmax.cpp | 3 +- .../aggregate/sorted_aggregate_function.cpp | 167 +++++++++--- .../src/function/aggregate_function.cpp | 36 ++- src/duckdb/src/function/function_list.cpp | 3 +- .../scalar/geometry/geometry_functions.cpp | 203 ++++++++++++++ .../scalar/system/aggregate_export.cpp | 252 ++++++++++++++---- .../src/function/table/arrow_conversion.cpp | 6 +- .../function/table/version/pragma_version.cpp | 6 +- .../window/window_aggregate_states.cpp | 2 +- .../window/window_naive_aggregator.cpp | 2 +- .../function/window/window_segment_tree.cpp | 2 +- .../catalog_entry/table_catalog_entry.hpp | 6 + .../include/duckdb/common/index_vector.hpp | 8 + .../common/row_operations/row_operations.hpp | 3 + .../duckdb/common/types/variant_value.hpp | 3 + .../vector_operations/aggregate_executor.hpp | 14 +- .../function/aggregate/minmax_n_helpers.hpp | 2 +- .../duckdb/function/aggregate_function.hpp | 61 ++++- .../duckdb/function/aggregate_state.hpp | 41 ++- .../function/aggregate_state_layout.hpp | 4 + .../function/scalar/geometry_functions.hpp | 12 +- .../function/scalar/system_functions.hpp | 4 +- .../src/include/duckdb/main/database.hpp | 5 + .../src/include/duckdb/planner/binder.hpp | 8 + .../planner/operator/logical_insert.hpp | 2 +- .../planner/operator/logical_merge_into.hpp | 2 +- .../storage/statistics/geometry_stats.hpp | 57 ++++ src/duckdb/src/include/duckdb_extension.h | 11 + .../src/main/capi/aggregate_function-c.cpp | 2 +- .../src/main/extension/extension_load.cpp | 37 +++ .../optimizer/partial_aggregate_pushdown.cpp | 4 + .../peg/transformer/transform_expression.cpp | 6 +- .../expression/bind_operator_expression.cpp | 11 +- .../planner/binder/query_node/plan_setop.cpp | 46 +++- .../planner/binder/statement/bind_copy.cpp | 31 +-- .../planner/binder/statement/bind_insert.cpp | 119 +++++++-- .../binder/statement/bind_merge_into.cpp | 49 ++-- .../planner/binder/statement/bind_update.cpp | 20 +- .../tableref/bind_expressionlistref.cpp | 37 ++- .../serialize_logical_operator.cpp | 16 +- 67 files changed, 1241 insertions(+), 320 deletions(-) diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp index d84e7483b..6834734b4 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp @@ -227,9 +227,8 @@ AggregateFunctionSet BitAndFun::GetFunctions() { for (auto &type : LogicalType::Integral()) { bit_and.AddFunction(GetBitfieldUnaryAggregate(type)); } - bit_and.AddFunction( - AggregateFunction::UnaryAggregateDestructor( - LogicalType::BIT, LogicalType::BIT)); + bit_and.AddFunction(AggregateFunction::UnaryAggregate( + LogicalType::BIT, LogicalType::BIT)); return bit_and; } @@ -238,9 +237,8 @@ AggregateFunctionSet BitOrFun::GetFunctions() { for (auto &type : LogicalType::Integral()) { bit_or.AddFunction(GetBitfieldUnaryAggregate(type)); } - bit_or.AddFunction( - AggregateFunction::UnaryAggregateDestructor( - LogicalType::BIT, LogicalType::BIT)); + bit_or.AddFunction(AggregateFunction::UnaryAggregate( + LogicalType::BIT, LogicalType::BIT)); return bit_or; } @@ -249,9 +247,8 @@ AggregateFunctionSet BitXorFun::GetFunctions() { for (auto &type : LogicalType::Integral()) { bit_xor.AddFunction(GetBitfieldUnaryAggregate(type)); } - bit_xor.AddFunction( - AggregateFunction::UnaryAggregateDestructor( - LogicalType::BIT, LogicalType::BIT)); + bit_xor.AddFunction(AggregateFunction::UnaryAggregate( + LogicalType::BIT, LogicalType::BIT)); return bit_xor; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp index 2c2d4837e..84ec14460 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp @@ -257,9 +257,8 @@ unique_ptr BindBitstringAgg(BindAggregateFunctionInput &input) { template void BindBitString(AggregateFunctionSet &bitstring_agg, const LogicalTypeId &type) { - auto function = - AggregateFunction::UnaryAggregateDestructor, TYPE, string_t, BitStringAggOperation>( - type, LogicalType::BIT); + auto function = AggregateFunction::UnaryAggregate, TYPE, string_t, BitStringAggOperation>( + type, LogicalType::BIT); function.SetBindCallback(BindBitstringAgg); // create new a 'BitstringAggBindData' function.SetSerializeCallback(BitstringAggBindData::Serialize); function.SetDeserializeCallback(BitstringAggBindData::Deserialize); diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp index 84663c72c..bbd1a1f62 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp @@ -13,6 +13,8 @@ namespace duckdb { namespace { struct StringAggState { + using STATE_TYPE = OptionalStateType>; + string_t value; bool is_set; uint32_t alloc_size; @@ -158,6 +160,20 @@ unique_ptr StringAggDeserialize(Deserializer &deserializer, BoundA return make_uniq(std::move(sep)); } +AggregateStateLayout StringAggStateType(AggregateLayoutInput &input) { + auto &function = input.function; + using ST = StringAggState::STATE_TYPE; + AggregateStateLayout layout; + layout.type = AggregateFunction::BuildStateLogical(function); + layout.total_state_size = AlignValue(sizeof(StringAggState)); + layout.field = BuildStateField(); + if (function.GetOriginalArguments().size() == 2) { + // record the value of the separator if explicitly provided + layout.constant_parameters.emplace(1, Value(input.bind_data->Cast().sep)); + } + return layout; +} + } // namespace AggregateFunctionSet StringAggFun::GetFunctions() { @@ -172,6 +188,7 @@ AggregateFunctionSet StringAggFun::GetFunctions() { FunctionNullHandling::DEFAULT_NULL_HANDLING, AggregateFunction::NoClusterUpdate(), StringAggBind); string_agg_param.SetSerializeCallback(StringAggSerialize); string_agg_param.SetDeserializeCallback(StringAggDeserialize); + string_agg_param.SetStructStateExport(StringAggStateType); string_agg.AddFunction(string_agg_param); string_agg_param.GetSignature().AddParameter(LogicalType::VARCHAR); string_agg.AddFunction(string_agg_param); diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp index ba1030fe3..985e8392c 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp @@ -335,7 +335,7 @@ void ApproxTopKUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t inp } template -void ApproxTopKFinalize(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) { +void ApproxTopKFinalize(Vector &state_vector, AggregateFinalizeInputData &, Vector &result, idx_t count, idx_t offset) { auto states = state_vector.Values(); auto old_len = ListVector::GetListSize(result); diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp index 90aec94a9..50aac35aa 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp @@ -169,31 +169,31 @@ struct ApproxQuantileScalarOperation : public ApproxQuantileOperation { AggregateFunction GetApproximateQuantileAggregateFunction(const LogicalType &type) { // Not binary comparable if (type == LogicalType::TIME_TZ) { - return AggregateFunction::UnaryAggregateDestructor(type, type); + return AggregateFunction::UnaryAggregate(type, type); } switch (type.InternalType()) { case PhysicalType::INT8: - return AggregateFunction::UnaryAggregateDestructor(type, type); + return AggregateFunction::UnaryAggregate( + type, type); case PhysicalType::INT16: - return AggregateFunction::UnaryAggregateDestructor(type, type); + return AggregateFunction::UnaryAggregate( + type, type); case PhysicalType::INT32: - return AggregateFunction::UnaryAggregateDestructor(type, type); + return AggregateFunction::UnaryAggregate( + type, type); case PhysicalType::INT64: - return AggregateFunction::UnaryAggregateDestructor(type, type); + return AggregateFunction::UnaryAggregate( + type, type); case PhysicalType::INT128: - return AggregateFunction::UnaryAggregateDestructor(type, type); + return AggregateFunction::UnaryAggregate(type, type); case PhysicalType::FLOAT: - return AggregateFunction::UnaryAggregateDestructor(type, type); + return AggregateFunction::UnaryAggregate( + type, type); case PhysicalType::DOUBLE: - return AggregateFunction::UnaryAggregateDestructor(type, type); + return AggregateFunction::UnaryAggregate( + type, type); default: throw InternalException("Unimplemented quantile aggregate"); } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp index 11ecfabe6..d1203e788 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp @@ -462,9 +462,8 @@ template > AggregateFunction GetTypedModeFunction(const LogicalType &type) { using STATE = ModeState; using OP = ModeFunction; - auto func = - AggregateFunction::UnaryAggregateDestructor( - type, type); + auto func = AggregateFunction::UnaryAggregate( + type, type); func.SetWindowBatchCallback(OP::template Window); return func; } @@ -563,9 +562,8 @@ template > AggregateFunction GetTypedEntropyFunction(const LogicalType &type) { using STATE = ModeState; using OP = EntropyFunction; - auto func = - AggregateFunction::UnaryAggregateDestructor( - type, LogicalType::DOUBLE); + auto func = AggregateFunction::UnaryAggregate( + type, LogicalType::DOUBLE); func.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); return func; } diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp index 0a8de1852..41bb35d07 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp @@ -441,6 +441,7 @@ struct ScalarDiscreteQuantile { QuantileSortKeyUpdate, ListCombineFunction, AggregateFunction::StateVoidFinalize, nullptr, nullptr, AggregateFunction::StateDestroy); + fun.SetInitLocalStateFinalizeCallback(FlattenedQuantileValues::Init); return fun; } }; @@ -468,6 +469,7 @@ struct ListDiscreteQuantile { QuantileSortKeyUpdate, ListCombineFunction, AggregateFunction::StateFinalize, nullptr, nullptr, AggregateFunction::StateDestroy); + fun.SetInitLocalStateFinalizeCallback(FlattenedQuantileValues::Init); return fun; } }; diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp index 95fc63b02..6388962ab 100644 --- a/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp +++ b/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp @@ -167,37 +167,37 @@ struct ReservoirQuantileScalarOperation : public ReservoirQuantileOperation { AggregateFunction GetReservoirQuantileAggregateFunction(PhysicalType type) { switch (type) { case PhysicalType::INT8: - return AggregateFunction::UnaryAggregateDestructor, int8_t, int8_t, - ReservoirQuantileScalarOperation>(LogicalType::TINYINT, - LogicalType::TINYINT); + return AggregateFunction::UnaryAggregate, int8_t, int8_t, + ReservoirQuantileScalarOperation>(LogicalType::TINYINT, + LogicalType::TINYINT); case PhysicalType::INT16: - return AggregateFunction::UnaryAggregateDestructor, int16_t, int16_t, - ReservoirQuantileScalarOperation>(LogicalType::SMALLINT, - LogicalType::SMALLINT); + return AggregateFunction::UnaryAggregate, int16_t, int16_t, + ReservoirQuantileScalarOperation>(LogicalType::SMALLINT, + LogicalType::SMALLINT); case PhysicalType::INT32: - return AggregateFunction::UnaryAggregateDestructor, int32_t, int32_t, - ReservoirQuantileScalarOperation>(LogicalType::INTEGER, - LogicalType::INTEGER); + return AggregateFunction::UnaryAggregate, int32_t, int32_t, + ReservoirQuantileScalarOperation>(LogicalType::INTEGER, + LogicalType::INTEGER); case PhysicalType::INT64: - return AggregateFunction::UnaryAggregateDestructor, int64_t, int64_t, - ReservoirQuantileScalarOperation>(LogicalType::BIGINT, - LogicalType::BIGINT); + return AggregateFunction::UnaryAggregate, int64_t, int64_t, + ReservoirQuantileScalarOperation>(LogicalType::BIGINT, + LogicalType::BIGINT); case PhysicalType::INT128: - return AggregateFunction::UnaryAggregateDestructor, hugeint_t, hugeint_t, - ReservoirQuantileScalarOperation>(LogicalType::HUGEINT, - LogicalType::HUGEINT); + return AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, + ReservoirQuantileScalarOperation>(LogicalType::HUGEINT, + LogicalType::HUGEINT); case PhysicalType::FLOAT: - return AggregateFunction::UnaryAggregateDestructor, float, float, - ReservoirQuantileScalarOperation>(LogicalType::FLOAT, - LogicalType::FLOAT); + return AggregateFunction::UnaryAggregate, float, float, + ReservoirQuantileScalarOperation>(LogicalType::FLOAT, + LogicalType::FLOAT); case PhysicalType::DOUBLE: - return AggregateFunction::UnaryAggregateDestructor, double, double, - ReservoirQuantileScalarOperation>(LogicalType::DOUBLE, - LogicalType::DOUBLE); + return AggregateFunction::UnaryAggregate, double, double, + ReservoirQuantileScalarOperation>(LogicalType::DOUBLE, + LogicalType::DOUBLE); default: throw InternalException("Unimplemented reservoir quantile aggregate"); } diff --git a/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp b/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp index f6ff7a597..0313d3418 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp @@ -271,7 +271,7 @@ void IsHistogramOtherBinFunction(DataChunk &args, ExpressionState &state, Vector } template -void HistogramBinFinalizeFunction(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, +void HistogramBinFinalizeFunction(Vector &state_vector, AggregateFinalizeInputData &, Vector &result, idx_t count, idx_t offset) { auto states = state_vector.Values *>(); diff --git a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp index 298ce0b8d..3038509d6 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp @@ -84,7 +84,8 @@ void HistogramUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, id } template -void HistogramFinalizeFunction(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) { +void HistogramFinalizeFunction(Vector &state_vector, AggregateFinalizeInputData &, Vector &result, idx_t count, + idx_t offset) { using HIST_STATE = HistogramAggState; auto states = state_vector.Values(); diff --git a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp index 7d3bc0012..79d7f5bbd 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp @@ -5,7 +5,7 @@ namespace duckdb { namespace { -void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Vector &result, idx_t count, +void ListFinalize(Vector &states_vector, AggregateFinalizeInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { auto states = states_vector.Values(); diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp index 312e00ab3..3bac85bfc 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp @@ -18,24 +18,26 @@ namespace duckdb { //! Flattens the values of a linked list into a contiguous array for interpolation. //! The flattened values are mutable - the interpolators partially sort them in place. -//! The flattened chunk is cached in the finalize data's local state, so that it is allocated (at most) once per -//! result chunk instead of once per finalized group. +//! The flattened chunk lives in the finalize's local state, so that it is allocated (at most) once per +//! finalize call instead of once per finalized group - callers that keep the local state alive re-use it +//! across finalize calls. template struct FlattenedQuantileValues : FunctionLocalState { - FlattenedQuantileValues(const LogicalType &type, idx_t capacity_p) : capacity(capacity_p) { - chunk.Initialize(Allocator::DefaultAllocator(), {type}, capacity_p); + FlattenedQuantileValues() : capacity(0) { } - //! Flatten the values of the given linked list into the chunk cached in the finalize data + static unique_ptr Init(const BoundAggregateFunction &, optional_ptr) { + return make_uniq(); + } + + //! Flatten the values of the given linked list into the chunk cached in the finalize local state static FlattenedQuantileValues &Flatten(AggregateFinalizeData &finalize_data, const LinkedList &linked_list) { const auto type = PrimitiveToLogicalType(); const auto required_capacity = MaxValue(linked_list.total_capacity, 1); - if (!finalize_data.local_state) { - finalize_data.local_state = make_uniq(type, NextPowerOfTwo(required_capacity)); - } - auto &values = finalize_data.local_state->Cast(); + D_ASSERT(finalize_data.input.local_state); + auto &values = finalize_data.input.local_state->Cast(); if (values.capacity < required_capacity) { - // grow the cached chunk + // (re-)allocate the cached chunk values.capacity = NextPowerOfTwo(required_capacity); values.chunk.Destroy(); values.chunk.Initialize(Allocator::DefaultAllocator(), {type}, values.capacity); @@ -125,12 +127,14 @@ struct QuantileOperation { //! Quantiles ignore NULL values, so they are filtered out while appending. template AggregateFunction QuantileBufferingAggregate(const LogicalType &input_type, const LogicalType &result_type) { - return AggregateFunction({input_type}, result_type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - ListUpdateFunction, ListCombineFunction, - AggregateFunction::StateFinalize, - FunctionNullHandling::DEFAULT_NULL_HANDLING, AggregateFunction::NoClusterUpdate(), - AggregateFunction::NoBind(), AggregateFunction::StateDestroy); + AggregateFunction fun({input_type}, result_type, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, + ListUpdateFunction, ListCombineFunction, + AggregateFunction::StateFinalize, + FunctionNullHandling::DEFAULT_NULL_HANDLING, AggregateFunction::NoClusterUpdate(), + AggregateFunction::NoBind(), AggregateFunction::StateDestroy); + fun.SetInitLocalStateFinalizeCallback(FlattenedQuantileValues::Init); + return fun; } template diff --git a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp index c760f971b..08644c237 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp @@ -218,7 +218,7 @@ void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &res auto &aggr = info.aggr_expr->Cast(); auto &allocator = ExecuteFunctionState::GetFunctionState(state)->Cast().arena_allocator; allocator.Reset(); - AggregateInputData aggr_input_data(aggr, allocator); + AggregateFinalizeInputData aggr_input_data(aggr, allocator); D_ASSERT(aggr.Function().HasStateUpdateCallback()); diff --git a/src/duckdb/extension/parquet/include/parquet_reader.hpp b/src/duckdb/extension/parquet/include/parquet_reader.hpp index a684ff253..14e75d629 100644 --- a/src/duckdb/extension/parquet/include/parquet_reader.hpp +++ b/src/duckdb/extension/parquet/include/parquet_reader.hpp @@ -398,7 +398,7 @@ class ParquetReader : public BaseFileReader { idx_t GetGroupOffset(ParquetReaderScanState &state); //! Group span is the distance between the min page offset and the max page offset plus the max page compressed size uint64_t GetGroupSpan(ParquetReaderScanState &state); - void PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t out_col_idx); + void PrepareRowGroupBuffer(ClientContext &context, ParquetReaderScanState &state, idx_t out_col_idx); //! Whole-group prefetch strategy. ParquetPrefetchStrategy WholeGroupPrefetch(ParquetReaderScanState &state, ThriftFileTransport &trans, const duckdb_parquet::RowGroup &group, uint64_t total_row_group_span, diff --git a/src/duckdb/extension/parquet/parquet_reader.cpp b/src/duckdb/extension/parquet/parquet_reader.cpp index ef008e162..71631df59 100644 --- a/src/duckdb/extension/parquet/parquet_reader.cpp +++ b/src/duckdb/extension/parquet/parquet_reader.cpp @@ -1301,8 +1301,8 @@ idx_t ParquetReader::GetGroupOffset(ParquetReaderScanState &state) { return min_offset; } -static FilterPropagateResult CheckParquetFloatFilter(ColumnReader &reader, const Statistics &pq_col_stats, - const TableFilter &filter) { +static FilterPropagateResult CheckParquetFloatFilter(ClientContext &context, ColumnReader &reader, + const Statistics &pq_col_stats, const TableFilter &filter) { // floating point values can have values in the [min, max] domain AND nan values // check both stats against the filter auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "CheckParquetFloatFilter"); @@ -1311,10 +1311,10 @@ static FilterPropagateResult CheckParquetFloatFilter(ColumnReader &reader, const auto nan_value = Value("nan").DefaultCastAs(type); NumericStats::SetMin(nan_stats, nan_value); NumericStats::SetMax(nan_stats, nan_value); - auto nan_prune = expr_filter.CheckStatistics(nan_stats); + auto nan_prune = expr_filter.CheckStatistics(context, nan_stats); auto min_max_stats = ParquetStatisticsUtils::CreateNumericStats(reader.Type(), reader.Schema(), pq_col_stats); - auto prune = expr_filter.CheckStatistics(*min_max_stats); + auto prune = expr_filter.CheckStatistics(context, *min_max_stats); // if EITHER of them cannot be pruned - we cannot prune if (prune == FilterPropagateResult::NO_PRUNING_POSSIBLE || @@ -1334,7 +1334,7 @@ ColumnReader &ParquetReaderScanState::GetColumnReader(idx_t i) { return *column_readers[i]; } -void ParquetReader::PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t i) { +void ParquetReader::PrepareRowGroupBuffer(ClientContext &context, ParquetReaderScanState &state, idx_t i) { auto &group = GetGroup(state); auto col_idx = MultiFileLocalIndex(i); auto &column_reader = state.GetColumnReader(col_idx); @@ -1372,12 +1372,12 @@ void ParquetReader::PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t i // floating point columns can have NaN values in addition to the min/max bounds defined in the file // in order to do optimal pruning - we prune based on the [min, max] of the file followed by pruning // based on nan - prune_result = CheckParquetFloatFilter(column_reader, + prune_result = CheckParquetFloatFilter(context, column_reader, group.columns[schema_column_index].meta_data.statistics, filter); } else { auto &expr_filter = ExpressionFilter::GetExpressionFilter(filter, "ParquetReader::PrepareRowGroupBuffer"); - prune_result = expr_filter.CheckStatistics(*stats); + prune_result = expr_filter.CheckStatistics(context, *stats); } // check the bloom filter if present if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE && !column_reader.Type().IsNested() && @@ -1745,7 +1745,7 @@ AsyncResult ParquetReader::Schedule(ClientContext &context, ParquetReaderScanSta uint64_t to_scan_compressed_bytes = 0; for (idx_t i = 0; i < column_ids.size(); i++) { auto col_idx = MultiFileLocalIndex(i); - PrepareRowGroupBuffer(state, col_idx); + PrepareRowGroupBuffer(context, state, col_idx); to_scan_compressed_bytes += state.GetColumnReader(i).TotalCompressedSize(); } diff --git a/src/duckdb/generated_extension_loader_package_build.cpp b/src/duckdb/generated_extension_loader_package_build.cpp index 8f6906398..41aecec14 100644 --- a/src/duckdb/generated_extension_loader_package_build.cpp +++ b/src/duckdb/generated_extension_loader_package_build.cpp @@ -29,6 +29,7 @@ #include "duckdb/main/extension/generated_extension_loader.hpp" #include "duckdb/main/extension_helper.hpp" + namespace duckdb { //! Looks through the package_build.py-generated list of extensions that are linked into DuckDB currently to try load diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp index 9fcec52c0..e5d85fce4 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -16,6 +16,8 @@ #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/planner/operator/logical_update.hpp" #include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/planner/column_binding.hpp" +#include "duckdb/planner/expression/bound_cast_expression.hpp" #include @@ -352,6 +354,21 @@ bool TableCatalogEntry::HasPrimaryKey() const { return GetPrimaryKey() != nullptr; } +LogicalType TableCatalogEntry::GetExpectedTypeForInsert(const ColumnDefinition &column) const { + return column.Type(); +} + +unique_ptr TableCatalogEntry::GetDefaultExpressionForColumn(ClientContext &context, + const LogicalType &input_type, + const LogicalType &result_type, + ColumnBinding binding, + const Expression &constant_value) const { + (void)context; + (void)constant_value; + return BoundCastExpression::AddCastToType(context, make_uniq(input_type, binding), + result_type); +} + virtual_column_map_t TableCatalogEntry::GetVirtualColumns() const { virtual_column_map_t virtual_columns; virtual_columns.insert(make_pair(COLUMN_IDENTIFIER_ROW_ID, TableColumn("rowid", LogicalType::ROW_TYPE))); diff --git a/src/duckdb/src/common/row_operations/row_aggregate.cpp b/src/duckdb/src/common/row_operations/row_aggregate.cpp index 06e94422d..933180295 100644 --- a/src/duckdb/src/common/row_operations/row_aggregate.cpp +++ b/src/duckdb/src/common/row_operations/row_aggregate.cpp @@ -161,11 +161,23 @@ void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &l VectorOperations::AddInPlace(addresses_copy, UnsafeNumericCast(layout.GetAggrOffset())); auto &aggregates = layout.GetAggregates(); + // initialize the finalize local states once - they are re-used across all finalize calls of this state + if (state.local_states.size() < aggregates.size()) { + state.local_states.resize(aggregates.size()); + for (idx_t i = 0; i < aggregates.size(); i++) { + auto &callbacks = aggregates[i].function.GetCallbacks(); + if (callbacks.HasInitLocalStateFinalizeCallback()) { + AggregateInputData aggr_input_data(aggregates[i], state.allocator); + state.local_states[i] = + callbacks.GetInitLocalStateFinalizeCallback()(aggr_input_data.function, aggr_input_data.bind_data); + } + } + } for (idx_t i = 0; i < aggregates.size(); i++) { auto &target = result.data[aggr_idx + i]; auto &aggr = aggregates[i]; - AggregateInputData aggr_input_data(aggr, state.allocator); - aggr.function.GetStateFinalizeCallback()(addresses_copy, aggr_input_data, target, result.size(), 0); + AggregateFinalizeInputData finalize_input_data(aggr, state.allocator, state.local_states[i].get()); + aggr.function.GetStateFinalizeCallback()(addresses_copy, finalize_input_data, target, result.size(), 0); FlatVector::SetSize(target, count_t(result.size())); // Move to the next aggregate state diff --git a/src/duckdb/src/common/types/time.cpp b/src/duckdb/src/common/types/time.cpp index 971996da6..ee7d45e26 100644 --- a/src/duckdb/src/common/types/time.cpp +++ b/src/duckdb/src/common/types/time.cpp @@ -77,7 +77,12 @@ bool Time::TryConvertInternal(const char *buf, idx_t len, idx_t &pos, dtime_t &r if (pos > len) { return false; } - if (pos == len && (!strict || sep_pos + 2 == pos)) { + if (pos == len) { + // no seconds field: in strict mode this is only valid for a two-digit minute (HH:MM), + // otherwise there is no separator to read and we must not look past the end of the buffer + if (strict && sep_pos + 2 != pos) { + return false; + } sec = 0; } else { if (buf[pos++] != sep) { diff --git a/src/duckdb/src/common/types/variant/variant_value.cpp b/src/duckdb/src/common/types/variant/variant_value.cpp index b2766a4bf..47e60f7a5 100644 --- a/src/duckdb/src/common/types/variant/variant_value.cpp +++ b/src/duckdb/src/common/types/variant/variant_value.cpp @@ -81,6 +81,15 @@ const vector &VariantValue::ArrayItems() const { return array_items; } +Value VariantValue::GetValue(const Value &variant_val) { + D_ASSERT(variant_val.type().id() == LogicalTypeId::VARIANT && !variant_val.IsNull()); + Vector tmp(variant_val, count_t(1)); + RecursiveUnifiedVectorFormat format; + Vector::RecursiveToUnifiedFormat(tmp, format); + UnifiedVariantVectorData vector_data(format); + return VariantUtils::ConvertVariantToValue(vector_data, 0, 0); +} + static void AnalyzeValue(const VariantValue &value, idx_t row, DataChunk &offsets) { auto &keys_offset = variant::OffsetData::GetKeys(offsets)[row]; auto &children_offset = variant::OffsetData::GetChildren(offsets)[row]; diff --git a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp index b13e08b03..02afe4b70 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp @@ -289,7 +289,8 @@ void StreamingWindowState::AggregateState::Execute(ExecutionContext &context, Da } // Update the state and finalize it one row at a time. - AggregateInputData aggr_input_data(*wexpr.AggregateFunction(), wexpr.BindInfo().get(), aggr_state.arena_allocator); + AggregateFinalizeInputData aggr_input_data(*wexpr.AggregateFunction(), wexpr.BindInfo().get(), + aggr_state.arena_allocator); for (idx_t i = 0; i < count; ++i) { sel.set_index(0, i); for (const auto struct_idx : structs) { diff --git a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index 2a3098c1d..a59d46bd0 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -683,7 +683,7 @@ void GlobalUngroupedAggregateState::Finalize(DataChunk &result, idx_t column_off for (idx_t aggr_idx = 0; aggr_idx < state.functions.size(); aggr_idx++) { auto &func = state.functions[aggr_idx]; Vector state_vector(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get())), count_t(1)); - AggregateInputData aggr_input_data(func, state.bind_data[aggr_idx].get(), allocator); + AggregateFinalizeInputData aggr_input_data(func, state.bind_data[aggr_idx].get(), allocator); func.GetStateFinalizeCallback()(state_vector, aggr_input_data, result.data[column_offset + aggr_idx], 1, 0); FlatVector::SetSize(result.data[column_offset + aggr_idx], count_t(1)); } diff --git a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp index 52188c8a8..c44170a23 100644 --- a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp @@ -27,7 +27,7 @@ PhysicalPivot::PhysicalPivot(PhysicalPlan &physical_plan, vector ty aggr.Function().GetStateInitCallback()(aggr.Function(), state.get()); Vector state_vector(Value::POINTER(CastPointerToValue(state.get())), count_t(1)); Vector result_vector(aggr_expr->GetReturnType()); - AggregateInputData aggr_input_data(aggr, physical_plan.ArenaRef()); + AggregateFinalizeInputData aggr_input_data(aggr, physical_plan.ArenaRef()); aggr.Function().GetStateFinalizeCallback()(state_vector, aggr_input_data, result_vector, 1, 0); empty_aggregates.push_back(result_vector.GetValue(0)); } diff --git a/src/duckdb/src/execution/physical_plan/plan_insert.cpp b/src/duckdb/src/execution/physical_plan/plan_insert.cpp index e49f7c6f6..e882fe110 100644 --- a/src/duckdb/src/execution/physical_plan/plan_insert.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_insert.cpp @@ -5,11 +5,11 @@ #include "duckdb/planner/operator/logical_insert.hpp" #include "duckdb/main/config.hpp" #include "duckdb/execution/operator/persistent/physical_batch_insert.hpp" -#include "duckdb/execution/operator/projection/physical_projection.hpp" #include "duckdb/parallel/task_scheduler.hpp" #include "duckdb/catalog/duck_catalog.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/main/settings.hpp" +#include "duckdb/execution/operator/projection/physical_projection.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" namespace duckdb { @@ -118,6 +118,7 @@ PhysicalOperator &DuckCatalog::PlanInsert(ClientContext &context, PhysicalPlanGe parallel_streaming_insert = false; } if (!op.column_index_map.empty()) { + //! Deprecated: The column_index_map is only populated by older versions. plan = planner.ResolveDefaultsProjection(op, *plan); } if (use_batch_index && !parallel_streaming_insert) { diff --git a/src/duckdb/src/execution/physical_plan/plan_merge_into.cpp b/src/duckdb/src/execution/physical_plan/plan_merge_into.cpp index 9fee8beb8..e8f64c139 100644 --- a/src/duckdb/src/execution/physical_plan/plan_merge_into.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_merge_into.cpp @@ -61,8 +61,10 @@ unique_ptr PlanMergeIntoAction(ClientContext &context, Logica std::move(set_expressions), std::move(set_columns), std::move(set_types), cardinality, op.return_chunk, !op.return_chunk, OnConflictAction::THROW, nullptr, nullptr, std::move(on_conflict_filter), std::move(columns_to_fetch), false); + // transform expressions if required if (!action.column_index_map.empty()) { + //! Deprecated: plan expressions for default expressions, now set at bind time vector> new_expressions; for (auto &col : op.table.GetColumns().Physical()) { auto storage_idx = col.StorageOid(); @@ -77,6 +79,7 @@ unique_ptr PlanMergeIntoAction(ClientContext &context, Logica } action.expressions = std::move(new_expressions); } + result->expressions = std::move(action.expressions); break; } diff --git a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp index a3a559fd8..c9d7ea8e5 100644 --- a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp +++ b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp @@ -1076,7 +1076,7 @@ SourceResultType RadixPartitionedHashTable::GetData(ExecutionContext &context, D aggr.Function().GetStateSizeCallback()(aggr.Function())); aggr.Function().GetStateInitCallback()(aggr.Function(), aggr_state.get()); - AggregateInputData aggr_input_data(aggr, allocator); + AggregateFinalizeInputData aggr_input_data(aggr, allocator); Vector state_vector(Value::POINTER(CastPointerToValue(aggr_state.get())), count_t(1)); auto &agg_result = chunk.data[null_groups.size() + i]; aggr.Function().GetStateFinalizeCallback()(state_vector, aggr_input_data, agg_result, 1, 0); diff --git a/src/duckdb/src/function/aggregate/distributive/count.cpp b/src/duckdb/src/function/aggregate/distributive/count.cpp index aae3baf67..be3b329bc 100644 --- a/src/duckdb/src/function/aggregate/distributive/count.cpp +++ b/src/duckdb/src/function/aggregate/distributive/count.cpp @@ -256,7 +256,8 @@ struct CountFunction : public BaseCountFunction { } }; -AggregateStateLayout GetCountStateType(const BoundAggregateFunction &function) { +AggregateStateLayout GetCountStateType(AggregateLayoutInput &input) { + auto &function = input.function; return AggregateStateLayout(LogicalType::BIGINT, AlignValue(function.GetStateSizeCallback()(function))); } diff --git a/src/duckdb/src/function/aggregate/distributive/minmax.cpp b/src/duckdb/src/function/aggregate/distributive/minmax.cpp index 83135ccde..cdccf99cd 100644 --- a/src/duckdb/src/function/aggregate/distributive/minmax.cpp +++ b/src/duckdb/src/function/aggregate/distributive/minmax.cpp @@ -531,7 +531,8 @@ AggregateFunction GetMinMaxNFunction() { MinMaxNBind, nullptr); } -AggregateStateLayout GetExportStateType(const BoundAggregateFunction &function) { +AggregateStateLayout GetExportStateType(AggregateLayoutInput &input) { + auto &function = input.function; child_list_t struct_children_types; struct_children_types.emplace_back("value", function.GetReturnType()); struct_children_types.emplace_back("is_set", LogicalType::BOOLEAN); diff --git a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp index ce3c7c106..e02bfc322 100644 --- a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp +++ b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp @@ -167,6 +167,52 @@ struct SortedAggregateBindData : public FunctionData { //! The sorted aggregate buffers its input rows in a linked list of structs, sharing the "list" callbacks struct SortedAggregateState : ListAggState {}; +//! Caches the chunks, contexts and inner aggregate state used while finalizing the groups of a sorted aggregate. +//! When the caller provides a local state slot (e.g. the hash table scan), this state survives across finalize +//! calls instead of being re-instantiated for every result chunk. +struct SortedAggregateFinalizeState : FunctionLocalState { + explicit SortedAggregateFinalizeState(const SortedAggregateBindData &order_bind) + : thread(order_bind.context), context(order_bind.context, thread, nullptr), + agg_state(order_bind.function.GetCallbacks().GetStateSizeCallback()(order_bind.function)), + agg_state_vec(Value::POINTER(CastPointerToValue(agg_state.data())), count_t(1)) { + auto &buffer_allocator = BufferManager::GetBufferManager(order_bind.context).GetBufferAllocator(); + rows.Initialize(buffer_allocator, {order_bind.buffered_struct_type}); + scanned.Initialize(buffer_allocator, order_bind.scan_types); + sliced.Initialize(buffer_allocator, order_bind.scan_types); + prefixed.Initialize(buffer_allocator, order_bind.sort_types); + + // The local state of the inner aggregate's finalize is kept alive across finalize calls as well + const auto &callbacks = order_bind.function.GetCallbacks(); + if (callbacks.HasInitLocalStateFinalizeCallback()) { + inner_local_state = + callbacks.GetInitLocalStateFinalizeCallback()(order_bind.function, order_bind.bind_info.get()); + } + } + + static unique_ptr Init(const BoundAggregateFunction &, optional_ptr bind_data) { + return make_uniq(bind_data->Cast()); + } + + //! The execution context for the sort operator + ThreadContext thread; + ExecutionContext context; + InterruptState interrupt; + //! The buffered rows of (possibly many) groups, accumulated before they are sunk into the sort + DataChunk rows; + //! The chunk for scanning the sorted data + DataChunk scanned; + //! The scanned data sliced to the rows of a single group + DataChunk sliced; + //! The sink chunk holding the buffered rows prefixed with the group number + DataChunk prefixed; + //! The state of the inner aggregate + vector agg_state; + //! A vector pointing to the inner aggregate state + Vector agg_state_vec; + //! The local state used by the inner aggregate's finalize (may be null) + unique_ptr inner_local_state; +}; + struct SortedAggregateFunction { static LogicalType GetElementType(AggregateInputData &aggr_input_data) { return aggr_input_data.bind_data->Cast().buffered_struct_type; @@ -177,7 +223,7 @@ struct SortedAggregateFunction { if (!count) { return; } - // Pack the buffered columns into a single struct vector for the list update + // Pack the buffered columns into a single struct vector and append the rows through the list update const auto &order_bind = aggr_input_data.bind_data->Cast(); Vector packed(order_bind.buffered_struct_type, count); auto &entries = StructVector::GetEntries(packed); @@ -204,51 +250,89 @@ struct SortedAggregateFunction { } } - //! Sinks all buffered rows of the state into the sort, prefixed with the group number - static void SinkState(const SortedAggregateBindData &order_bind, SortedAggregateState &state, idx_t group_number, - ExecutionContext &context, OperatorSinkInput &sink, DataChunk &prefixed) { - auto &sort = *order_bind.sort; - ListSegmentScanState scan_state; - order_bind.buffered_funcs.InitializeScan(state.linked_list, scan_state); - for (;;) { - Vector rows(order_bind.buffered_struct_type, STANDARD_VECTOR_SIZE); - const auto chunk_count = order_bind.buffered_funcs.Scan(scan_state, rows); - if (!chunk_count) { - break; + //! Sinks the rows accumulated in the rows chunk into the sort, prefixed with their group numbers + static void FlushAccumulated(const SortedAggregateBindData &order_bind, idx_t &accumulated, + SortedAggregateFinalizeState &finalize_state, ExecutionContext &context, + OperatorSinkInput &sink) { + if (!accumulated) { + return; + } + auto &prefixed = finalize_state.prefixed; + FlatVector::SetSize(prefixed.data[0], count_t(accumulated)); + auto &entries = StructVector::GetEntries(finalize_state.rows.data[0]); + for (column_t col_idx = 0; col_idx < entries.size(); ++col_idx) { + prefixed.data[col_idx + 1].Reference(entries[col_idx]); + FlatVector::SetSize(prefixed.data[col_idx + 1], count_t(accumulated)); + } + order_bind.sort->Sink(context, prefixed, sink); + finalize_state.rows.Reset(); + accumulated = 0; + } + + //! Buffers the rows of the state into the cached rows chunk, prefixed with the group number, flushing into + //! the sort whenever the chunk fills up - this batches many small groups into a single sink call + static void SinkState(const SortedAggregateBindData &order_bind, SortedAggregateState &state, + const idx_t group_number, idx_t &accumulated, SortedAggregateFinalizeState &finalize_state, + ExecutionContext &context, OperatorSinkInput &sink) { + const auto group_count = state.linked_list.total_capacity; + if (!group_count) { + return; + } + auto &rows = finalize_state.rows.data[0]; + auto group_numbers = FlatVector::GetDataMutable(finalize_state.prefixed.data[0]); + if (group_count <= STANDARD_VECTOR_SIZE) { + // The group fits in the rows chunk - flush first if there is not enough space left + if (accumulated + group_count > STANDARD_VECTOR_SIZE) { + FlushAccumulated(order_bind, accumulated, finalize_state, context, sink); } - auto &entries = StructVector::GetEntries(rows); - prefixed.Reset(); - prefixed.data[0].Reference(Value::USMALLINT(UnsafeNumericCast(group_number)), count_t(1)); - FlatVector::SetSize(prefixed.data[0], count_t(chunk_count)); - for (column_t col_idx = 0; col_idx < entries.size(); ++col_idx) { - prefixed.data[col_idx + 1].Reference(entries[col_idx]); - FlatVector::SetSize(prefixed.data[col_idx + 1], count_t(chunk_count)); + // Append the group's rows to the accumulated rows + order_bind.buffered_funcs.BuildListVector(state.linked_list, rows, accumulated); + for (idx_t i = 0; i < group_count; ++i) { + group_numbers[accumulated + i] = UnsafeNumericCast(group_number); + } + accumulated += group_count; + } else { + // The group does not fit in a single chunk - flush, then stream it chunk at a time + FlushAccumulated(order_bind, accumulated, finalize_state, context, sink); + ListSegmentScanState scan_state; + order_bind.buffered_funcs.InitializeScan(state.linked_list, scan_state); + for (;;) { + const auto chunk_count = order_bind.buffered_funcs.Scan(scan_state, rows); + if (!chunk_count) { + break; + } + for (idx_t i = 0; i < chunk_count; ++i) { + group_numbers[i] = UnsafeNumericCast(group_number); + } + accumulated = chunk_count; + FlushAccumulated(order_bind, accumulated, finalize_state, context, sink); } - sort.Sink(context, prefixed, sink); } // Release the state - the rows are freed with the arena allocator state.linked_list = LinkedList(); } - static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + static void Finalize(Vector &states, AggregateFinalizeInputData &finalize_input_data, Vector &result, idx_t count, const idx_t offset) { - auto &order_bind = aggr_input_data.bind_data->Cast(); + auto &order_bind = finalize_input_data.bind_data->Cast(); auto &client = order_bind.context; - auto &buffer_allocator = BufferManager::GetBufferManager(client).GetBufferAllocator(); - DataChunk scanned; - scanned.Initialize(buffer_allocator, order_bind.scan_types); - DataChunk sliced; - sliced.Initialize(buffer_allocator, order_bind.scan_types); - - // Reusable inner state - auto &aggr = order_bind.function; - vector agg_state(aggr.GetCallbacks().GetStateSizeCallback()(aggr)); - Vector agg_state_vec(Value::POINTER(CastPointerToValue(agg_state.data())), count_t(1)); + // The local state holds the chunks and contexts - callers can keep it alive across finalize calls + // so they do not have to be re-instantiated for every finalize call + D_ASSERT(finalize_input_data.local_state); + auto &finalize_state = finalize_input_data.local_state->Cast(); + auto &scanned = finalize_state.scanned; + auto &sliced = finalize_state.sliced; + auto &agg_state = finalize_state.agg_state; + auto &agg_state_vec = finalize_state.agg_state_vec; + auto &context = finalize_state.context; + auto &interrupt = finalize_state.interrupt; // State variables + auto &aggr = order_bind.function; auto bind_info = order_bind.bind_info.get(); - AggregateInputData aggr_bind_info(aggr, bind_info, aggr_input_data.allocator); + AggregateFinalizeInputData aggr_bind_info(aggr, bind_info, finalize_input_data.allocator, + finalize_state.inner_local_state.get()); // Inner aggregate APIs auto initialize = aggr.GetCallbacks().GetStateInitCallback(); @@ -264,24 +348,19 @@ struct SortedAggregateFunction { state_unprocessed[i] = sdata[i].GetValueUnsafe()->linked_list.total_capacity; } - ThreadContext thread(client); - ExecutionContext context(client, thread, nullptr); - InterruptState interrupt; auto &sort = order_bind.sort; auto global_sink = sort->GetGlobalSinkState(client); auto local_sink = sort->GetLocalSinkState(context); - DataChunk prefixed; - prefixed.Initialize(buffer_allocator, order_bind.sort_types); - // Go through the states accumulating values to sort until we hit the sort threshold idx_t unsorted_count = 0; idx_t sorted = 0; + idx_t accumulated = 0; for (idx_t finalized = 0; finalized < count;) { if (unsorted_count < order_bind.threshold) { auto state = sdata[finalized].GetValueUnsafe(); OperatorSinkInput sink {*global_sink, *local_sink, interrupt}; - SinkState(order_bind, *state, finalized, context, sink, prefixed); + SinkState(order_bind, *state, finalized, accumulated, finalize_state, context, sink); unsorted_count += state_unprocessed[finalized]; // Go to the next aggregate unless this is the last one @@ -290,6 +369,12 @@ struct SortedAggregateFunction { } } + // Sink any remaining accumulated rows before sorting + { + OperatorSinkInput sink {*global_sink, *local_sink, interrupt}; + FlushAccumulated(order_bind, accumulated, finalize_state, context, sink); + } + // If they were all empty (filtering) flush them // (This can only happen on the last range) if (!unsorted_count) { @@ -428,6 +513,7 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateE ListCombineFunction, SortedAggregateFunction::Finalize, bound_function.GetProperties().GetNullHandling(), nullptr, nullptr, nullptr, nullptr, SortedAggregateFunction::WindowBatch); + ordered_aggregate.SetInitLocalStateFinalizeCallback(SortedAggregateFinalizeState::Init); expr.FunctionMutable().ReplaceImplementation(ordered_aggregate); expr.BindInfoMutable() = std::move(sorted_bind); @@ -482,6 +568,7 @@ void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundWindowExpr aggregate.GetProperties().GetNullHandling(), nullptr, nullptr, nullptr, nullptr, SortedAggregateFunction::WindowBatch); ordered_aggregate.SetWindowCallback(SortedAggregateFunction::Window); + ordered_aggregate.SetInitLocalStateFinalizeCallback(SortedAggregateFinalizeState::Init); aggregate.ReplaceImplementation(ordered_aggregate); expr.BindInfoMutable() = std::move(sorted_bind); diff --git a/src/duckdb/src/function/aggregate_function.cpp b/src/duckdb/src/function/aggregate_function.cpp index 0a86519a5..89af82665 100644 --- a/src/duckdb/src/function/aggregate_function.cpp +++ b/src/duckdb/src/function/aggregate_function.cpp @@ -17,6 +17,39 @@ AggregateInputData::AggregateInputData(const AggregateObject &aggr, ArenaAllocat allocator_p, combine_type_p) { } +AggregateFinalizeInputData::AggregateFinalizeInputData(const BoundAggregateFunction &function_p, + optional_ptr bind_data_p, + ArenaAllocator &allocator_p, + optional_ptr local_state_p) + : AggregateInputData(function_p, bind_data_p, allocator_p), local_state(local_state_p) { + InitializeLocalState(); +} + +AggregateFinalizeInputData::AggregateFinalizeInputData(const BoundAggregateExpression &expr, + ArenaAllocator &allocator_p, + optional_ptr local_state_p) + : AggregateInputData(expr, allocator_p), local_state(local_state_p) { + InitializeLocalState(); +} + +AggregateFinalizeInputData::AggregateFinalizeInputData(const AggregateObject &aggr, ArenaAllocator &allocator_p, + optional_ptr local_state_p) + : AggregateInputData(aggr, allocator_p), local_state(local_state_p) { + InitializeLocalState(); +} + +void AggregateFinalizeInputData::InitializeLocalState() { + if (local_state) { + // the caller passed in an externally-owned local state + return; + } + auto &callbacks = function.GetCallbacks(); + if (callbacks.HasInitLocalStateFinalizeCallback()) { + owned_state = callbacks.GetInitLocalStateFinalizeCallback()(function, bind_data); + local_state = owned_state.get(); + } +} + bool AggregateFunctionProperties::operator==(const AggregateFunctionProperties &rhs) const { return FunctionProperties::operator==(rhs) && order_dependent == rhs.order_dependent && distinct_dependent == rhs.distinct_dependent; @@ -27,7 +60,8 @@ bool AggregateFunctionProperties::operator!=(const AggregateFunctionProperties & bool AggregateFunctionCallbacks::operator==(const AggregateFunctionCallbacks &rhs) const { return state_size == rhs.state_size && initialize == rhs.initialize && update == rhs.update && - combine == rhs.combine && finalize == rhs.finalize && cluster_update == rhs.cluster_update && + combine == rhs.combine && finalize == rhs.finalize && + init_local_state_finalize == rhs.init_local_state_finalize && cluster_update == rhs.cluster_update && window == rhs.window && window_init == rhs.window_init && window_batch == rhs.window_batch && bind == rhs.bind && destructor == rhs.destructor && statistics == rhs.statistics && serialize == rhs.serialize && deserialize == rhs.deserialize && get_state_type == rhs.get_state_type; diff --git a/src/duckdb/src/function/function_list.cpp b/src/duckdb/src/function/function_list.cpp index 068cbf201..a61c9a55f 100644 --- a/src/duckdb/src/function/function_list.cpp +++ b/src/duckdb/src/function/function_list.cpp @@ -258,7 +258,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(SubstringGraphemeFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(SubtractFun), DUCKDB_SCALAR_FUNCTION(SuffixFun), - DUCKDB_SCALAR_FUNCTION(ToAggregateStateFun), + DUCKDB_SCALAR_FUNCTION_SET(ToAggregateStateFun), DUCKDB_SCALAR_FUNCTION_SET(TryStrpTimeFun), DUCKDB_SCALAR_FUNCTION_ALIAS(UcaseFun), DUCKDB_SCALAR_FUNCTION(UpperFun), @@ -268,6 +268,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(VariantKeysFun), DUCKDB_SCALAR_FUNCTION(VariantNormalizeFun), DUCKDB_SCALAR_FUNCTION(VariantTypeofFun), + DUCKDB_SCALAR_FUNCTION(VertexExtractFun), DUCKDB_SCALAR_FUNCTION_SET(WriteLogFun), DUCKDB_SCALAR_FUNCTION(ConcatOperatorFun), DUCKDB_SCALAR_FUNCTION(LikeFun), diff --git a/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp b/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp index 1a410c1e6..9842b0171 100644 --- a/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp +++ b/src/duckdb/src/function/scalar/geometry/geometry_functions.cpp @@ -150,4 +150,207 @@ ScalarFunction StSetcrsFun::GetFunction() { return geom_func; } +namespace { + +struct VertexExtractBindData final : public FunctionData { + explicit VertexExtractBindData(idx_t vertex_index) : vertex_index(vertex_index) { + } + + idx_t vertex_index; + + unique_ptr Copy() const override { + return make_uniq(vertex_index); + } + + auto Equals(const FunctionData &other) const -> bool override { + auto &other_bind = other.Cast(); + return vertex_index == other_bind.vertex_index; + } +}; + +} // namespace + +static auto VertexExtractBind(BindScalarFunctionInput &input) -> unique_ptr { + auto &arguments = input.GetArguments(); + + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + if (!arguments[1]->IsFoldable()) { + throw BinderException("vertex_extract: vertex argument must be constant!"); + } + const auto vertex_val = ExpressionExecutor::EvaluateScalar(input.GetClientContext(), *arguments[1]); + if (vertex_val.IsNull()) { + throw BinderException("vertex_extract: vertex argument cannot be NULL!"); + } + const auto vertex_str = StringUtil::Lower(StringValue::Get(vertex_val)); + if (vertex_str == "x") { + return make_uniq(static_cast(0)); + } + if (vertex_str == "y") { + return make_uniq(static_cast(1)); + } + if (vertex_str == "z") { + return make_uniq(static_cast(2)); + } + if (vertex_str == "m") { + return make_uniq(static_cast(3)); + } + throw BinderException("vertex_extract: invalid vertex argument '%s', expected one of 'x', 'y', 'z', or 'm'", + vertex_str); +} + +static auto VertexExtractStats(ClientContext &context, FunctionStatisticsInput &input) -> unique_ptr { + const auto &child_stats = input.child_stats; + const auto &bind_data = input.bind_data->Cast(); + + const auto &extent = GeometryStats::GetExtent(child_stats[0]); + const auto &types = GeometryStats::GetTypes(child_stats[0]); + const auto &flags = GeometryStats::GetFlags(child_stats[0]); + + auto new_stats = NumericStats::CreateUnknown(LogicalType::DOUBLE); + + if (!types.Has(GeometryType::POINT) || !flags.HasNonEmptyGeometry()) { + // No non-empty points present, so vertex extraction will always return NULL + *input.expr_ptr = make_uniq(Value(LogicalType::DOUBLE)); + new_stats.Set(StatsInfo::CANNOT_HAVE_VALID_VALUES); + new_stats.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + return new_stats.ToUnique(); + } + + new_stats.CopyValidity(child_stats[0]); + + if (!types.HasOnly(GeometryType::POINT)) { + // If there are non-point geometries, we cannot guarantee that all rows will yield a valid value + new_stats.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + + // If there are empty geometries, we cannot guarantee that all rows will yield a valid value + if (flags.HasEmptyGeometry()) { + new_stats.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + + if (bind_data.vertex_index == 2) { + // Z is absent on XY and XYM points + if (types.Has(VertexType::XY) || types.Has(VertexType::XYM)) { + new_stats.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + if (!types.Has(VertexType::XYZ) && !types.Has(VertexType::XYZM)) { + // If there are no vertex types with Z, we can guarantee that all rows will yield NULL + *input.expr_ptr = make_uniq(Value(LogicalType::DOUBLE)); + new_stats.Set(StatsInfo::CANNOT_HAVE_VALID_VALUES); + new_stats.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + return new_stats.ToUnique(); + } + } + + if (bind_data.vertex_index == 3) { + // M is absent on XY and XYZ points + if (types.Has(VertexType::XY) || types.Has(VertexType::XYZ)) { + new_stats.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + } + if (!types.Has(VertexType::XYM) && !types.Has(VertexType::XYZM)) { + // If there are no vertex types with M, we can guarantee that all rows will yield NULL + *input.expr_ptr = make_uniq(Value(LogicalType::DOUBLE)); + new_stats.Set(StatsInfo::CANNOT_HAVE_VALID_VALUES); + new_stats.Set(StatsInfo::CAN_HAVE_NULL_VALUES); + return new_stats.ToUnique(); + } + } + + if (bind_data.vertex_index == 0 && extent.HasXY()) { // X + NumericStats::SetMin(new_stats, extent.x_min); + NumericStats::SetMax(new_stats, extent.x_max); + return new_stats.ToUnique(); + } + if (bind_data.vertex_index == 1 && extent.HasXY()) { // Y + NumericStats::SetMin(new_stats, extent.y_min); + NumericStats::SetMax(new_stats, extent.y_max); + return new_stats.ToUnique(); + } + if (bind_data.vertex_index == 2 && extent.HasZ()) { // Z + NumericStats::SetMin(new_stats, extent.z_min); + NumericStats::SetMax(new_stats, extent.z_max); + return new_stats.ToUnique(); + } + if (bind_data.vertex_index == 3 && extent.HasM()) { // M + NumericStats::SetMin(new_stats, extent.m_min); + NumericStats::SetMax(new_stats, extent.m_max); + return new_stats.ToUnique(); + } + + return nullptr; +} + +static auto VertexExtractFunction(DataChunk &input, ExpressionState &state, Vector &result) { + const auto &func_expr = state.expr.Cast(); + const auto &bind_data = func_expr.BindInfo()->Cast(); + + UnaryExecutor::Execute(input.data[0], result, [&](const string_t &geom_str) { + const auto data = const_data_ptr_cast(geom_str.GetData()); + const auto size = geom_str.GetSize(); + const auto meta = Load(data + sizeof(uint8_t)); + + const auto type_id = (meta & 0x0000FFFF) % 1000; + const auto flag_id = (meta & 0x0000FFFF) / 1000; + const auto has_z = ((flag_id & 0x01) != 0); + const auto has_m = ((flag_id & 0x02) != 0); + + if (type_id != 1) { + return optional(); + } + + auto value = std::numeric_limits::quiet_NaN(); + + if (bind_data.vertex_index == 0) { // X + constexpr auto offset = sizeof(uint8_t) + sizeof(uint32_t); + if (size < offset + sizeof(double)) { + return optional(); + } + value = Load(data + offset); + } else if (bind_data.vertex_index == 1) { // Y + constexpr auto offset = sizeof(uint8_t) + sizeof(uint32_t) + sizeof(double); + if (size < offset + sizeof(double)) { + return optional(); + } + value = Load(data + offset); + } else if (bind_data.vertex_index == 2) { // Z + if (!has_z) { + return optional(); + } + constexpr auto offset = sizeof(uint8_t) + sizeof(uint32_t) + 2 * sizeof(double); + if (size < offset + sizeof(double)) { + return optional(); + } + value = Load(data + offset); + } else if (bind_data.vertex_index == 3) { // M + if (!has_m) { + return optional(); + } + const auto offset = sizeof(uint8_t) + sizeof(uint32_t) + (2 + has_z) * sizeof(double); + if (size < offset + sizeof(double)) { + return optional(); + } + value = Load(data + offset); + } + + if (std::isnan(value)) { + return optional(); + } + + return optional(value); + }); +} + +ScalarFunction VertexExtractFun::GetFunction() { + auto fun = ScalarFunction({}, LogicalTypeId::DOUBLE, VertexExtractFunction, VertexExtractBind, VertexExtractStats); + fun.GetSignature() + .AddParameter("geom", LogicalType::GEOMETRY()) + .AddParameter("coordinate", LogicalTypeId::VARCHAR) + .SetReturnType(LogicalType::DOUBLE); + + fun.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + return fun; +} + } // namespace duckdb diff --git a/src/duckdb/src/function/scalar/system/aggregate_export.cpp b/src/duckdb/src/function/scalar/system/aggregate_export.cpp index 6f5e63711..d3f95073a 100644 --- a/src/duckdb/src/function/scalar/system/aggregate_export.cpp +++ b/src/duckdb/src/function/scalar/system/aggregate_export.cpp @@ -2,11 +2,14 @@ #include "duckdb/common/vector/list_vector.hpp" #include "duckdb/common/vector/struct_vector.hpp" #include "duckdb/common/types/list_segment.hpp" +#include "duckdb/common/types/variant_value.hpp" #include "duckdb/function/aggregate_state_layout.hpp" #include "duckdb/function/create_sort_key.hpp" #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/common/extension_type_info.hpp" #include "duckdb/common/types/value.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/function_binder.hpp" #include "duckdb/function/scalar/generic_common.hpp" @@ -110,8 +113,8 @@ void TemplateDispatch(PhysicalType type, ARGS &&... args) { } } -static AggregateStateLayout GetLayout(const BoundAggregateFunction &aggr) { - return aggr.GetStateTypeCallback()(aggr); +static AggregateStateLayout GetLayout(const BoundAggregateFunction &aggr, optional_ptr bind_data) { + return aggr.GetStateType(bind_data); } // Load rows from input_vec into the packed binary state buffer. Skips null rows. @@ -283,7 +286,7 @@ struct CombineState : public FunctionLocalState { ArenaAllocator allocator; explicit CombineState(const ExportAggregateBindData &bind_data) - : layout(GetLayout(bind_data.aggr)), + : layout(GetLayout(bind_data.aggr, bind_data.bind_data.get())), state_buffer0(make_unsafe_uniq_array(STANDARD_VECTOR_SIZE * layout.total_state_size)), state_buffer1(make_unsafe_uniq_array(STANDARD_VECTOR_SIZE * layout.total_state_size)), addresses0(LogicalType::POINTER), addresses1(LogicalType::POINTER), allocator(Allocator::DefaultAllocator()) { @@ -306,7 +309,7 @@ struct FinalizeState : public FunctionLocalState { ArenaAllocator allocator; explicit FinalizeState(const ExportAggregateBindData &bind_data) - : layout(GetLayout(bind_data.aggr)), + : layout(GetLayout(bind_data.aggr, bind_data.bind_data.get())), state_buffer(make_unsafe_uniq_array(STANDARD_VECTOR_SIZE * layout.total_state_size)), addresses(LogicalType::POINTER), allocator(Allocator::DefaultAllocator()) { } @@ -336,7 +339,7 @@ void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, Vector & DeserializeState(layout, input.data[0], count, local_state.state_buffer.get(), local_state.allocator); - AggregateInputData aggr_input_data(bind_data.aggr, bind_data.bind_data.get(), local_state.allocator); + AggregateFinalizeInputData aggr_input_data(bind_data.aggr, bind_data.bind_data.get(), local_state.allocator); bind_data.aggr.GetStateFinalizeCallback()(local_state.addresses, aggr_input_data, result, count, 0); auto validity = input.data[0].Validity(); @@ -400,8 +403,11 @@ void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &r } // looks up the aggregate function with the given name in the catalog and binds it with the given argument types +// constant_parameters holds the values of arguments that must be re-bound with a specific constant +// (e.g. string_agg's separator), keyed by argument index - all other arguments are bound with a NULL value unique_ptr BindExportedAggregate(ClientContext &context, const string &function_name, - const vector &argument_types) { + const vector &argument_types, + const map &constant_parameters) { auto &func = Catalog::GetSystemCatalog(context).GetEntry( context, Identifier::DefaultSchema(), Identifier(function_name)); if (func.type != CatalogType::AGGREGATE_FUNCTION_ENTRY) { @@ -421,20 +427,69 @@ unique_ptr BindExportedAggregate(ClientContext &context // but the aggregate state export needs a rework around how it handles more complex aggregates anyway vector> args; args.reserve(argument_types.size()); - for (auto &arg_type : argument_types) { - args.push_back(make_uniq(Value(arg_type))); + for (idx_t arg_idx = 0; arg_idx < argument_types.size(); arg_idx++) { + auto constant_entry = constant_parameters.find(arg_idx); + if (constant_entry != constant_parameters.end()) { + args.push_back(make_uniq(constant_entry->second)); + } else { + args.push_back(make_uniq(Value(argument_types[arg_idx]))); + } } auto [bound_aggr, bind_info] = function_binder.ResolveFunction(aggr, args); - if (bound_aggr.GetArguments() != argument_types) { - throw InternalException("Type mismatch for exported aggregate %s", function_name); + // the bind callback can erase constant arguments (e.g. string_agg's separator) - in that case the original + // argument list holds the pre-erase arguments that the exported signature refers to + const auto &bound_args = + bound_aggr.GetOriginalArguments().empty() ? bound_aggr.GetArguments() : bound_aggr.GetOriginalArguments(); + bool signature_matches = bound_args.size() == argument_types.size(); + for (idx_t arg_idx = 0; signature_matches && arg_idx < bound_args.size(); arg_idx++) { + // an ANY argument in the function signature (e.g. string_agg's data argument) matches any requested type + if (bound_args[arg_idx].id() != LogicalTypeId::ANY && bound_args[arg_idx] != argument_types[arg_idx]) { + signature_matches = false; + } + } + if (!signature_matches) { + throw InternalException("Type mismatch for exported aggregate %s: bound=[%s] requested=[%s]", function_name, + StringUtil::ToString(bound_args, ", "), StringUtil::ToString(argument_types, ", ")); } return make_uniq(bound_aggr, std::move(bind_info), bound_aggr.GetStateSizeCallback()(bound_aggr)); } +// parses the "parameters" property of an AGGREGATE_STATE type +// each parameter is either a plain type, or a (type, value) pair for parameters that were bound to a constant +// (e.g. string_agg's separator) - the latter are returned in constant_parameters, keyed by argument index +void ParseStateParameters(const Value ¶meters, vector &argument_types, + map &constant_parameters) { + for (auto &val : ListValue::GetChildren(parameters)) { + const idx_t arg_idx = argument_types.size(); + if (!val.IsNull() && val.type().id() == LogicalTypeId::TYPE) { + // plain type + argument_types.push_back(TypeValue::GetType(val)); + continue; + } + if (!val.IsNull() && val.type().id() == LogicalTypeId::STRUCT) { + // (type, value) pair + auto &children = StructValue::GetChildren(val); + if (children.size() != 2 || children[0].IsNull() || children[0].type().id() != LogicalTypeId::TYPE || + children[1].type().id() != LogicalTypeId::VARIANT) { + throw InternalException("Aggregate state parameter entry should be a (type, value) pair"); + } + argument_types.push_back(TypeValue::GetType(children[0])); + if (!children[1].IsNull()) { + // the parameter is bound to a constant - decode it and cast it back to the declared argument type + constant_parameters.emplace(arg_idx, + VariantValue::GetValue(children[1]).DefaultCastAs(argument_types.back())); + } + continue; + } + throw InternalException( + "Aggregate state object should have a property called parameters that is a list of types"); + } +} + unique_ptr BindAggregateStateInternal(ClientContext &context, BoundSimpleFunction &function, vector> &arguments) { auto &arg_return_type = arguments[0]->GetReturnType(); @@ -458,15 +513,10 @@ unique_ptr BindAggregateStateInternal(ClientContext &co "Aggregate state object should have a property called parameters that is a list of types"); } vector argument_types; - for (auto &val : ListValue::GetChildren(entry->second)) { - if (val.IsNull() || val.type().id() != LogicalTypeId::TYPE) { - throw InternalException( - "Aggregate state object should have a property called parameters that is a list of types"); - } - argument_types.push_back(TypeValue::GetType(val)); - } + map constant_parameters; + ParseStateParameters(entry->second, argument_types, constant_parameters); - return BindExportedAggregate(context, function_name, argument_types); + return BindExportedAggregate(context, function_name, argument_types, constant_parameters); } unique_ptr BindAggregateState(BindScalarFunctionInput &input) { @@ -477,8 +527,28 @@ unique_ptr BindAggregateState(BindScalarFunctionInput &input) { // combine - both arguments must be aggregate states of the same function with the same signature if (arguments.size() == 2 && arguments[0]->GetReturnType() != arguments[1]->GetReturnType()) { + auto &left_type = arguments[0]->GetReturnType(); + auto &right_type = arguments[1]->GetReturnType(); + if (left_type.IsAggregateState() && right_type.IsAggregateState()) { + // both are aggregate states but they are not equal - if the function and signature match, the states + // were bound with different constant parameters (e.g. string_agg states with different separators) + auto &left_props = left_type.GetExtensionInfo()->properties; + auto &right_props = right_type.GetExtensionInfo()->properties; + auto left_name = left_props.find("function_name"); + auto right_name = right_props.find("function_name"); + if (left_name != left_props.end() && right_name != right_props.end() && + left_name->second == right_name->second) { + auto left_params = left_props.find("parameters"); + auto right_params = right_props.find("parameters"); + throw BinderException( + "Cannot COMBINE aggregate states of \"%s\" that were created with different parameters: %s <> %s", + StringValue::Get(left_name->second), + left_params == left_props.end() ? "?" : left_params->second.ToString(), + right_params == right_props.end() ? "?" : right_params->second.ToString()); + } + } throw BinderException("Cannot COMBINE aggregate states from different functions, %s <> %s", - arguments[0]->GetReturnType().ToString(), arguments[1]->GetReturnType().ToString()); + left_type.ToString(), right_type.ToString()); } if (bound_function.GetName() == "finalize") { @@ -491,7 +561,7 @@ unique_ptr BindAggregateState(BindScalarFunctionInput &input) { return std::move(bind_data); } -void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, +void ExportAggregateFinalize(Vector &state, AggregateFinalizeInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { D_ASSERT(offset == 0); const data_ptr_t *addresses_ptrs; @@ -504,7 +574,7 @@ void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, addresses_ptrs = FlatVector::GetData(state); } - auto layout = GetLayout(aggr_input_data.function); + auto layout = GetLayout(aggr_input_data.function, aggr_input_data.bind_data); result.Flatten(); SerializeState(layout, result, count, addresses_ptrs); @@ -543,7 +613,7 @@ void CombineAggrUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx auto &bind_data = aggr_input_data.bind_data->Cast(); auto &underlying_aggr = bind_data.aggr; - auto layout = GetLayout(underlying_aggr); + auto layout = GetLayout(underlying_aggr, bind_data.bind_data.get()); auto aligned_size = layout.total_state_size; unsafe_unique_array temp_state_buf = make_unsafe_uniq_array(count * aligned_size); @@ -573,7 +643,7 @@ void CombineAggrUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx underlying_aggr.GetStateCombineCallback()(source_vec, target_vec, combine_input, count); } -void CombineAggrFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, +void CombineAggrFinalize(Vector &state, AggregateFinalizeInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { D_ASSERT(offset == 0); auto &bind_data = aggr_input_data.bind_data->Cast(); @@ -588,7 +658,7 @@ void CombineAggrFinalize(Vector &state, AggregateInputData &aggr_input_data, Vec addresses_ptrs = FlatVector::GetData(state); } - auto layout = GetLayout(underlying_aggr); + auto layout = GetLayout(underlying_aggr, bind_data.bind_data.get()); result.Flatten(); SerializeState(layout, result, count, addresses_ptrs); @@ -597,33 +667,98 @@ void CombineAggrFinalize(Vector &state, AggregateInputData &aggr_input_data, Vec // constructs the AGGREGATE_STATE type for the given bound aggregate function // the state layout (a struct) is aliased to AGGREGATE_STATE, with the function name and signature stored in the // extension type info so that the aggregate can be re-bound later (e.g. by FINALIZE/COMBINE) -LogicalType CreateAggregateStateType(const BoundAggregateFunction &bound_function) { +LogicalType CreateAggregateStateType(const BoundAggregateFunction &bound_function, + optional_ptr bind_data) { + auto layout = bound_function.GetStateType(bind_data); // deep copy the type before modifying it - SetAlias/SetExtensionInfo modify the (shared) extra type info in // place, and the state layout type can share its type info with e.g. the aggregate's input expressions - LogicalType state_layout = bound_function.GetStateType().type.DeepCopy(); + LogicalType state_layout = layout.type.DeepCopy(); state_layout.SetAlias("AGGREGATE_STATE"); auto ext_info = make_uniq(); ext_info->properties.emplace("function_name", bound_function.GetName()); + auto &original_arguments = bound_function.GetOriginalArguments().empty() ? bound_function.GetArguments() + : bound_function.GetOriginalArguments(); vector arguments; - for (auto &arg : bound_function.GetOriginalArguments().empty() ? bound_function.GetArguments() - : bound_function.GetOriginalArguments()) { - arguments.push_back(Value::TYPE(arg)); + if (layout.constant_parameters.empty()) { + // all parameters are plain types - store the parameters as a list of types + for (auto &arg : original_arguments) { + arguments.push_back(Value::TYPE(arg)); + } + ext_info->properties.emplace("parameters", Value::LIST(LogicalType::TYPE(), std::move(arguments))); + } else { + // some parameters were bound to a constant (e.g. string_agg's separator) - store the parameters as a list of + // (type, value) pairs, where the value holds the constant the parameter must be re-bound with + for (idx_t arg_idx = 0; arg_idx < original_arguments.size(); arg_idx++) { + child_list_t children; + children.emplace_back("type", Value::TYPE(original_arguments[arg_idx])); + auto constant_entry = layout.constant_parameters.find(arg_idx); + if (constant_entry == layout.constant_parameters.end()) { + children.emplace_back("value", Value(LogicalType::VARIANT())); + } else { + children.emplace_back("value", constant_entry->second.DefaultCastAs(LogicalType::VARIANT())); + } + arguments.push_back(Value::STRUCT(std::move(children))); + } + auto entry_type = LogicalType::STRUCT({{"type", LogicalType::TYPE()}, {"value", LogicalType::VARIANT()}}); + ext_info->properties.emplace("parameters", Value::LIST(entry_type, std::move(arguments))); } - ext_info->properties.emplace("parameters", Value::LIST(LogicalType::TYPE(), std::move(arguments))); state_layout.SetExtensionInfo(std::move(ext_info)); return state_layout; } +// parses a single entry of the to_aggregate_state signature list - either a TYPE value (e.g. make_type('VARCHAR')) +// or a string naming a type (e.g. 'VARCHAR') +LogicalType ParseSignatureType(ClientContext &context, const Value &arg) { + if (arg.IsNull()) { + throw BinderException("to_aggregate_state: the signature cannot contain NULL values"); + } + if (arg.type().id() == LogicalTypeId::TYPE) { + return TypeValue::GetType(arg); + } + if (arg.type().id() == LogicalTypeId::VARCHAR) { + return TransformStringToLogicalType(StringValue::Get(arg), context); + } + throw BinderException("to_aggregate_state: the signature must be a list of types"); +} + +// parses the optional fourth argument of to_aggregate_state: the constant values for arguments that must be bound to +// a specific constant rather than a NULL value of the argument type (e.g. string_agg's separator). It is supplied as +// a LIST with one entry per argument (i.e. the same length as the signature): NULL for arguments that are not bound to +// a constant, and the constant value for arguments that are, e.g. [NULL, '|'] to bind argument 1 to the constant '|' +void ParseConstantParameters(const Value &constants, idx_t argument_count, map &constant_parameters) { + if (constants.IsNull()) { + return; + } + if (constants.type().id() != LogicalTypeId::LIST) { + throw BinderException("to_aggregate_state: the constant parameters must be a list with one entry per argument " + "(use NULL for arguments that are not bound to a constant), e.g. [NULL, '|']"); + } + auto &children = ListValue::GetChildren(constants); + if (children.size() != argument_count) { + throw BinderException("to_aggregate_state: the constant parameters list has %llu entries but the aggregate has " + "%llu arguments - it must have exactly one entry per argument (use NULL for arguments " + "that are not bound to a constant)", + children.size(), argument_count); + } + for (idx_t i = 0; i < children.size(); i++) { + if (children[i].IsNull()) { + continue; + } + constant_parameters[i] = children[i]; + } +} + unique_ptr ToAggregateStateBind(BindScalarFunctionInput &input) { auto &bound_function = input.GetBoundFunction(); auto &arguments = input.GetArguments(); auto &context = input.GetClientContext(); - for (idx_t i = 1; i < 3; i++) { + for (idx_t i = 1; i < arguments.size(); i++) { if (arguments[i]->HasParameter()) { throw ParameterNotResolvedException(); } if (!arguments[i]->IsFoldable()) { - throw BinderException("to_aggregate_state: the aggregate name and signature must be constant"); + throw BinderException("to_aggregate_state: the aggregate name, signature and constant parameters must be " + "constant"); } } auto function_name_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); @@ -636,30 +771,34 @@ unique_ptr ToAggregateStateBind(BindScalarFunctionInput &input) { if (signature_val.IsNull()) { throw BinderException("to_aggregate_state: the signature must be a list of types"); } + // the signature lists all of the argument types in order vector argument_types; for (auto &arg : ListValue::GetChildren(signature_val)) { - if (arg.IsNull()) { - throw BinderException("to_aggregate_state: the signature cannot contain NULL values"); - } - if (arg.type().id() == LogicalTypeId::TYPE) { - argument_types.push_back(TypeValue::GetType(arg)); - } else if (arg.type().id() == LogicalTypeId::VARCHAR) { - argument_types.push_back(TransformStringToLogicalType(StringValue::Get(arg), context)); - } else { - throw BinderException("to_aggregate_state: the signature must be a list of types"); - } + argument_types.push_back(ParseSignatureType(context, arg)); + } + + // constant_parameters holds the values of arguments that must be re-bound with a specific constant rather than a + // NULL value of the argument type (e.g. string_agg's separator), keyed by the argument's index + map constant_parameters; + if (arguments.size() > 3) { + auto constants_val = ExpressionExecutor::EvaluateScalar(context, *arguments[3]); + ParseConstantParameters(constants_val, argument_types.size(), constant_parameters); + } + for (auto &entry : constant_parameters) { + // cast each constant to the argument type declared in the signature + entry.second = entry.second.DefaultCastAs(argument_types[entry.first]); } - auto bind_data = BindExportedAggregate(context, function_name, argument_types); + auto bind_data = BindExportedAggregate(context, function_name, argument_types, constant_parameters); auto &aggr = bind_data->aggr; if (!aggr.HasGetStateTypeCallback()) { throw BinderException( "Aggregate function \"%s\" does not have a state type callback defined - cannot convert to its state", function_name); } - auto state_layout = aggr.GetStateType().type; + auto state_layout = aggr.GetStateType(bind_data->bind_data.get()).type; bound_function.GetArguments()[0] = state_layout; - bound_function.SetReturnType(CreateAggregateStateType(aggr)); + bound_function.SetReturnType(CreateAggregateStateType(aggr, bind_data->bind_data.get())); return std::move(bind_data); } @@ -701,7 +840,7 @@ ExportAggregateFunction::Bind(unique_ptr child_aggrega "Aggregate function \"%s\" does not have a state type callback defined - cannot export state", bound_function.GetName()); } - SetStateExport(*child_aggregate, CreateAggregateStateType(bound_function)); + SetStateExport(*child_aggregate, CreateAggregateStateType(bound_function, child_aggregate->BindInfo().get())); return child_aggregate; } @@ -734,12 +873,21 @@ ScalarFunction CombineFun::GetFunction() { return function; } -ScalarFunction ToAggregateStateFun::GetFunction() { - auto function = ScalarFunction("to_aggregate_state", - {LogicalTypeId::ANY, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::ANY)}, - LogicalTypeId::ANY, ToAggregateStateFunction, ToAggregateStateBind); - function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); - return function; +ScalarFunctionSet ToAggregateStateFun::GetFunctions() { + ScalarFunctionSet set("to_aggregate_state"); + vector arguments {LogicalTypeId::ANY, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::ANY)}; + for (idx_t constant_params = 0; constant_params < 2; constant_params++) { + if (constant_params) { + // optional fourth argument: constant parameter values as a list with one entry per argument (e.g. + // [NULL, ',']) + arguments.emplace_back(LogicalTypeId::ANY); + } + ScalarFunction function("to_aggregate_state", arguments, LogicalTypeId::ANY, ToAggregateStateFunction, + ToAggregateStateBind); + function.SetNullHandling(FunctionNullHandling::SPECIAL_HANDLING); + set.AddFunction(std::move(function)); + } + return set; } AggregateFunction CombineAggrFun::GetFunction() { diff --git a/src/duckdb/src/function/table/arrow_conversion.cpp b/src/duckdb/src/function/table/arrow_conversion.cpp index e93003233..351f6e2e8 100644 --- a/src/duckdb/src/function/table/arrow_conversion.cpp +++ b/src/duckdb/src/function/table/arrow_conversion.cpp @@ -1459,7 +1459,11 @@ void ArrowToDuckDBConversion::ColumnArrowToDuckDBDictionary(Vector &vector, Arro default: throw NotImplementedException("ArrowArrayPhysicalType not recognized"); }; - // the dictionary buffer holds dict_length entries plus one trailing NULL sentinel slot + // the dictionary buffer holds dict_length entries plus one trailing NULL sentinel slot. + // the inner ColumnArrowToDuckDB call may have replaced the buffer via FlatVector::SetData + // (e.g. DirectConversion's zero-copy path), shrinking capacity from dict_length+1 to + // dict_length. Re-extend before sizing so the sentinel slot stays in bounds. + base_vector->Reserve(dict_length + 1); FlatVector::SetSize(*base_vector, count_t(dict_length + 1)); array_state.AddDictionary(std::move(base_vector), array.dictionary); } diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index 808d2e861..594b1c55b 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev8617" +#define DUCKDB_PATCH_VERSION "0-dev8694" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 6 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.6.0-dev8617" +#define DUCKDB_VERSION "v1.6.0-dev8694" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "ccbdf9f7c7" +#define DUCKDB_SOURCE_ID "72e5a0f30c" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/function/window/window_aggregate_states.cpp b/src/duckdb/src/function/window/window_aggregate_states.cpp index b18523b30..5d928cb79 100644 --- a/src/duckdb/src/function/window/window_aggregate_states.cpp +++ b/src/duckdb/src/function/window/window_aggregate_states.cpp @@ -31,7 +31,7 @@ void WindowAggregateStates::Combine(WindowAggregateStates &target) { } void WindowAggregateStates::Finalize(Vector &result) { - AggregateInputData aggr_input_data(aggr, allocator); + AggregateFinalizeInputData aggr_input_data(aggr, allocator); aggr.function.GetStateFinalizeCallback()(*statef, aggr_input_data, result, GetCount(), 0); } diff --git a/src/duckdb/src/function/window/window_naive_aggregator.cpp b/src/duckdb/src/function/window/window_naive_aggregator.cpp index dd0a47c9b..217fd6d5e 100644 --- a/src/duckdb/src/function/window/window_naive_aggregator.cpp +++ b/src/duckdb/src/function/window/window_naive_aggregator.cpp @@ -351,7 +351,7 @@ void WindowNaiveLocalState::Evaluate(ExecutionContext &context, const WindowAggr FlushStates(gsink); // Finalise the result aggregates and write to the result - AggregateInputData aggr_input_data(aggr, allocator); + AggregateFinalizeInputData aggr_input_data(aggr, allocator); aggr.function.GetStateFinalizeCallback()(statef, aggr_input_data, result, count, 0); // Destruct the result aggregates diff --git a/src/duckdb/src/function/window/window_segment_tree.cpp b/src/duckdb/src/function/window/window_segment_tree.cpp index af07fcfaa..a44087761 100644 --- a/src/duckdb/src/function/window/window_segment_tree.cpp +++ b/src/duckdb/src/function/window/window_segment_tree.cpp @@ -281,7 +281,7 @@ void WindowSegmentTreePart::WindowSegmentValue(const WindowSegmentTreeGlobalStat } void WindowSegmentTreePart::Finalize(Vector &result, idx_t count) { // Finalise the result aggregates and write to result if write_result is set - AggregateInputData aggr_input_data(aggr, allocator); + AggregateFinalizeInputData aggr_input_data(aggr, allocator); aggr.function.GetStateFinalizeCallback()(statef, aggr_input_data, result, count, 0); // Destruct the result aggregates diff --git a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp index cbad23af6..10dc1a26e 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog_entry/table_catalog_entry.hpp @@ -43,6 +43,7 @@ struct DropNotNullInfo; struct SetColumnCommentInfo; struct CreateTableInfo; struct BoundCreateTableInfo; +struct ColumnBinding; class TableFunction; struct FunctionData; @@ -138,6 +139,11 @@ class TableCatalogEntry : public StandardEntry { //! Returns true, if the table has a primary key, else false. bool HasPrimaryKey() const; + virtual LogicalType GetExpectedTypeForInsert(const ColumnDefinition &column) const; + virtual unique_ptr GetDefaultExpressionForColumn(ClientContext &context, const LogicalType &input_type, + const LogicalType &result_type, ColumnBinding binding, + const Expression &constant_value) const; + //! Returns the virtual columns for this table virtual virtual_column_map_t GetVirtualColumns() const; diff --git a/src/duckdb/src/include/duckdb/common/index_vector.hpp b/src/duckdb/src/include/duckdb/common/index_vector.hpp index abce65f03..07d2b34a0 100644 --- a/src/duckdb/src/include/duckdb/common/index_vector.hpp +++ b/src/duckdb/src/include/duckdb/common/index_vector.hpp @@ -39,6 +39,14 @@ class IndexVector { return internal_vector.empty(); } + bool operator==(const IndexVector &other) const { + return internal_vector == other.internal_vector; + } + + bool operator!=(const IndexVector &other) const { + return !(*this == other); + } + void reserve(idx_t size) { // NOLINT: match stl API internal_vector.reserve(size); } diff --git a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp index 943c5cadc..6906b0095 100644 --- a/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp +++ b/src/duckdb/src/include/duckdb/common/row_operations/row_operations.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/enums/expression_type.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/types/vector.hpp" +#include "duckdb/function/function.hpp" namespace duckdb { @@ -32,6 +33,8 @@ struct RowOperationsState { ArenaAllocator &allocator; unique_ptr addresses; // Re-usable vector for row_aggregate.cpp + //! Per-aggregate slots in which the aggregates can cache state across finalize calls + vector> local_states; }; // RowOperations contains a set of operations that operate on data using a TupleDataLayout diff --git a/src/duckdb/src/include/duckdb/common/types/variant_value.hpp b/src/duckdb/src/include/duckdb/common/types/variant_value.hpp index 8d26fb02c..e9690012c 100644 --- a/src/duckdb/src/include/duckdb/common/types/variant_value.hpp +++ b/src/duckdb/src/include/duckdb/common/types/variant_value.hpp @@ -49,6 +49,9 @@ struct VariantValue { return VariantValue(Value(LogicalType::SQLNULL)); } + //! Convert a (non-null) VARIANT-typed Value back to a plain Value + static Value GetValue(const Value &variant_val); + public: void AddChild(const string &key, VariantValue &&val); void AddItem(VariantValue &&val); diff --git a/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp b/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp index cd036822c..ef96698aa 100644 --- a/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp +++ b/src/duckdb/src/include/duckdb/common/vector_operations/aggregate_executor.hpp @@ -787,7 +787,7 @@ class AggregateExecutor { } template - static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, + static void Finalize(Vector &states, AggregateFinalizeInputData &finalize_input_data, Vector &result, idx_t count, idx_t offset) { if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) { result.SetVectorType(VectorType::CONSTANT_VECTOR); @@ -795,7 +795,7 @@ class AggregateExecutor { auto sdata = ConstantVector::GetData(states); auto rdata = ConstantVector::GetData(result); - AggregateFinalizeData finalize_data(result, aggr_input_data, count); + AggregateFinalizeData finalize_data(result, finalize_input_data, count); OP::template Finalize(**sdata, *rdata, finalize_data); } else { D_ASSERT(states.GetVectorType() == VectorType::FLAT_VECTOR); @@ -803,7 +803,7 @@ class AggregateExecutor { auto sdata = FlatVector::GetData(states); auto rdata = FlatVector::GetDataMutable(result); - AggregateFinalizeData finalize_data(result, aggr_input_data, count); + AggregateFinalizeData finalize_data(result, finalize_input_data, count); for (idx_t i = 0; i < count; i++) { finalize_data.result_idx = i + offset; OP::template Finalize(*sdata[i], rdata[finalize_data.result_idx], @@ -813,21 +813,21 @@ class AggregateExecutor { } template - static void VoidFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, - idx_t offset) { + static void VoidFinalize(Vector &states, AggregateFinalizeInputData &finalize_input_data, Vector &result, + idx_t count, idx_t offset) { if (states.GetVectorType() == VectorType::CONSTANT_VECTOR) { result.SetVectorType(VectorType::CONSTANT_VECTOR); FlatVector::SetSize(result, count); auto sdata = ConstantVector::GetData(states); - AggregateFinalizeData finalize_data(result, aggr_input_data, count); + AggregateFinalizeData finalize_data(result, finalize_input_data, count); OP::template Finalize(**sdata, finalize_data); } else { D_ASSERT(states.GetVectorType() == VectorType::FLAT_VECTOR); result.SetVectorType(VectorType::FLAT_VECTOR); auto sdata = FlatVector::GetData(states); - AggregateFinalizeData finalize_data(result, aggr_input_data); + AggregateFinalizeData finalize_data(result, finalize_input_data); for (idx_t i = 0; i < count; i++) { finalize_data.result_idx = i + offset; OP::template Finalize(*sdata[i], finalize_data); diff --git a/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp b/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp index d82e90c91..67d1d5b98 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate/minmax_n_helpers.hpp @@ -442,7 +442,7 @@ struct MinMaxNOperation { } template - static void Finalize(Vector &state_vector, AggregateInputData &input_data, Vector &result, idx_t count, + static void Finalize(Vector &state_vector, AggregateFinalizeInputData &input_data, Vector &result, idx_t count, idx_t offset) { // We only expect bind data from arg_max, otherwise nulls last is the default const bool nulls_last = diff --git a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp index 41275fc96..db1eea887 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp @@ -87,8 +87,13 @@ typedef void (*aggregate_update_t)(Vector inputs[], AggregateInputData &aggr_inp //! The type used for combining hashed aggregate states typedef void (*aggregate_combine_t)(Vector &state, Vector &combined, AggregateInputData &aggr_input_data, idx_t count); //! The type used for finalizing hashed aggregate function payloads -typedef void (*aggregate_finalize_t)(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, - idx_t offset); +typedef void (*aggregate_finalize_t)(Vector &state, AggregateFinalizeInputData &finalize_input_data, Vector &result, + idx_t count, idx_t offset); +//! Initializes the local state used by the finalize of the aggregate (optional). +//! The local state can be used to cache expensive intermediates between finalized groups - callers can keep the +//! local state alive to re-use it across multiple finalize calls (see AggregateFinalizeInputData). +typedef unique_ptr (*aggregate_init_local_state_finalize_t)(const BoundAggregateFunction &function, + optional_ptr bind_data); //! The type used for propagating statistics in aggregate functions (optional) typedef unique_ptr (*aggregate_statistics_t)(ClientContext &context, BoundAggregateExpression &expr, AggregateStatisticsInput &input); @@ -126,7 +131,7 @@ typedef void (*aggregate_serialize_t)(Serializer &serializer, const optional_ptr typedef unique_ptr (*aggregate_deserialize_t)(Deserializer &deserializer, BoundAggregateFunction &function); -typedef AggregateStateLayout (*aggregate_get_state_type_t)(const BoundAggregateFunction &function); +typedef AggregateStateLayout (*aggregate_get_state_type_t)(AggregateLayoutInput &input); struct AggregateFunctionInfo { DUCKDB_API virtual ~AggregateFunctionInfo(); @@ -185,6 +190,10 @@ class AggregateFunctionCallbacks { aggregate_finalize_t GetStateFinalizeCallback() const { return finalize; } bool HasStateFinalizeCallback() const { return finalize != nullptr; } + void SetInitLocalStateFinalizeCallback(aggregate_init_local_state_finalize_t callback) { init_local_state_finalize = callback; } + aggregate_init_local_state_finalize_t GetInitLocalStateFinalizeCallback() const { return init_local_state_finalize; } + bool HasInitLocalStateFinalizeCallback() const { return init_local_state_finalize != nullptr; } + bool HasWindowCallback() const { return window != nullptr; } aggregate_window_t GetWindowCallback() const { return window; } void SetWindowCallback(aggregate_window_t callback) { window = callback; } @@ -220,6 +229,8 @@ class AggregateFunctionCallbacks { aggregate_combine_t combine = nullptr; //! The hashed aggregate finalization function (may be null, if window is set) aggregate_finalize_t finalize = nullptr; + //! Initializes the local state used by the finalize (may be null) + aggregate_init_local_state_finalize_t init_local_state_finalize = nullptr; //! The clustered aggregate update function (may be null) aggregate_cluster_update_t cluster_update = nullptr; //! The windowed aggregate custom function (may be null) @@ -336,6 +347,10 @@ class BaseAggregateFunction { auto GetStateFinalizeCallback() const -> aggregate_finalize_t { return callbacks.finalize; } auto SetStateFinalizeCallback(aggregate_finalize_t callback) -> void { callbacks.finalize = callback; } + auto HasInitLocalStateFinalizeCallback() const -> bool { return callbacks.init_local_state_finalize != nullptr; } + auto GetInitLocalStateFinalizeCallback() const -> aggregate_init_local_state_finalize_t { return callbacks.init_local_state_finalize; } + auto SetInitLocalStateFinalizeCallback(aggregate_init_local_state_finalize_t callback) -> void { callbacks.init_local_state_finalize = callback; } + auto HasWindowCallback() const -> bool { return callbacks.window != nullptr; } auto GetWindowCallback() const -> aggregate_window_t { return callbacks.window; } auto SetWindowCallback(aggregate_window_t callback) -> void { callbacks.window = callback; } @@ -538,13 +553,20 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { AggregateFunction::StateCombine, AggregateFunction::StateFinalize, null_handling, UnaryClusterUpdateCallback()); + // automatically wire up the destructor if the operation defines a Destroy method + if constexpr (OperationHasDestroy::value) { + result.callbacks.destructor = AggregateFunction::StateDestroy; + } WireStructStateType(result); return result; } + //! Deprecated: use UnaryAggregate instead - the destructor is now automatically wired up when the operation + //! defines a Destroy method template - static AggregateFunction UnaryAggregateDestructor(LogicalType input_type, LogicalType return_type) { + [[deprecated("Use UnaryAggregate instead - the destructor is now wired up automatically")]] static AggregateFunction + UnaryAggregateDestructor(LogicalType input_type, LogicalType return_type) { auto aggregate = UnaryAggregate(input_type, return_type); aggregate.callbacks.destructor = AggregateFunction::StateDestroy; return aggregate; @@ -589,6 +611,15 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { struct OperationHasInitialize()))>> : std::true_type {}; + //! Detects whether "OP" provides a "Destroy(STATE &, AggregateInputData &)" method + template + struct OperationHasDestroy : std::false_type {}; + template + struct OperationHasDestroy( + std::declval(), std::declval()))>> : std::true_type { + }; + template static void StateInitialize(const BoundAggregateFunction &, data_ptr_t state) { // FIXME: we should remove the "destructor_type" option in the future @@ -668,15 +699,15 @@ class AggregateFunction : public BaseAggregateFunction, public SimpleFunction { } template - static void StateFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, - idx_t offset) { - AggregateExecutor::Finalize(states, aggr_input_data, result, count, offset); + static void StateFinalize(Vector &states, AggregateFinalizeInputData &finalize_input_data, Vector &result, + idx_t count, idx_t offset) { + AggregateExecutor::Finalize(states, finalize_input_data, result, count, offset); } template - static void StateVoidFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, - idx_t offset) { - AggregateExecutor::VoidFinalize(states, aggr_input_data, result, count, offset); + static void StateVoidFinalize(Vector &states, AggregateFinalizeInputData &finalize_input_data, Vector &result, + idx_t count, idx_t offset) { + AggregateExecutor::VoidFinalize(states, finalize_input_data, result, count, offset); } template @@ -694,9 +725,10 @@ class BoundAggregateFunction : public BaseAggregateFunction, public BoundSimpleF DUCKDB_API bool operator==(const BoundAggregateFunction &rhs) const; DUCKDB_API bool operator!=(const BoundAggregateFunction &rhs) const; - AggregateStateLayout GetStateType() const { + AggregateStateLayout GetStateType(optional_ptr bind_data) const { D_ASSERT(callbacks.get_state_type); - return callbacks.get_state_type(*this); + AggregateLayoutInput input(*this, bind_data); + return callbacks.get_state_type(input); } }; @@ -705,7 +737,8 @@ template inline void AggregateFunction::WireStructStateType(AggregateFunction &result) { if constexpr (HasStructStateType::value) { using ST = typename STATE::STATE_TYPE; - result.SetStructStateExport([](const BoundAggregateFunction &bound) { + result.SetStructStateExport([](AggregateLayoutInput &input) { + auto &bound = input.function; AggregateStateLayout layout; if (bound.GetReturnType().IsAggregateState()) { // the function has been modified for state export (see ExportAggregateFunction::SetStateExport) - @@ -720,7 +753,7 @@ inline void AggregateFunction::WireStructStateType(AggregateFunction &result) { return layout; }); } else if constexpr (HasPrimitiveLogicalType::value) { - result.SetStructStateExport([](const BoundAggregateFunction &) { + result.SetStructStateExport([](AggregateLayoutInput &) { return AggregateStateLayout(PrimitiveToLogicalType(), AlignValue(sizeof(STATE))); }); } diff --git a/src/duckdb/src/include/duckdb/function/aggregate_state.hpp b/src/duckdb/src/include/duckdb/function/aggregate_state.hpp index 0b0a9f6d8..3f8d08d63 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate_state.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate_state.hpp @@ -50,6 +50,42 @@ struct AggregateInputData { optional_ptr clustered; }; +//! Input to the get_state_type callback - bundles the bound aggregate function with its bind data so that the +//! callback can resolve the exported state layout (including any constant parameters stored in the bind data). +struct AggregateLayoutInput { + AggregateLayoutInput(const BoundAggregateFunction &function_p, optional_ptr bind_data_p) + : function(function_p), bind_data(bind_data_p) { + } + + const BoundAggregateFunction &function; + optional_ptr bind_data; +}; + +//! The input data provided to the finalize callback of an aggregate function. +//! If the function defines an "init_local_state_finalize" callback, the local state is initialized on construction. +//! Callers can instead pass in an externally-owned local state - this way the local state can be kept alive and +//! re-used across multiple finalize calls (e.g. for the duration of a hash table scan). +struct AggregateFinalizeInputData : public AggregateInputData { + DUCKDB_API AggregateFinalizeInputData(const BoundAggregateFunction &function_p, + optional_ptr bind_data_p, ArenaAllocator &allocator_p, + optional_ptr local_state_p = nullptr); + DUCKDB_API AggregateFinalizeInputData(const BoundAggregateExpression &expr, ArenaAllocator &allocator_p, + optional_ptr local_state_p = nullptr); + DUCKDB_API AggregateFinalizeInputData(const AggregateObject &aggr, ArenaAllocator &allocator_p, + optional_ptr local_state_p = nullptr); + + //! The local state of the finalize (set if the function defines an "init_local_state_finalize" callback) + optional_ptr local_state; + +private: + //! Initializes the local state when the caller did not pass in an external one + void InitializeLocalState(); + +private: + //! The local state owned by this input data - used when the caller does not pass in an external local state + unique_ptr owned_state; +}; + struct AggregateUnaryInput { AggregateUnaryInput(AggregateInputData &input_p, const ValidityMask &input_mask_p) : input(input_p), input_mask(input_mask_p), input_idx(0) { @@ -77,15 +113,14 @@ struct AggregateBinaryInput { }; struct AggregateFinalizeData { - AggregateFinalizeData(Vector &result_p, AggregateInputData &input_p, idx_t result_count_p = 1) + AggregateFinalizeData(Vector &result_p, AggregateFinalizeInputData &input_p, idx_t result_count_p = 1) : result(result_p), input(input_p), result_idx(0), result_count(result_count_p) { } Vector &result; - AggregateInputData &input; + AggregateFinalizeInputData &input; idx_t result_idx; idx_t result_count; - unique_ptr local_state; inline void ReturnNull() { switch (result.GetVectorType()) { diff --git a/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp b/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp index 134357b67..840a2db34 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate_state_layout.hpp @@ -9,8 +9,10 @@ #include "duckdb/common/enums/order_type.hpp" #include "duckdb/common/helper.hpp" +#include "duckdb/common/map.hpp" #include "duckdb/common/type_util.hpp" #include "duckdb/common/types/list_segment.hpp" +#include "duckdb/common/types/value.hpp" namespace duckdb { @@ -481,6 +483,8 @@ struct AggregateStateLayout { LogicalType type; AggregateStateField field; idx_t total_state_size = 0; + //! Constant values for arguments that must be re-bound with a specific constant rather than only the type + unordered_map constant_parameters; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/scalar/geometry_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/geometry_functions.hpp index 0e45677a3..e96916f10 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/geometry_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/geometry_functions.hpp @@ -88,7 +88,17 @@ struct StSetcrsFun { static constexpr const char *Parameters = "geom,crs"; static constexpr const char *Description = "Sets the Coordinate Reference System (CRS) identifier of the geometry"; static constexpr const char *Example = ""; - static constexpr const char *Categories = ""; + static constexpr const char *Categories = "geometry"; + + static ScalarFunction GetFunction(); +}; + +struct VertexExtractFun { + static constexpr const char *Name = "vertex_extract"; + static constexpr const char *Parameters = "geom,coordinate"; + static constexpr const char *Description = "Extracts the specified coordinate (X, Y, Z, M) from a point geometry"; + static constexpr const char *Example = "vertex_extract('POINT(1 2 3)', 'Z')"; + static constexpr const char *Categories = "geometry"; static ScalarFunction GetFunction(); }; diff --git a/src/duckdb/src/include/duckdb/function/scalar/system_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/system_functions.hpp index c4ae6b3a6..5fdd0609d 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/system_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/system_functions.hpp @@ -38,11 +38,11 @@ struct CombineFun { struct ToAggregateStateFun { static constexpr const char *Name = "to_aggregate_state"; static constexpr const char *Parameters = "data,name,signature"; - static constexpr const char *Description = "Converts a value into the aggregate state of the aggregate function with the given name and signature. The type of the value must exactly match the state layout of the aggregate function."; + static constexpr const char *Description = "Converts a value into the aggregate state of the aggregate function with the given name and signature. The type of the value must exactly match the state layout of the aggregate function. An optional fourth argument supplies constant parameter values (e.g. string_agg's separator) as a list with one entry per argument, using NULL for arguments that are not bound to a constant."; static constexpr const char *Example = "to_aggregate_state({'count': 1, 'value': 42.0}, 'avg', ['DOUBLE'])"; static constexpr const char *Categories = ""; - static ScalarFunction GetFunction(); + static ScalarFunctionSet GetFunctions(); }; struct WriteLogFun { diff --git a/src/duckdb/src/include/duckdb/main/database.hpp b/src/duckdb/src/include/duckdb/main/database.hpp index 55b0eb5f0..ff82f7755 100644 --- a/src/duckdb/src/include/duckdb/main/database.hpp +++ b/src/duckdb/src/include/duckdb/main/database.hpp @@ -150,6 +150,11 @@ class DuckDB { load_info->FinishLoad(install_info); } + // Function pointer type for the C API extension init function + typedef bool (*ext_init_c_api_fun_t)(duckdb_extension_info info, duckdb_extension_access *access); + // Load a statically compiled C API extension by calling its init function directly (no vtable needed) + DUCKDB_API void LoadStaticCAPIExtension(const string &name, ext_init_c_api_fun_t init_fun); + DUCKDB_API FileSystem &GetFileSystem(); DUCKDB_API idx_t NumberOfThreads(); diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index d3f704541..2fa157e18 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -297,10 +297,12 @@ class Binder : public enable_shared_from_this { unique_ptr BindUpdateSet(LogicalOperator &op, unique_ptr root, UpdateSetInfo &set_info, TableCatalogEntry &table, + const vector> &bound_defaults, vector &columns, bool prioritize_table_when_binding = false); void BindUpdateSet(TableIndex proj_index, unique_ptr &root, UpdateSetInfo &set_info, TableCatalogEntry &table, vector &columns, + const vector> &bound_defaults, vector> &update_expressions, vector> &projection_expressions, bool prioritize_table_when_binding = false); @@ -586,6 +588,12 @@ class Binder : public enable_shared_from_this { void ExpandDefaultInValuesList(InsertQueryNode &node, TableCatalogEntry &table, optional_ptr values_list, const vector &named_column_map); + + unique_ptr ResolveInputProjection(LogicalInsert &insert, + const IndexVector &column_index_map, + unique_ptr root, + const vector &source_types); + unique_ptr BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, LogicalGet &get, TableIndex proj_index, vector> &expressions, MergeIntoAction &action, diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp index 2ab45dc8f..715a65108 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp @@ -54,7 +54,7 @@ class LogicalInsert : public LogicalOperator { LogicalInsert(TableCatalogEntry &table, TableIndex table_index); vector>> insert_values; - //! The insertion map ([table_index -> index in result, or DConstants::INVALID_INDEX if not specified]) + //! Deprecated: The insertion map ([table_index -> index in result, or DConstants::INVALID_INDEX if not specified]) physical_index_vector_t column_index_map; //! The expected types for the INSERT statement (obtained from the column types) vector expected_types; diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_merge_into.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_merge_into.hpp index d61f42f7b..cf6e9c037 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_merge_into.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_merge_into.hpp @@ -28,7 +28,7 @@ class BoundMergeIntoAction { vector columns; //! Set of expressions for INSERT or UPDATE vector> expressions; - //! Column index map (for INSERT) + //! Deprecated: Column index map (for INSERT) physical_index_vector_t column_index_map; //! Whether or not an UPDATE is a DELETE + INSERT bool update_is_del_and_insert = false; diff --git a/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp b/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp index 34d1be6f7..dd5833b3a 100644 --- a/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp +++ b/src/duckdb/src/include/duckdb/storage/statistics/geometry_stats.hpp @@ -75,6 +75,23 @@ class GeometryTypeSet { } } + bool Has(GeometryType geom_type) const { + const auto geom_idx = static_cast(geom_type); + D_ASSERT(geom_idx < PART_TYPES); + for (idx_t i = 0; i < VERT_TYPES; i++) { + if (sets[i] & (1 << geom_idx)) { + return true; + } + } + return false; + } + + bool Has(VertexType vert_type) const { + const auto vert_idx = static_cast(vert_type); + D_ASSERT(vert_idx < VERT_TYPES); + return sets[vert_idx] != 0; + } + //! Check if only the given geometry and vertex type is present //! (all others are absent) bool HasOnly(GeometryType geom_type, VertexType vert_type) const { @@ -98,6 +115,46 @@ class GeometryTypeSet { return true; } + //! Check if only the given geometry type is present (with any vertex type), + //! and that at least one such geometry is present. + bool HasOnly(GeometryType geom_type) const { + const auto geom_idx = static_cast(geom_type); + D_ASSERT(geom_idx < PART_TYPES); + bool found = false; + for (uint8_t v_idx = 0; v_idx < VERT_TYPES; v_idx++) { + for (uint8_t g_idx = 1; g_idx < PART_TYPES; g_idx++) { + if (!(sets[v_idx] & (1 << g_idx))) { + continue; + } + if (g_idx != geom_idx) { + return false; + } + found = true; + } + } + return found; + } + + //! Check if only the given vertex type is present (with any geometry type), + //! and that at least one such geometry is present. + bool HasOnly(VertexType vert_type) const { + const auto vert_idx = static_cast(vert_type); + D_ASSERT(vert_idx < VERT_TYPES); + bool found = false; + for (uint8_t v_idx = 0; v_idx < VERT_TYPES; v_idx++) { + for (uint8_t g_idx = 1; g_idx < PART_TYPES; g_idx++) { + if (!(sets[v_idx] & (1 << g_idx))) { + continue; + } + if (v_idx != vert_idx) { + return false; + } + found = true; + } + } + return found; + } + bool HasSingleType() const { idx_t type_count = 0; for (uint8_t v_idx = 0; v_idx < VERT_TYPES; v_idx++) { diff --git a/src/duckdb/src/include/duckdb_extension.h b/src/duckdb/src/include/duckdb_extension.h index c99af185a..2d0916976 100644 --- a/src/duckdb/src/include/duckdb_extension.h +++ b/src/duckdb/src/include/duckdb_extension.h @@ -793,6 +793,9 @@ typedef struct { //===--------------------------------------------------------------------===// // Typedefs mapping functions to struct entries //===--------------------------------------------------------------------===// +// When building as a static extension, DuckDB symbols are resolved directly at link time. +// The vtable (duckdb_ext_api) is not used - skip these macro redirections. +#ifndef DUCKDB_BUILD_STATIC_EXTENSION // Version v1.2.0 #define duckdb_open duckdb_ext_api.duckdb_open #define duckdb_open_ext duckdb_ext_api.duckdb_open_ext @@ -1393,9 +1396,16 @@ typedef struct { #define duckdb_destroy_selection_vector duckdb_ext_api.duckdb_destroy_selection_vector #define duckdb_selection_vector_get_data_ptr duckdb_ext_api.duckdb_selection_vector_get_data_ptr +#endif // DUCKDB_BUILD_STATIC_EXTENSION + //===--------------------------------------------------------------------===// // Struct Global Macros //===--------------------------------------------------------------------===// +#ifdef DUCKDB_BUILD_STATIC_EXTENSION +// No vtable global needed for static builds - DuckDB symbols are resolved directly at link time +#define DUCKDB_EXTENSION_GLOBAL +#define DUCKDB_EXTENSION_API_INIT(info, access, minimum_api_version) +#else // This goes in the c/c++ file containing the entrypoint (handle #define DUCKDB_EXTENSION_GLOBAL duckdb_ext_api_v1 duckdb_ext_api = {0}; // Initializes the C Extension API: First thing to call in the extension entrypoint @@ -1405,6 +1415,7 @@ typedef struct { return false; \ }; \ duckdb_ext_api = *res; +#endif // DUCKDB_BUILD_STATIC_EXTENSION // Place in global scope of any C/C++ file that needs to access the extension API #define DUCKDB_EXTENSION_EXTERN extern duckdb_ext_api_v1 duckdb_ext_api; diff --git a/src/duckdb/src/main/capi/aggregate_function-c.cpp b/src/duckdb/src/main/capi/aggregate_function-c.cpp index d3a3f7145..e9bb3ba49 100644 --- a/src/duckdb/src/main/capi/aggregate_function-c.cpp +++ b/src/duckdb/src/main/capi/aggregate_function-c.cpp @@ -122,7 +122,7 @@ void CAPIAggregateCombine(Vector &state, Vector &combined, AggregateInputData &a } } -void CAPIAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, +void CAPIAggregateFinalize(Vector &state, AggregateFinalizeInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { state.Flatten(); auto &bind_data = aggr_input_data.bind_data->Cast(); diff --git a/src/duckdb/src/main/extension/extension_load.cpp b/src/duckdb/src/main/extension/extension_load.cpp index d1c8d8ee9..b11171ffb 100644 --- a/src/duckdb/src/main/extension/extension_load.cpp +++ b/src/duckdb/src/main/extension/extension_load.cpp @@ -137,6 +137,43 @@ struct ExtensionAccess { } }; +//===--------------------------------------------------------------------===// +// Static C API Extension Loading +//===--------------------------------------------------------------------===// +void DuckDB::LoadStaticCAPIExtension(const string &name, ext_init_c_api_fun_t init_fun) { + auto &manager = ExtensionManager::Get(*instance); + auto load_info = manager.BeginLoad({name}); + if (!load_info) { + // already loaded + return; + } + + ExtensionInitResult init_result; + init_result.filename = name; + init_result.filebase = name; + // Statically compiled extensions are always tied to the exact DuckDB version + init_result.abi_type = ExtensionABIType::C_STRUCT_UNSTABLE; + init_result.lib_hdl = nullptr; + + DuckDBExtensionLoadState load_state(*instance, init_result); + + // For static loading, get_api is null - the extension uses direct DuckDB symbols (no vtable needed) + duckdb_extension_access access; + access.set_error = ExtensionAccess::SetError; + access.get_database = ExtensionAccess::GetDatabase; + access.get_api = nullptr; + + if (!(*init_fun)(load_state.ToCStruct(), &access)) { + string msg = load_state.has_error ? load_state.error_data.Message() : "unknown error"; + load_info->LoadFail(ErrorData(msg)); + throw IOException("Failed to load static C API extension '%s': %s", name, msg); + } + + ExtensionInstallInfo install_info; + install_info.mode = ExtensionInstallMode::STATICALLY_LINKED; + load_info->FinishLoad(install_info); +} + //===--------------------------------------------------------------------===// // Load External Extension //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp b/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp index fb616240b..1190dceb1 100644 --- a/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp +++ b/src/duckdb/src/optimizer/partial_aggregate_pushdown.cpp @@ -68,6 +68,10 @@ static bool IsSupportedAggregate(const BoundAggregateExpression &expr) { if (!expr.Function().HasGetStateTypeCallback()) { return false; } + if (expr.Function().GetOrderDependent() == AggregateOrderDependent::ORDER_DEPENDENT) { + // pushing down a partial aggregate changes the order in which values are combined + return false; + } return true; } diff --git a/src/duckdb/src/parser/peg/transformer/transform_expression.cpp b/src/duckdb/src/parser/peg/transformer/transform_expression.cpp index fbfaecfae..c361e8ccc 100644 --- a/src/duckdb/src/parser/peg/transformer/transform_expression.cpp +++ b/src/duckdb/src/parser/peg/transformer/transform_expression.cpp @@ -2514,12 +2514,10 @@ unique_ptr PEGTransformerFactory::TransformTypeLiteral(PEGTran throw ParserException("Cannot convert to type %s, requires exactly one type modifier", EnumUtil::ToString(type.id())); } - if (type == LogicalTypeId::UNBOUND || type.InternalType() == PhysicalType::INVALID) { - type = LogicalType::UNBOUND(make_uniq(colid, vector>())); - } auto string_literal = list_pr.Child(1).result; auto child = make_uniq(Value(string_literal)); - auto result = make_uniq(type, std::move(child)); + auto unbound_type = LogicalType::UNBOUND(make_uniq(colid, vector>())); + auto result = make_uniq(unbound_type, std::move(child)); return std::move(result); } diff --git a/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp index 5a146bb70..029a32120 100644 --- a/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_operator_expression.cpp @@ -176,10 +176,11 @@ BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) const auto &extract_expr_type = extract_exp->GetReturnType(); if (extract_expr_type.id() != LogicalTypeId::STRUCT && extract_expr_type.id() != LogicalTypeId::UNION && extract_expr_type.id() != LogicalTypeId::MAP && extract_expr_type.id() != LogicalTypeId::SQLNULL && - !extract_expr_type.IsJSONType() && extract_expr_type.id() != LogicalTypeId::VARIANT) { - return BindResult(StringUtil::Format( - "Cannot extract field %s from expression \"%s\" because it is not a struct, union, map, or json", - name_exp->ToString(), extract_exp->ToString())); + !extract_expr_type.IsJSONType() && extract_expr_type.id() != LogicalTypeId::VARIANT && + extract_expr_type.id() != LogicalTypeId::GEOMETRY) { + return BindResult(StringUtil::Format("Cannot extract field %s from expression \"%s\" because it is not a " + "struct, union, map, json or geometry", + name_exp->ToString(), extract_exp->ToString())); } if (extract_expr_type.id() == LogicalTypeId::UNION) { function_name = "union_extract"; @@ -206,6 +207,8 @@ BindResult ExpressionBinder::BindExpression(OperatorExpression &op, idx_t depth) const_exp.SetReturnType(LogicalType::VARCHAR); } } + } else if (extract_expr_type.id() == LogicalTypeId::GEOMETRY) { + function_name = "vertex_extract"; } else { function_name = "struct_extract"; } diff --git a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp index 96001104d..4c615a7c7 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_setop.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_setop.cpp @@ -9,21 +9,42 @@ namespace duckdb { +static optional_ptr GetProjectionColumnRef(const Expression &expression) { + if (expression.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { + return expression.Cast(); + } + if (expression.GetExpressionClass() != ExpressionClass::BOUND_CAST) { + return nullptr; + } + auto &cast = expression.Cast(); + if (cast.IsTryCast() || cast.Child().GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return nullptr; + } + return cast.Child().Cast(); +} + // Optionally push a PROJECTION operator unique_ptr Binder::CastLogicalOperatorToTypes(const vector &source_types, const vector &target_types, unique_ptr op) { D_ASSERT(op); - // first check if we even need to cast D_ASSERT(source_types.size() == target_types.size()); + auto node = op.get(); if (source_types == target_types) { - // source and target types are equal: don't need to cast - return op; + bool has_cast = false; + if (node->type == LogicalOperatorType::LOGICAL_PROJECTION) { + for (auto &expression : node->expressions) { + if (expression->GetExpressionClass() == ExpressionClass::BOUND_CAST) { + has_cast = true; + break; + } + } + } + if (!has_cast) { + return op; + } } - // otherwise add casts - auto node = op.get(); if (node->type == LogicalOperatorType::LOGICAL_PROJECTION) { - // "node" is a projection; we can just do the casts in there D_ASSERT(node->expressions.size() == source_types.size()); if (node->children.size() == 1 && node->children[0]->type == LogicalOperatorType::LOGICAL_GET) { // If this projection only has one child and that child is a logical get we can try to pushdown types @@ -33,9 +54,9 @@ unique_ptr Binder::CastLogicalOperatorToTypes(const vector new_column_types; bool do_pushdown = true; for (idx_t i = 0; i < op->expressions.size(); i++) { - if (op->expressions[i]->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - auto &col_ref = op->expressions[i]->Cast(); - auto column_id = column_ids[col_ref.Binding().column_index].GetPrimaryIndex(); + auto col_ref = GetProjectionColumnRef(*op->expressions[i]); + if (col_ref) { + auto column_id = column_ids[col_ref->Binding().column_index].GetPrimaryIndex(); if (new_column_types.find(column_id) != new_column_types.end()) { // Only one reference per column is accepted do_pushdown = false; @@ -57,6 +78,10 @@ unique_ptr Binder::CastLogicalOperatorToTypes(const vector Binder::CastLogicalOperatorToTypes(const vectorcatalog, stmt.info->schema); auto &table = Catalog::GetEntry(context, stmt.info->catalog, stmt.info->schema, stmt.info->table); + physical_index_vector_t column_index_map; + vector named_column_map; + vector expected_types; vector expected_names; - if (!bound_insert.column_index_map.empty()) { - expected_names.resize(bound_insert.expected_types.size()); - for (auto &col : table.GetColumns().Physical()) { - auto i = col.Physical(); - if (bound_insert.column_index_map[i] != DConstants::INVALID_INDEX) { - expected_names[bound_insert.column_index_map[i]] = col.Name().GetIdentifierName(); - } - } - } else { - expected_names.reserve(bound_insert.expected_types.size()); - for (auto &col : table.GetColumns().Physical()) { - expected_names.emplace_back(col.Name()); - } + BindInsertColumnList(table, stmt.info->select_list, false, named_column_map, expected_types, column_index_map); + D_ASSERT(expected_types == bound_insert.expected_types); + expected_names.reserve(named_column_map.size()); + for (auto &column_index : named_column_map) { + expected_names.push_back(table.GetColumn(column_index).Name().GetIdentifierName()); } + auto copy_from_function = function.copy_from_function; CopyFromFunctionBindInput input(*stmt.info, copy_from_function); - auto function_data = function.copy_from_bind(context, input, expected_names, bound_insert.expected_types); + auto function_data = function.copy_from_bind(context, input, expected_names, expected_types); auto get = make_uniq(GenerateTableIndex(), std::move(copy_from_function), std::move(function_data), - bound_insert.expected_types, StringsToIdentifiers(expected_names)); - for (idx_t i = 0; i < bound_insert.expected_types.size(); i++) { + expected_types, StringsToIdentifiers(expected_names)); + for (idx_t i = 0; i < expected_types.size(); i++) { get->AddColumnId(i); } - insert_statement.plan->children.push_back(std::move(get)); + auto root = ResolveInputProjection(bound_insert, column_index_map, std::move(get), expected_types); + insert_statement.plan->children.push_back(std::move(root)); result.plan = std::move(insert_statement.plan); return result; } diff --git a/src/duckdb/src/planner/binder/statement/bind_insert.cpp b/src/duckdb/src/planner/binder/statement/bind_insert.cpp index 6d1ce15ce..08ed07a24 100644 --- a/src/duckdb/src/planner/binder/statement/bind_insert.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_insert.cpp @@ -20,10 +20,14 @@ #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/expression_binder/where_binder.hpp" #include "duckdb/planner/expression_binder/update_binder.hpp" +#include "duckdb/planner/expression/bound_columnref_expression.hpp" +#include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/parser/statement/update_statement.hpp" #include "duckdb/planner/expression/bound_default_expression.hpp" #include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/execution/expression_executor.hpp" +#include "duckdb/function/scalar/struct_functions.hpp" #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/parser/tableref/basetableref.hpp" @@ -67,30 +71,84 @@ void Binder::ExpandDefaultInValuesList(InsertQueryNode &node, TableCatalogEntry idx_t expected_columns = node.columns.empty() ? table.GetColumns().PhysicalColumnCount() : node.columns.size(); // special case: check if we are inserting from a VALUES statement - if (values_list) { - auto &expr_list = values_list->Cast(); - expr_list.expected_types.resize(expected_columns); - expr_list.expected_names.resize(expected_columns); - - D_ASSERT(!expr_list.values.empty()); - CheckInsertColumnCountMismatch(expected_columns, expr_list.values[0].size(), !node.columns.empty(), table.name); - - // VALUES list! - for (idx_t col_idx = 0; col_idx < expected_columns; col_idx++) { - D_ASSERT(named_column_map.size() >= col_idx); - auto &table_col_idx = named_column_map[col_idx]; - - // set the expected types as the types for the INSERT statement - auto &column = table.GetColumn(table_col_idx); - expr_list.expected_types[col_idx] = column.Type(); - expr_list.expected_names[col_idx] = column.Name(); - - // now replace any DEFAULT values with the corresponding default expression - for (idx_t list_idx = 0; list_idx < expr_list.values.size(); list_idx++) { - TryReplaceDefaultExpression(expr_list.values[list_idx][col_idx], column); + auto &expr_list = values_list->Cast(); + expr_list.expected_types.resize(expected_columns); + expr_list.expected_names.resize(expected_columns); + + D_ASSERT(!expr_list.values.empty()); + CheckInsertColumnCountMismatch(expected_columns, expr_list.values[0].size(), !node.columns.empty(), table.name); + + // VALUES list! + for (idx_t col_idx = 0; col_idx < expected_columns; col_idx++) { + D_ASSERT(named_column_map.size() >= col_idx); + auto &table_col_idx = named_column_map[col_idx]; + + // set the expected types as the types for the INSERT statement + auto &column = table.GetColumn(table_col_idx); + expr_list.expected_types[col_idx] = table.GetExpectedTypeForInsert(column); + expr_list.expected_names[col_idx] = column.Name(); + + // now replace any DEFAULT values with the corresponding default expression + for (idx_t list_idx = 0; list_idx < expr_list.values.size(); list_idx++) { + TryReplaceDefaultExpression(expr_list.values[list_idx][col_idx], column); + } + } +} + +unique_ptr Binder::ResolveInputProjection(LogicalInsert &insert, + const IndexVector &column_index_map, + unique_ptr root, + const vector &source_types) { + auto &table = insert.table; + auto source_bindings = root->GetColumnBindings(); + vector> select_list; + for (auto &col : table.GetColumns().Physical()) { + auto storage_idx = col.StorageOid(); + auto mapped_index = column_index_map.empty() ? storage_idx : column_index_map[col.Physical()]; + if (mapped_index == DConstants::INVALID_INDEX) { + // Push default value + select_list.push_back(std::move(insert.bound_defaults[storage_idx])); + continue; + } + auto &original_type = source_types[mapped_index]; + auto source_binding = source_bindings[mapped_index]; + auto expression = table.GetDefaultExpressionForColumn(context, original_type, col.Type(), source_binding, + *insert.bound_defaults[storage_idx]); + if (!expression->HasQueryLocation() && root->type == LogicalOperatorType::LOGICAL_PROJECTION) { + expression->SetQueryLocation(root->expressions[mapped_index]->GetQueryLocation()); + } + select_list.push_back(std::move(expression)); + } + + bool can_inline_projection = root->type == LogicalOperatorType::LOGICAL_PROJECTION; + if (can_inline_projection) { + for (auto &expression : root->expressions) { + if (expression->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + can_inline_projection = false; + break; } } } + if (can_inline_projection) { + auto &child_projection = root->Cast(); + for (auto &expression : select_list) { + ExpressionIterator::EnumerateExpression(expression, [&](unique_ptr &child) { + if (child->GetExpressionType() != ExpressionType::BOUND_COLUMN_REF) { + return; + } + auto &column_ref = child->Cast(); + if (column_ref.Binding().table_index != child_projection.table_index) { + return; + } + child = child_projection.expressions[column_ref.Binding().column_index]->Copy(); + }); + } + root = std::move(child_projection.children[0]); + } + + auto projection = make_uniq(GenerateTableIndex(), std::move(select_list)); + projection->AddChild(std::move(root)); + return std::move(projection); } void DoUpdateSetQualify(unique_ptr &expr, const string &table_name, @@ -586,9 +644,10 @@ BoundStatement Binder::BindNode(InsertQueryNode &node) { node.columns = root_select.names; } + physical_index_vector_t column_index_map; vector named_column_map; BindInsertColumnList(table, node.columns, node.default_values, named_column_map, insert->expected_types, - insert->column_index_map); + column_index_map); // bind the default values auto &catalog_name = table.ParentCatalog().GetName(); @@ -615,9 +674,23 @@ BoundStatement Binder::BindNode(InsertQueryNode &node) { // inserting from a select - check if the column count matches CheckInsertColumnCountMismatch(expected_columns, root_select.types.size(), !node.columns.empty(), table.name); - root = CastLogicalOperatorToTypes(root_select.types, insert->expected_types, std::move(root_select.plan)); + auto source_types = root_select.types; + auto target_types = insert->expected_types; + root_select.plan = ResolveInputProjection(*insert, column_index_map, std::move(root_select.plan), source_types); + target_types.clear(); + for (auto &column : table.GetColumns().Physical()) { + target_types.push_back(column.Type()); + } + source_types.clear(); + for (auto &expr : root_select.plan->expressions) { + source_types.push_back(expr->GetReturnType()); + } + root = CastLogicalOperatorToTypes(source_types, target_types, std::move(root_select.plan)); } else { root = make_uniq(GenerateTableIndex()); + if (node.default_values) { + root = ResolveInputProjection(*insert, column_index_map, std::move(root), {}); + } } insert->AddChild(std::move(root)); diff --git a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp index acc3cfdf6..22ecdc63e 100644 --- a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp @@ -74,8 +74,8 @@ Binder::BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, } } unique_ptr fake_root; - BindUpdateSet(proj_index, fake_root, *action.update_info, table, result->columns, result->expressions, - expressions); + BindUpdateSet(proj_index, fake_root, *action.update_info, table, result->columns, merge_into.bound_defaults, + result->expressions, expressions); // bind any additional columns that need to be bound for update constraints // FIXME: this is pretty hacky @@ -111,10 +111,12 @@ Binder::BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, } vector named_column_map; vector expected_types; + physical_index_vector_t column_index_map; BindInsertColumnList(table, action.insert_columns, action.default_values, named_column_map, expected_types, - result->column_index_map); + column_index_map); - vector> insert_expressions; + vector insert_bindings; + vector insert_types; if (!action.default_values && action.expressions.empty()) { // no expressions: * // expand source bindings @@ -127,19 +129,26 @@ Binder::BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, auto &column = table.GetColumns().GetColumn(named_column_map[i]); InsertBinder insert_binder(*this, context); - insert_binder.target_type = column.Type(); + insert_binder.target_type = table.GetExpectedTypeForInsert(column); TryReplaceDefaultExpression(action.expressions[i], column); auto insert_expr = insert_binder.Bind(action.expressions[i]); - - insert_expressions.push_back(std::move(insert_expr)); - } - - for (auto &insert_expr : insert_expressions) { auto insert_type = insert_expr->GetReturnType(); auto expr_index = ColumnBinding::PushExpression(expressions, std::move(insert_expr)); - result->expressions.push_back( - make_uniq(insert_type, ColumnBinding(proj_index, expr_index))); + insert_bindings.emplace_back(proj_index, expr_index); + insert_types.push_back(std::move(insert_type)); + } + + for (auto &col : table.GetColumns().Physical()) { + auto storage_idx = col.StorageOid(); + auto mapped_index = column_index_map.empty() ? storage_idx : column_index_map[col.Physical()]; + if (mapped_index == DConstants::INVALID_INDEX) { + result->expressions.push_back(merge_into.bound_defaults[storage_idx]->Copy()); + } else { + result->expressions.push_back(table.GetDefaultExpressionForColumn( + context, insert_types[mapped_index], col.Type(), insert_bindings[mapped_index], + *merge_into.bound_defaults[storage_idx])); + } } break; } @@ -269,6 +278,14 @@ BoundStatement Binder::BindNode(MergeQueryNode &node) { auto proj_index = GenerateTableIndex(); vector> projection_expressions; + // bind table constraints/default values in case these are referenced + auto &catalog_name = table.ParentCatalog().GetName(); + auto &schema_name = table.ParentSchema().name; + BindDefaultValues(table.GetColumns(), merge_into->bound_defaults, catalog_name.GetIdentifierName(), + schema_name.GetIdentifierName()); + + merge_into->bound_constraints = BindConstraints(table); + for (auto &entry : node.actions) { if (entry.first == MergeActionCondition::WHEN_MATCHED) { continue; @@ -328,14 +345,6 @@ BoundStatement Binder::BindNode(MergeQueryNode &node) { merge_into->return_chunk = true; } - // bind table constraints/default values in case these are referenced - auto &catalog_name = table.ParentCatalog().GetName(); - auto &schema_name = table.ParentSchema().name; - BindDefaultValues(table.GetColumns(), merge_into->bound_defaults, catalog_name.GetIdentifierName(), - schema_name.GetIdentifierName()); - - merge_into->bound_constraints = BindConstraints(table); - // bind WHEN_MATCHED merge actions (can contain references to both source and target) for (auto &entry : node.actions) { if (entry.first != MergeActionCondition::WHEN_MATCHED) { diff --git a/src/duckdb/src/planner/binder/statement/bind_update.cpp b/src/duckdb/src/planner/binder/statement/bind_update.cpp index 7f0e7c0be..522b10305 100644 --- a/src/duckdb/src/planner/binder/statement/bind_update.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_update.cpp @@ -21,6 +21,7 @@ namespace duckdb { void Binder::BindUpdateSet(TableIndex proj_index, unique_ptr &root, UpdateSetInfo &set_info, TableCatalogEntry &table, vector &columns, + const vector> &bound_defaults, vector> &update_expressions, vector> &projection_expressions, bool prioritize_table_when_binding) { D_ASSERT(set_info.columns.size() == set_info.expressions.size()); @@ -56,10 +57,13 @@ void Binder::BindUpdateSet(TableIndex proj_index, unique_ptr &r } columns.push_back(column.Physical()); if (expr->GetExpressionType() == ExpressionType::VALUE_DEFAULT) { - update_expressions.push_back(make_uniq(column.Type())); + auto bound_default = bound_defaults[column.StorageOid()]->Copy(); + auto expr_index = ColumnBinding::PushExpression(projection_expressions, std::move(bound_default)); + update_expressions.push_back( + make_uniq(column.Type(), ColumnBinding(proj_index, expr_index))); } else { UpdateBinder binder(*expr_binder_ptr, context); - binder.target_type = column.Type(); + binder.target_type = table.GetExpectedTypeForInsert(column); auto bound_expr = binder.Bind(expr); if (root) { PlanSubqueries(bound_expr, root); @@ -67,9 +71,10 @@ void Binder::BindUpdateSet(TableIndex proj_index, unique_ptr &r auto bound_type = bound_expr->GetReturnType(); auto expr_index = ColumnBinding::PushExpression(projection_expressions, std::move(bound_expr)); + auto source_binding = ColumnBinding(proj_index, expr_index); - update_expressions.push_back( - make_uniq(bound_type, ColumnBinding(proj_index, expr_index))); + update_expressions.push_back(table.GetDefaultExpressionForColumn( + context, bound_type, column.Type(), source_binding, *bound_defaults[column.StorageOid()])); } } } @@ -78,11 +83,12 @@ void Binder::BindUpdateSet(TableIndex proj_index, unique_ptr &r // unless there are no expressions to project, in which case it just returns 'root' unique_ptr Binder::BindUpdateSet(LogicalOperator &op, unique_ptr root, UpdateSetInfo &set_info, TableCatalogEntry &table, + const vector> &bound_defaults, vector &columns, bool prioritize_table_when_binding) { auto proj_index = GenerateTableIndex(); vector> projection_expressions; - BindUpdateSet(proj_index, root, set_info, table, columns, op.expressions, projection_expressions, + BindUpdateSet(proj_index, root, set_info, table, columns, bound_defaults, op.expressions, projection_expressions, prioritize_table_when_binding); if (op.type != LogicalOperatorType::LOGICAL_UPDATE && projection_expressions.empty()) { return root; @@ -190,8 +196,8 @@ BoundStatement Binder::BindNode(UpdateQueryNode &node) { D_ASSERT(node.set_info); D_ASSERT(node.set_info->columns.size() == node.set_info->expressions.size()); - auto proj_tmp = BindUpdateSet(*update, std::move(root), *node.set_info, table, update->columns, - node.prioritize_table_when_binding); + auto proj_tmp = BindUpdateSet(*update, std::move(root), *node.set_info, table, update->bound_defaults, + update->columns, node.prioritize_table_when_binding); D_ASSERT(proj_tmp->type == LogicalOperatorType::LOGICAL_PROJECTION); auto proj = unique_ptr_cast(std::move(proj_tmp)); diff --git a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp index dbc0fc7fd..40722ec1e 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_expressionlistref.cpp @@ -33,6 +33,8 @@ BoundStatement Binder::Bind(ExpressionListRef &expr) { if (!result.types.empty()) { D_ASSERT(result.types.size() == expression_list.size()); binder.target_type = result.types[val_idx]; + } else { + binder.target_type = LogicalType(LogicalTypeId::INVALID); } auto bound_expr = binder.Bind(expression_list[val_idx]); list.push_back(std::move(bound_expr)); @@ -40,16 +42,40 @@ BoundStatement Binder::Bind(ExpressionListRef &expr) { values.push_back(std::move(list)); this->SetCanContainNulls(prev_can_contain_nulls); } - if (result.types.empty() && !expr.values.empty()) { - // there are no types specified - // we have to figure out the result types + bool infer_types = result.types.empty(); + if (!infer_types) { + for (auto &type : result.types) { + if (!type.IsValid()) { + infer_types = true; + break; + } + } + } + if (infer_types && !expr.values.empty()) { + // there are no types specified, or some types were left invalid + // we have to figure out the result types for those columns // for each column, we iterate over all of the expressions and select the max logical type // we initialize all types to SQLNULL - result.types.resize(expr.values[0].size(), LogicalType::SQLNULL); + vector should_infer(expr.values[0].size(), true); + if (result.types.empty()) { + result.types.resize(expr.values[0].size(), LogicalType::SQLNULL); + } else { + for (idx_t i = 0; i < result.types.size(); i++) { + auto &type = result.types[i]; + if (!type.IsValid()) { + type = LogicalType::SQLNULL; + } else { + should_infer[i] = false; + } + } + } // now loop over the lists and select the max logical type for (idx_t list_idx = 0; list_idx < values.size(); list_idx++) { auto &list = values[list_idx]; for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { + if (!should_infer[val_idx]) { + continue; + } auto ¤t_type = result.types[val_idx]; auto next_type = ExpressionBinder::GetExpressionReturnType(*list[val_idx]); result.types[val_idx] = LogicalType::MaxLogicalType(context, current_type, next_type); @@ -62,6 +88,9 @@ BoundStatement Binder::Bind(ExpressionListRef &expr) { for (idx_t list_idx = 0; list_idx < values.size(); list_idx++) { auto &list = values[list_idx]; for (idx_t val_idx = 0; val_idx < list.size(); val_idx++) { + if (!should_infer[val_idx]) { + continue; + } list[val_idx] = BoundCastExpression::AddCastToType(context, std::move(list[val_idx]), result.types[val_idx]); } diff --git a/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp b/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp index eab655f37..734a7ba54 100644 --- a/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp +++ b/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp @@ -209,7 +209,11 @@ void BoundMergeIntoAction::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>(201, "condition", condition); serializer.WritePropertyWithDefault>(202, "columns", columns); serializer.WritePropertyWithDefault>>(203, "expressions", expressions); - serializer.WriteProperty>(204, "column_index_map", column_index_map); + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + serializer.WritePropertyWithDefault>(204, "column_index_map", column_index_map, IndexVector()); + } else { + serializer.WriteProperty>(204, "column_index_map", column_index_map); + } serializer.WritePropertyWithDefault(205, "update_is_del_and_insert", update_is_del_and_insert); } @@ -219,7 +223,7 @@ unique_ptr BoundMergeIntoAction::Deserialize(Deserializer deserializer.ReadPropertyWithDefault>(201, "condition", result->condition); deserializer.ReadPropertyWithDefault>(202, "columns", result->columns); deserializer.ReadPropertyWithDefault>>(203, "expressions", result->expressions); - deserializer.ReadProperty>(204, "column_index_map", result->column_index_map); + deserializer.ReadPropertyWithExplicitDefault>(204, "column_index_map", result->column_index_map, IndexVector()); deserializer.ReadPropertyWithDefault(205, "update_is_del_and_insert", result->update_is_del_and_insert); return result; } @@ -549,7 +553,11 @@ void LogicalInsert::Serialize(Serializer &serializer) const { LogicalOperator::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "table_info", table.GetInfo()); serializer.WritePropertyWithDefault>>>(201, "insert_values", insert_values); - serializer.WriteProperty>(202, "column_index_map", column_index_map); + if (serializer.ShouldSerialize(StorageVersion::V2_0_0)) { + serializer.WritePropertyWithDefault>(202, "column_index_map", column_index_map, IndexVector()); + } else { + serializer.WriteProperty>(202, "column_index_map", column_index_map); + } serializer.WritePropertyWithDefault>(203, "expected_types", expected_types); serializer.WritePropertyWithDefault(204, "table_index", table_index); serializer.WritePropertyWithDefault(205, "return_chunk", return_chunk); @@ -572,7 +580,7 @@ unique_ptr LogicalInsert::Deserialize(Deserializer &deserialize auto table_info = deserializer.ReadPropertyWithDefault>(200, "table_info"); auto result = duckdb::unique_ptr(new LogicalInsert(deserializer.Get(), std::move(table_info))); deserializer.ReadPropertyWithDefault>>>(201, "insert_values", result->insert_values); - deserializer.ReadProperty>(202, "column_index_map", result->column_index_map); + deserializer.ReadPropertyWithExplicitDefault>(202, "column_index_map", result->column_index_map, IndexVector()); deserializer.ReadPropertyWithDefault>(203, "expected_types", result->expected_types); deserializer.ReadPropertyWithDefault(204, "table_index", result->table_index); deserializer.ReadPropertyWithDefault(205, "return_chunk", result->return_chunk);