Skip to content

Commit 963cd88

Browse files
committed
x
1 parent b6e1fef commit 963cd88

File tree

1 file changed

+59
-40
lines changed

1 file changed

+59
-40
lines changed

src/query/expression/src/aggregate/partitioned_payload.rs

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,28 @@ use crate::ProjectedBlock;
2828
use crate::StatesLayout;
2929
use crate::BATCH_SIZE;
3030

31+
#[derive(Debug, Clone, Copy)]
32+
struct PartitionMask {
33+
mask: u64,
34+
shift: u64,
35+
}
36+
37+
impl PartitionMask {
38+
fn new(partition_count: u64) -> Self {
39+
let radix_bits = partition_count.trailing_zeros() as u64;
40+
debug_assert_eq!(1 << radix_bits, partition_count);
41+
42+
let shift = 48 - radix_bits;
43+
let mask = ((1 << radix_bits) - 1) << shift;
44+
45+
Self { mask, shift }
46+
}
47+
48+
pub fn index(&self, hash: u64) -> usize {
49+
((hash & self.mask) >> self.shift) as _
50+
}
51+
}
52+
3153
pub struct PartitionedPayload {
3254
pub payloads: Vec<Payload>,
3355
pub group_types: Vec<DataType>,
@@ -37,9 +59,7 @@ pub struct PartitionedPayload {
3759

3860
pub arenas: Vec<Arc<Bump>>,
3961

40-
partition_count: u64,
41-
mask_v: u64,
42-
shift_v: u64,
62+
partition_mask: PartitionMask,
4363
}
4464

4565
unsafe impl Send for PartitionedPayload {}
@@ -52,9 +72,6 @@ impl PartitionedPayload {
5272
partition_count: u64,
5373
arenas: Vec<Arc<Bump>>,
5474
) -> Self {
55-
let radix_bits = partition_count.trailing_zeros() as u64;
56-
debug_assert_eq!(1 << radix_bits, partition_count);
57-
5875
let states_layout = if !aggrs.is_empty() {
5976
Some(get_states_layout(&aggrs).unwrap())
6077
} else {
@@ -72,7 +89,7 @@ impl PartitionedPayload {
7289
})
7390
.collect_vec();
7491

75-
let offsets = RowLayout {
92+
let row_layout = RowLayout {
7693
states_layout,
7794
..payloads[0].row_layout.clone()
7895
};
@@ -81,12 +98,10 @@ impl PartitionedPayload {
8198
payloads,
8299
group_types,
83100
aggrs,
84-
row_layout: offsets,
85-
partition_count,
101+
row_layout,
86102

87103
arenas,
88-
mask_v: mask(radix_bits),
89-
shift_v: shift(radix_bits),
104+
partition_mask: PartitionMask::new(partition_count),
90105
}
91106
}
92107

@@ -119,7 +134,7 @@ impl PartitionedPayload {
119134
state.reset_partitions(self.partition_count());
120135
for &row in &state.empty_vector[..new_group_rows] {
121136
let hash = state.group_hashes[row];
122-
let partition_idx = ((hash & self.mask_v) >> self.shift_v) as usize;
137+
let partition_idx = self.partition_mask.index(hash);
123138
let (count, sel) = &mut state.partition_entries[partition_idx];
124139

125140
sel[*count as usize] = row;
@@ -149,19 +164,27 @@ impl PartitionedPayload {
149164
return self;
150165
}
151166

152-
let mut new_partition_payload = PartitionedPayload::new(
153-
self.group_types.clone(),
154-
self.aggrs.clone(),
155-
new_partition_count as u64,
156-
self.arenas.clone(),
157-
);
167+
let PartitionedPayload {
168+
payloads,
169+
group_types,
170+
aggrs,
171+
arenas,
172+
..
173+
} = self;
174+
175+
let mut new_partition_payload =
176+
PartitionedPayload::new(group_types, aggrs, new_partition_count as u64, arenas);
177+
178+
state.clear();
179+
for payload in payloads.into_iter() {
180+
new_partition_payload.combine_single(payload, state, None)
181+
}
158182

159-
new_partition_payload.combine(self, state);
160183
new_partition_payload
161184
}
162185

163186
pub fn combine(&mut self, other: PartitionedPayload, state: &mut PayloadFlushState) {
164-
if other.partition_count == self.partition_count {
187+
if other.partition_count() == self.partition_count() {
165188
for (l, r) in self.payloads.iter_mut().zip(other.payloads.into_iter()) {
166189
l.combine(r);
167190
}
@@ -184,7 +207,7 @@ impl PartitionedPayload {
184207
return;
185208
}
186209

187-
if self.partition_count == 1 {
210+
if self.partition_count() == 1 {
188211
self.payloads[0].combine(other);
189212
} else {
190213
flush_state.clear();
@@ -194,13 +217,19 @@ impl PartitionedPayload {
194217
// copy rows
195218
let state = &*flush_state.probe_state;
196219

197-
for partition in (0..self.partition_count as usize)
198-
.filter(|x| only_bucket.is_none() || only_bucket == Some(*x))
199-
{
200-
let (count, sel) = &state.partition_entries[partition];
201-
if *count > 0 {
202-
let payload = &mut self.payloads[partition];
203-
payload.copy_rows(&sel[..*count as _], &flush_state.addresses);
220+
match only_bucket {
221+
Some(i) => {
222+
let (count, sel) = &state.partition_entries[i];
223+
self.payloads[i].copy_rows(&sel[..*count as _], &flush_state.addresses);
224+
}
225+
None => {
226+
for ((count, sel), payload) in
227+
state.partition_entries.iter().zip(self.payloads.iter_mut())
228+
{
229+
if *count > 0 {
230+
payload.copy_rows(&sel[..*count as _], &flush_state.addresses);
231+
}
232+
}
204233
}
205234
}
206235
}
@@ -236,7 +265,7 @@ impl PartitionedPayload {
236265
flush_state.addresses[idx] = row_ptr;
237266

238267
let hash = row_ptr.hash(&self.row_layout);
239-
let partition_idx = ((hash & self.mask_v) >> self.shift_v) as usize;
268+
let partition_idx = self.partition_mask.index(hash);
240269

241270
let (count, sel) = &mut state.partition_entries[partition_idx];
242271
sel[*count as usize] = idx.into();
@@ -253,7 +282,7 @@ impl PartitionedPayload {
253282

254283
#[inline]
255284
pub fn partition_count(&self) -> usize {
256-
self.partition_count as usize
285+
self.payloads.len()
257286
}
258287

259288
#[allow(dead_code)]
@@ -266,13 +295,3 @@ impl PartitionedPayload {
266295
self.payloads.iter().map(|x| x.memory_size()).sum()
267296
}
268297
}
269-
270-
#[inline]
271-
fn shift(radix_bits: u64) -> u64 {
272-
48 - radix_bits
273-
}
274-
275-
#[inline]
276-
fn mask(radix_bits: u64) -> u64 {
277-
((1 << radix_bits) - 1) << shift(radix_bits)
278-
}

0 commit comments

Comments
 (0)