@@ -28,6 +28,28 @@ use crate::ProjectedBlock;
2828use crate :: StatesLayout ;
2929use crate :: BATCH_SIZE ;
3030
31+ #[ derive( Debug , Clone , Copy ) ]
32+ struct PartitionMask {
33+ mask : u64 ,
34+ shift : u64 ,
35+ }
36+
37+ impl PartitionMask {
38+ fn new ( partition_count : u64 ) -> Self {
39+ let radix_bits = partition_count. trailing_zeros ( ) as u64 ;
40+ debug_assert_eq ! ( 1 << radix_bits, partition_count) ;
41+
42+ let shift = 48 - radix_bits;
43+ let mask = ( ( 1 << radix_bits) - 1 ) << shift;
44+
45+ Self { mask, shift }
46+ }
47+
48+ pub fn index ( & self , hash : u64 ) -> usize {
49+ ( ( hash & self . mask ) >> self . shift ) as _
50+ }
51+ }
52+
3153pub struct PartitionedPayload {
3254 pub payloads : Vec < Payload > ,
3355 pub group_types : Vec < DataType > ,
@@ -37,9 +59,7 @@ pub struct PartitionedPayload {
3759
3860 pub arenas : Vec < Arc < Bump > > ,
3961
40- partition_count : u64 ,
41- mask_v : u64 ,
42- shift_v : u64 ,
62+ partition_mask : PartitionMask ,
4363}
4464
4565unsafe impl Send for PartitionedPayload { }
@@ -52,9 +72,6 @@ impl PartitionedPayload {
5272 partition_count : u64 ,
5373 arenas : Vec < Arc < Bump > > ,
5474 ) -> Self {
55- let radix_bits = partition_count. trailing_zeros ( ) as u64 ;
56- debug_assert_eq ! ( 1 << radix_bits, partition_count) ;
57-
5875 let states_layout = if !aggrs. is_empty ( ) {
5976 Some ( get_states_layout ( & aggrs) . unwrap ( ) )
6077 } else {
@@ -72,7 +89,7 @@ impl PartitionedPayload {
7289 } )
7390 . collect_vec ( ) ;
7491
75- let offsets = RowLayout {
92+ let row_layout = RowLayout {
7693 states_layout,
7794 ..payloads[ 0 ] . row_layout . clone ( )
7895 } ;
@@ -81,12 +98,10 @@ impl PartitionedPayload {
8198 payloads,
8299 group_types,
83100 aggrs,
84- row_layout : offsets,
85- partition_count,
101+ row_layout,
86102
87103 arenas,
88- mask_v : mask ( radix_bits) ,
89- shift_v : shift ( radix_bits) ,
104+ partition_mask : PartitionMask :: new ( partition_count) ,
90105 }
91106 }
92107
@@ -119,7 +134,7 @@ impl PartitionedPayload {
119134 state. reset_partitions ( self . partition_count ( ) ) ;
120135 for & row in & state. empty_vector [ ..new_group_rows] {
121136 let hash = state. group_hashes [ row] ;
122- let partition_idx = ( ( hash & self . mask_v ) >> self . shift_v ) as usize ;
137+ let partition_idx = self . partition_mask . index ( hash ) ;
123138 let ( count, sel) = & mut state. partition_entries [ partition_idx] ;
124139
125140 sel[ * count as usize ] = row;
@@ -149,19 +164,27 @@ impl PartitionedPayload {
149164 return self ;
150165 }
151166
152- let mut new_partition_payload = PartitionedPayload :: new (
153- self . group_types . clone ( ) ,
154- self . aggrs . clone ( ) ,
155- new_partition_count as u64 ,
156- self . arenas . clone ( ) ,
157- ) ;
167+ let PartitionedPayload {
168+ payloads,
169+ group_types,
170+ aggrs,
171+ arenas,
172+ ..
173+ } = self ;
174+
175+ let mut new_partition_payload =
176+ PartitionedPayload :: new ( group_types, aggrs, new_partition_count as u64 , arenas) ;
177+
178+ state. clear ( ) ;
179+ for payload in payloads. into_iter ( ) {
180+ new_partition_payload. combine_single ( payload, state, None )
181+ }
158182
159- new_partition_payload. combine ( self , state) ;
160183 new_partition_payload
161184 }
162185
163186 pub fn combine ( & mut self , other : PartitionedPayload , state : & mut PayloadFlushState ) {
164- if other. partition_count == self . partition_count {
187+ if other. partition_count ( ) == self . partition_count ( ) {
165188 for ( l, r) in self . payloads . iter_mut ( ) . zip ( other. payloads . into_iter ( ) ) {
166189 l. combine ( r) ;
167190 }
@@ -184,7 +207,7 @@ impl PartitionedPayload {
184207 return ;
185208 }
186209
187- if self . partition_count == 1 {
210+ if self . partition_count ( ) == 1 {
188211 self . payloads [ 0 ] . combine ( other) ;
189212 } else {
190213 flush_state. clear ( ) ;
@@ -194,13 +217,19 @@ impl PartitionedPayload {
194217 // copy rows
195218 let state = & * flush_state. probe_state ;
196219
197- for partition in ( 0 ..self . partition_count as usize )
198- . filter ( |x| only_bucket. is_none ( ) || only_bucket == Some ( * x) )
199- {
200- let ( count, sel) = & state. partition_entries [ partition] ;
201- if * count > 0 {
202- let payload = & mut self . payloads [ partition] ;
203- payload. copy_rows ( & sel[ ..* count as _ ] , & flush_state. addresses ) ;
220+ match only_bucket {
221+ Some ( i) => {
222+ let ( count, sel) = & state. partition_entries [ i] ;
223+ self . payloads [ i] . copy_rows ( & sel[ ..* count as _ ] , & flush_state. addresses ) ;
224+ }
225+ None => {
226+ for ( ( count, sel) , payload) in
227+ state. partition_entries . iter ( ) . zip ( self . payloads . iter_mut ( ) )
228+ {
229+ if * count > 0 {
230+ payload. copy_rows ( & sel[ ..* count as _ ] , & flush_state. addresses ) ;
231+ }
232+ }
204233 }
205234 }
206235 }
@@ -236,7 +265,7 @@ impl PartitionedPayload {
236265 flush_state. addresses [ idx] = row_ptr;
237266
238267 let hash = row_ptr. hash ( & self . row_layout ) ;
239- let partition_idx = ( ( hash & self . mask_v ) >> self . shift_v ) as usize ;
268+ let partition_idx = self . partition_mask . index ( hash ) ;
240269
241270 let ( count, sel) = & mut state. partition_entries [ partition_idx] ;
242271 sel[ * count as usize ] = idx. into ( ) ;
@@ -253,7 +282,7 @@ impl PartitionedPayload {
253282
254283 #[ inline]
255284 pub fn partition_count ( & self ) -> usize {
256- self . partition_count as usize
285+ self . payloads . len ( )
257286 }
258287
259288 #[ allow( dead_code) ]
@@ -266,13 +295,3 @@ impl PartitionedPayload {
266295 self . payloads . iter ( ) . map ( |x| x. memory_size ( ) ) . sum ( )
267296 }
268297}
269-
270- #[ inline]
271- fn shift ( radix_bits : u64 ) -> u64 {
272- 48 - radix_bits
273- }
274-
275- #[ inline]
276- fn mask ( radix_bits : u64 ) -> u64 {
277- ( ( 1 << radix_bits) - 1 ) << shift ( radix_bits)
278- }
0 commit comments