Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions velox/experimental/cudf/exec/CudfGroupby.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ void addDecimalSumCountRequestsAfterDecode(
uint32_t& countIdx,
std::unique_ptr<cudf::column>& decodedSum,
std::unique_ptr<cudf::column>& decodedCount) {
auto sumAndCount = cudf_velox::deserializeDecimalSumState(
encodedColumn, scale, stream, cudf_velox::get_output_mr());
auto sumAndCount =
cudf_velox::deserializeDecimalSumState(encodedColumn, scale, stream);
decodedSum.swap(sumAndCount.sum);
decodedCount.swap(sumAndCount.count);

Expand Down Expand Up @@ -190,7 +190,7 @@ void addDecimalFinalSumOnlyRequest(
auto& request = requests.emplace_back();
sumIdx = requests.size() - 1;
auto sumAndCount = cudf_velox::deserializeDecimalSumState(
tbl.column(inputIndex), scale, stream, cudf_velox::get_output_mr());
tbl.column(inputIndex), scale, stream);
decodedSum.swap(sumAndCount.sum);
request.values = decodedSum->view();
request.aggregations.push_back(
Expand Down
12 changes: 6 additions & 6 deletions velox/experimental/cudf/exec/CudfReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ std::unique_ptr<cudf::column> intermediateDecimalMergeSerializedString(
int32_t scale,
rmm::cuda_stream_view stream) {
auto const sumAgg = cudf::make_sum_aggregation<cudf::reduce_aggregation>();
auto sumAndCount = cudf_velox::deserializeDecimalSumState(
inputCol, scale, stream, get_output_mr());
auto sumAndCount =
cudf_velox::deserializeDecimalSumState(inputCol, scale, stream);
auto sumScalar = cudf::reduce(
sumAndCount.sum->view(),
*sumAgg,
Expand All @@ -307,8 +307,8 @@ std::unique_ptr<cudf::column> finalDecimalAvgFromSerializedString(
TypePtr const& resultType,
rmm::cuda_stream_view stream) {
auto const sumAgg = cudf::make_sum_aggregation<cudf::reduce_aggregation>();
auto sumAndCount = cudf_velox::deserializeDecimalSumState(
inputCol, scale, stream, get_output_mr());
auto sumAndCount =
cudf_velox::deserializeDecimalSumState(inputCol, scale, stream);
auto sumScalar = cudf::reduce(
sumAndCount.sum->view(),
*sumAgg,
Expand Down Expand Up @@ -387,8 +387,8 @@ std::unique_ptr<cudf::column> reduceFinalDecimalSumFromSerializedColumn(
rmm::cuda_stream_view stream) {
validateIntermediateColumnType(inputCol);
auto scale = getDecimalPrecisionScale(*outputType).second;
auto sumAndCount = cudf_velox::deserializeDecimalSumState(
inputCol, scale, stream, get_output_mr());
auto sumAndCount =
cudf_velox::deserializeDecimalSumState(inputCol, scale, stream);
return singleOrRawDecimalSumWithCast(
sumAndCount.sum->view(), outputType, stream);
}
Expand Down
6 changes: 2 additions & 4 deletions velox/experimental/cudf/exec/DecimalAggregationCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ std::unique_ptr<cudf::column> serializeDecimalPartialOrIntermediateState(
std::unique_ptr<cudf::column> count,
rmm::cuda_stream_view stream) {
count = castCountColumnToInt64(std::move(count), stream);
return serializeDecimalSumState(
sum->view(), count->view(), stream, get_output_mr());
return serializeDecimalSumState(sum->view(), count->view(), stream);
}

std::unique_ptr<cudf::column> finalizeDecimalAverage(
Expand All @@ -60,8 +59,7 @@ std::unique_ptr<cudf::column> finalizeDecimalAverage(
const TypePtr& resultType,
rmm::cuda_stream_view stream) {
count = castCountColumnToInt64(std::move(count), stream);
auto avgCol = computeDecimalAverage(
sum->view(), count->view(), stream, get_output_mr());
auto avgCol = computeDecimalAverage(sum->view(), count->view(), stream);
auto const cudfOutType = veloxToCudfDataType(resultType);
if (avgCol->type() != cudfOutType) {
avgCol = cudf::cast(avgCol->view(), cudfOutType, stream, get_output_mr());
Expand Down
49 changes: 30 additions & 19 deletions velox/experimental/cudf/exec/DecimalAggregationKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "velox/experimental/cudf/exec/DecimalAggregationKernels.h"
#include "velox/experimental/cudf/exec/DecimalAggregationKernelsGpu.h"
#include "velox/experimental/cudf/exec/GpuResources.h"

#include "velox/common/base/Exceptions.h"

Expand All @@ -31,8 +32,7 @@ namespace facebook::velox::cudf_velox {
DecimalSumStateColumns deserializeDecimalSumState(
const cudf::column_view& stateCol,
int32_t scale,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr) {
rmm::cuda_stream_view stream) {
VELOX_CHECK(
stateCol.type().id() == cudf::type_id::STRING,
"Decimal sum state requires STRING/VARBINARY column (type is {})",
Expand All @@ -44,12 +44,14 @@ DecimalSumStateColumns deserializeDecimalSumState(
cudf::data_type{cudf::type_id::DECIMAL128, -scale},
0,
cudf::mask_state::UNALLOCATED,
stream);
stream,
get_output_mr());
empty.count = cudf::make_fixed_width_column(
cudf::data_type{cudf::type_id::INT64},
0,
cudf::mask_state::UNALLOCATED,
stream);
stream,
get_output_mr());
return empty;
}

Expand All @@ -61,12 +63,14 @@ DecimalSumStateColumns deserializeDecimalSumState(
cudf::data_type{cudf::type_id::DECIMAL128, -scale},
numRows,
cudf::mask_state::ALL_NULL,
stream);
stream,
get_output_mr());
allNull.count = cudf::make_fixed_width_column(
cudf::data_type{cudf::type_id::INT64},
numRows,
cudf::mask_state::ALL_NULL,
stream);
stream,
get_output_mr());
return allNull;
}

Expand All @@ -81,12 +85,14 @@ DecimalSumStateColumns deserializeDecimalSumState(
cudf::data_type{cudf::type_id::DECIMAL128, -scale},
numRows,
cudf::mask_state::UNALLOCATED,
stream);
stream,
get_output_mr());
auto countCol = cudf::make_fixed_width_column(
cudf::data_type{cudf::type_id::INT64},
numRows,
cudf::mask_state::UNALLOCATED,
stream);
stream,
get_output_mr());

auto sumView = sumCol->mutable_view();
auto countView = countCol->mutable_view();
Expand All @@ -110,10 +116,10 @@ DecimalSumStateColumns deserializeDecimalSumState(
}

if (stateCol.nullable()) {
auto nullMask = cudf::copy_bitmask(stateCol, stream, mr);
auto nullMask = cudf::copy_bitmask(stateCol, stream, get_output_mr());
auto nullCount = stateCol.null_count();
sumCol->set_null_mask(std::move(nullMask), nullCount);
auto countMask = cudf::copy_bitmask(stateCol, stream, mr);
auto countMask = cudf::copy_bitmask(stateCol, stream, get_output_mr());
countCol->set_null_mask(std::move(countMask), nullCount);
}

Expand All @@ -126,8 +132,7 @@ DecimalSumStateColumns deserializeDecimalSumState(
std::unique_ptr<cudf::column> serializeDecimalSumState(
const cudf::column_view& sumCol,
const cudf::column_view& countCol,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr) {
rmm::cuda_stream_view stream) {
VELOX_CHECK(
countCol.type().id() == cudf::type_id::INT64,
"Decimal sum state requires INT64 count column (type is {})",
Expand Down Expand Up @@ -165,11 +170,14 @@ std::unique_ptr<cudf::column> serializeDecimalSumState(
cudf::data_type{offsetsType},
numRows + 1,
cudf::mask_state::UNALLOCATED,
stream);
stream,
get_output_mr());
auto offsetsView = offsetsCol->mutable_view();

rmm::device_buffer charsBuf(
static_cast<size_t>(numRows) * detail::kDecimalSumStateSize, stream);
static_cast<size_t>(numRows) * detail::kDecimalSumStateSize,
stream,
get_output_mr());

detail::fillOffsetsForDecimalSumState(
useLargeOffsets,
Expand Down Expand Up @@ -204,7 +212,7 @@ std::unique_ptr<cudf::column> serializeDecimalSumState(
}

auto [nullMask, nullCount] =
detail::buildStateValidityMask(sumCol, countCol, stream, mr);
detail::buildStateValidityMask(sumCol, countCol, stream, get_output_mr());
return cudf::make_strings_column(
static_cast<cudf::size_type>(numRows),
std::move(offsetsCol),
Expand All @@ -216,8 +224,7 @@ std::unique_ptr<cudf::column> serializeDecimalSumState(
std::unique_ptr<cudf::column> computeDecimalAverage(
const cudf::column_view& sumCol,
const cudf::column_view& countCol,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr) {
rmm::cuda_stream_view stream) {
VELOX_CHECK(
countCol.type().id() == cudf::type_id::INT64,
"Decimal average requires INT64 count column (type is {})",
Expand All @@ -236,7 +243,11 @@ std::unique_ptr<cudf::column> computeDecimalAverage(

auto numRows = sumCol.size();
auto out = cudf::make_fixed_width_column(
sumCol.type(), numRows, cudf::mask_state::UNALLOCATED, stream);
sumCol.type(),
numRows,
cudf::mask_state::UNALLOCATED,
stream,
get_output_mr());

if (numRows > 0) {
auto const rowCount = static_cast<int32_t>(numRows);
Expand All @@ -252,7 +263,7 @@ std::unique_ptr<cudf::column> computeDecimalAverage(
}

auto [nullMask, nullCount] =
detail::buildStateValidityMask(sumCol, countCol, stream, mr);
detail::buildStateValidityMask(sumCol, countCol, stream, get_output_mr());
if (nullCount > 0) {
out->set_null_mask(std::move(nullMask), nullCount);
}
Expand Down
9 changes: 3 additions & 6 deletions velox/experimental/cudf/exec/DecimalAggregationKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ struct DecimalSumStateColumns {
DecimalSumStateColumns deserializeDecimalSumState(
const cudf::column_view& stateCol,
int32_t scale,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);
rmm::cuda_stream_view stream);

// Encodes partial decimal SUM state (DECIMAL64 or DECIMAL128 sums plus
// INT64 counts) into a single STRING column (later converted to Velox
Expand All @@ -50,8 +49,7 @@ DecimalSumStateColumns deserializeDecimalSumState(
std::unique_ptr<cudf::column> serializeDecimalSumState(
const cudf::column_view& sumCol,
const cudf::column_view& countCol,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);
rmm::cuda_stream_view stream);

// Finalizes AVG from intermediate SUM state: divides each sum by its count
// on device with decimal-specific rounding (see averageRoundDecimalSum),
Expand All @@ -61,7 +59,6 @@ std::unique_ptr<cudf::column> serializeDecimalSumState(
std::unique_ptr<cudf::column> computeDecimalAverage(
const cudf::column_view& sumCol,
const cudf::column_view& countCol,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);
rmm::cuda_stream_view stream);

} // namespace facebook::velox::cudf_velox
Loading
Loading