Skip to content

Commit b6e1fef

Browse files
committed
x
1 parent ad5a4a8 commit b6e1fef

File tree

4 files changed

+83
-106
lines changed

4 files changed

+83
-106
lines changed

src/query/expression/src/aggregate/hash_index.rs

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
use std::fmt::Debug;
1616

1717
use super::payload_row::CompareState;
18-
use super::CompareItem;
1918
use super::PartitionedPayload;
2019
use super::ProbeState;
2120
use super::RowPtr;
@@ -164,14 +163,9 @@ impl HashIndex {
164163
row_count: usize,
165164
mut adapter: impl TableAdapter,
166165
) -> usize {
167-
for (i, item) in state.no_match_vector[..row_count].iter_mut().enumerate() {
168-
let hash = state.group_hashes[i];
169-
*item = CompareItem {
170-
row: i.into(),
171-
salt: Entry::hash_to_salt(hash),
172-
slot: self.init_slot(hash),
173-
row_ptr: RowPtr::null(),
174-
};
166+
for (i, row) in state.no_match_vector[..row_count].iter_mut().enumerate() {
167+
*row = i.into();
168+
state.slots[i] = self.init_slot(state.group_hashes[i]);
175169
}
176170

177171
let mut new_group_count = 0;
@@ -183,21 +177,18 @@ impl HashIndex {
183177
let mut no_match_count = 0;
184178

185179
// 1. inject new_group_count, new_entry_count, need_compare_count, no_match_count
186-
for item in state.no_match_vector[..remaining_entries].iter_mut() {
187-
let (slot, is_new) = self.find_or_insert(item.slot, item.salt);
188-
item.slot = slot;
180+
for row in state.no_match_vector[..remaining_entries].iter().copied() {
181+
let slot = &mut state.slots[row];
182+
let is_new;
183+
184+
let salt = Entry::hash_to_salt(state.group_hashes[row]);
185+
(*slot, is_new) = self.find_or_insert(*slot, salt);
189186

190187
if is_new {
191-
state.empty_vector[new_entry_count] = item.row;
192-
state.slots[new_entry_count] = slot;
188+
state.empty_vector[new_entry_count] = row;
193189
new_entry_count += 1;
194190
} else {
195-
state.group_compare_vector[need_compare_count] = CompareItem {
196-
row: item.row,
197-
slot: item.slot,
198-
salt: item.salt,
199-
row_ptr: self.mut_entry(slot).get_pointer(),
200-
};
191+
state.group_compare_vector[need_compare_count] = row;
201192
need_compare_count += 1;
202193
}
203194
}
@@ -208,40 +199,36 @@ impl HashIndex {
208199

209200
adapter.append_rows(state, new_entry_count);
210201

211-
for (i, row) in state.empty_vector[..new_entry_count]
212-
.iter()
213-
.copied()
214-
.enumerate()
215-
{
216-
let entry = self.mut_entry(state.slots[i]);
202+
for row in state.empty_vector[..new_entry_count].iter().copied() {
203+
let entry = self.mut_entry(state.slots[row]);
217204
entry.set_pointer(state.addresses[row]);
218205
debug_assert_eq!(entry.get_pointer(), state.addresses[row]);
219206
}
220207
}
221208

222209
// 3. set address of compare vector
223210
if need_compare_count > 0 {
224-
for item in &mut state.group_compare_vector[..need_compare_count] {
225-
let entry = self.mut_entry(item.slot);
211+
for row in state.group_compare_vector[..need_compare_count]
212+
.iter()
213+
.copied()
214+
{
215+
let entry = self.mut_entry(state.slots[row]);
226216

227217
debug_assert!(entry.is_occupied());
228-
debug_assert_eq!(
229-
entry.get_salt(),
230-
(state.group_hashes[item.row] >> 48) as u16
231-
);
232-
item.row_ptr = entry.get_pointer();
233-
state.addresses[item.row] = item.row_ptr;
218+
debug_assert_eq!(entry.get_salt(), (state.group_hashes[row] >> 48) as u16);
219+
state.addresses[row] = entry.get_pointer();
234220
}
235221

236222
// 4. compare
237223
no_match_count = adapter.compare(state, need_compare_count, no_match_count);
238224
}
239225

240226
// 5. Linear probing, just increase iter_times
241-
for item in &mut state.no_match_vector[..no_match_count] {
242-
item.slot += 1;
243-
if item.slot >= self.capacity {
244-
item.slot = 0;
227+
for row in state.no_match_vector[..no_match_count].iter().copied() {
228+
let slot = &mut state.slots[row];
229+
*slot += 1;
230+
if *slot >= self.capacity {
231+
*slot = 0;
245232
}
246233
}
247234
remaining_entries = no_match_count;
@@ -269,7 +256,9 @@ impl<'a> TableAdapter for AdapterImpl<'a> {
269256
need_compare_count: usize,
270257
no_match_count: usize,
271258
) -> usize {
259+
// todo: compare hash first if NECESSARY
272260
CompareState {
261+
address: &state.addresses,
273262
compare: &mut state.group_compare_vector,
274263
matched: &mut state.match_vector,
275264
no_matched: &mut state.no_match_vector,
@@ -349,12 +338,11 @@ mod tests {
349338
impl TableAdapter for &mut TestTableAdapter {
350339
fn append_rows(&mut self, state: &mut ProbeState, new_entry_count: usize) {
351340
for row in state.empty_vector[..new_entry_count].iter() {
352-
let row_index = row.to_index();
353-
let (key, hash) = self.incoming[row_index];
341+
let (key, hash) = self.incoming[*row];
354342
let value = key + 20;
355343

356344
self.payload.push((key, hash, value));
357-
state.addresses[*row] = self.get_row_ptr(true, row_index);
345+
state.addresses[*row] = self.get_row_ptr(true, row.to_usize());
358346
}
359347
}
360348

@@ -364,18 +352,21 @@ mod tests {
364352
need_compare_count: usize,
365353
mut no_match_count: usize,
366354
) -> usize {
367-
for item in &state.group_compare_vector[..need_compare_count] {
368-
let incoming = self.incoming[item.row.to_index()];
355+
for row in state.group_compare_vector[..need_compare_count]
356+
.iter()
357+
.copied()
358+
{
359+
let incoming = self.incoming[row];
369360

370-
let (key, hash, _) = self.get_payload(item.row_ptr);
361+
let (key, hash, _) = self.get_payload(state.addresses[row]);
371362

372363
const POINTER_MASK: u64 = 0x0000FFFFFFFFFFFF;
373364
assert_eq!(incoming.1 | POINTER_MASK, hash | POINTER_MASK);
374365
if incoming.0 == key {
375366
continue;
376367
}
377368

378-
state.no_match_vector[no_match_count] = item.clone();
369+
state.no_match_vector[no_match_count] = row;
379370
no_match_count += 1;
380371
}
381372

src/query/expression/src/aggregate/payload.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ impl Payload {
250250
for row in select_vector {
251251
unsafe {
252252
address[*row]
253-
.write_u8(write_offset, bitmap.get_bit(row.to_index()) as u8);
253+
.write_u8(write_offset, bitmap.get_bit(row.to_usize()) as u8);
254254
}
255255
}
256256
}
@@ -291,15 +291,15 @@ impl Payload {
291291
let (array_layout, padded_size) = layout.repeat(select_vector.len()).unwrap();
292292
// Bump only allocates but does not drop, so there is no use after free for any item.
293293
let place = self.arena.alloc_layout(array_layout);
294-
for (idx, place) in select_vector
294+
for (row, place) in select_vector
295295
.iter()
296296
.copied()
297297
.enumerate()
298-
.map(|(i, row)| (row.to_index(), unsafe { place.add(padded_size * i) }))
298+
.map(|(i, row)| (row, unsafe { place.add(padded_size * i) }))
299299
{
300300
let place = StateAddr::from(place);
301-
address[idx].set_state_addr(&self.row_layout, &place);
302-
let page = &mut self.pages[page_index[idx]];
301+
address[row].set_state_addr(&self.row_layout, &place);
302+
let page = &mut self.pages[page_index[row]];
303303
for (aggr, loc) in self.aggrs.iter().zip(states_loc.iter()) {
304304
aggr.init_state(AggrState::new(place, loc));
305305
page.state_offsets += 1;

src/query/expression/src/aggregate/payload_row.rs

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use databend_common_column::bitmap::Bitmap;
1717
use databend_common_io::prelude::bincode_deserialize_from_slice;
1818
use databend_common_io::prelude::bincode_serialize_into_buf;
1919

20-
use super::CompareItem;
2120
use super::RowID;
2221
use super::RowLayout;
2322
use super::RowPtr;
@@ -117,19 +116,19 @@ pub(super) unsafe fn serialize_column_to_rowformat(
117116
}
118117
} else {
119118
for row in select_vector {
120-
address[*row].write_u8(offset, v.get_bit(row.to_index()) as u8);
119+
address[*row].write_u8(offset, v.get_bit(row.to_usize()) as u8);
121120
}
122121
}
123122
}
124123
Column::Binary(v) | Column::Bitmap(v) | Column::Variant(v) | Column::Geometry(v) => {
125124
for row in select_vector {
126-
let data = arena.alloc_slice_copy(v.index_unchecked(row.to_index()));
125+
let data = arena.alloc_slice_copy(v.index_unchecked(row.to_usize()));
127126
address[*row].write_bytes(offset, data);
128127
}
129128
}
130129
Column::String(v) => {
131130
for row in select_vector {
132-
let data = arena.alloc_str(v.index_unchecked(row.to_index()));
131+
let data = arena.alloc_str(v.index_unchecked(row.to_usize()));
133132
address[*row].write_bytes(offset, data.as_bytes());
134133
}
135134
}
@@ -150,7 +149,7 @@ pub(super) unsafe fn serialize_column_to_rowformat(
150149
// for complex column
151150
other => {
152151
for row in select_vector {
153-
let s = other.index_unchecked(row.to_index()).to_owned();
152+
let s = other.index_unchecked(row.to_usize()).to_owned();
154153
scratch.clear();
155154
bincode_serialize_into_buf(scratch, &s).unwrap();
156155

@@ -170,15 +169,16 @@ unsafe fn serialize_fixed_size_column_to_rowformat<T>(
170169
T: AccessType<Scalar: Copy>,
171170
{
172171
for row in select_vector {
173-
let val = T::index_column_unchecked_scalar(column, row.to_index());
172+
let val = T::index_column_unchecked_scalar(column, row.to_usize());
174173
address[*row].write(offset, &val);
175174
}
176175
}
177176

178177
pub struct CompareState<'a> {
179-
pub(super) compare: &'a mut [CompareItem; BATCH_SIZE],
180-
pub(super) matched: &'a mut [CompareItem; BATCH_SIZE],
181-
pub(super) no_matched: &'a mut [CompareItem; BATCH_SIZE],
178+
pub(super) address: &'a [RowPtr; BATCH_SIZE],
179+
pub(super) compare: &'a mut [RowID; BATCH_SIZE],
180+
pub(super) matched: &'a mut [RowID; BATCH_SIZE],
181+
pub(super) no_matched: &'a mut [RowID; BATCH_SIZE],
182182
}
183183

184184
impl<'s> CompareState<'s> {
@@ -207,6 +207,7 @@ impl<'s> CompareState<'s> {
207207
);
208208

209209
self = CompareState::<'s> {
210+
address: self.address,
210211
compare: self.matched,
211212
matched: self.compare,
212213
no_matched: self.no_matched,
@@ -355,32 +356,32 @@ impl<'s> CompareState<'s> {
355356
let mut match_count = 0;
356357
if let Some(validity) = validity {
357358
let is_all_set = validity.null_count() == 0;
358-
for item in &self.compare[..count] {
359-
let row = item.row.to_index();
360-
let is_set2 = unsafe { item.row_ptr.read::<u8>(validity_offset) != 0 };
361-
let is_set = is_all_set || unsafe { validity.get_bit_unchecked(row) };
359+
for row in self.compare[..count].iter().copied() {
360+
let row_index = row.to_usize();
361+
let is_set2 = unsafe { self.address[row].read::<u8>(validity_offset) != 0 };
362+
let is_set = is_all_set || unsafe { validity.get_bit_unchecked(row_index) };
362363

363364
let equal = if is_set && is_set2 {
364-
compare_fn(row, &item.row_ptr)
365+
compare_fn(row_index, &self.address[row])
365366
} else {
366367
is_set == is_set2
367368
};
368369

369370
if equal {
370-
self.matched[match_count] = item.clone();
371+
self.matched[match_count] = row;
371372
match_count += 1;
372373
} else {
373-
self.no_matched[no_match_count] = item.clone();
374+
self.no_matched[no_match_count] = row;
374375
no_match_count += 1;
375376
}
376377
}
377378
} else {
378-
for item in &self.compare[..count] {
379-
if compare_fn(item.row.to_index(), &item.row_ptr) {
380-
self.matched[match_count] = item.clone();
379+
for row in self.compare[..count].iter().copied() {
380+
if compare_fn(row.to_usize(), &self.address[row]) {
381+
self.matched[match_count] = row;
381382
match_count += 1;
382383
} else {
383-
self.no_matched[no_match_count] = item.clone();
384+
self.no_matched[no_match_count] = row;
384385
no_match_count += 1;
385386
}
386387
}
@@ -396,16 +397,17 @@ impl<'s> CompareState<'s> {
396397
(count, mut no_match_count): (usize, usize),
397398
) -> (usize, usize) {
398399
let mut match_count = 0;
399-
for item in &self.compare[..count] {
400-
let value = unsafe { AnyType::index_column_unchecked(col, item.row.to_index()) };
401-
let scalar = unsafe { item.row_ptr.read_bytes(col_offset) };
400+
for row in self.compare[..count].iter().copied() {
401+
let row_index = row.to_usize();
402+
let value = unsafe { AnyType::index_column_unchecked(col, row_index) };
403+
let scalar = unsafe { self.address[row].read_bytes(col_offset) };
402404
let scalar: Scalar = bincode_deserialize_from_slice(scalar).unwrap();
403405

404406
if scalar.as_ref() == value {
405-
self.matched[match_count] = item.clone();
407+
self.matched[match_count] = row;
406408
match_count += 1;
407409
} else {
408-
self.no_matched[no_match_count] = item.clone();
410+
self.no_matched[no_match_count] = row;
409411
no_match_count += 1;
410412
}
411413
}

0 commit comments

Comments
 (0)