diff --git a/src/query/expression/src/aggregate/aggregate_hashtable.rs b/src/query/expression/src/aggregate/aggregate_hashtable.rs index 904dd571befb1..472d5745454ea 100644 --- a/src/query/expression/src/aggregate/aggregate_hashtable.rs +++ b/src/query/expression/src/aggregate/aggregate_hashtable.rs @@ -92,6 +92,7 @@ impl AggregateHashTable { arena: Arc, need_init_entry: bool, ) -> Self { + debug_assert!(capacity.is_power_of_two()); let entries = if need_init_entry { vec![Entry::default(); capacity] } else { @@ -110,6 +111,7 @@ impl AggregateHashTable { entries, count: 0, capacity, + capacity_mask: capacity - 1, }, config, } diff --git a/src/query/expression/src/aggregate/hash_index.rs b/src/query/expression/src/aggregate/hash_index.rs index baf37d447124d..e8d2516341339 100644 --- a/src/query/expression/src/aggregate/hash_index.rs +++ b/src/query/expression/src/aggregate/hash_index.rs @@ -22,33 +22,56 @@ pub(super) struct HashIndex { pub entries: Vec, pub count: usize, pub capacity: usize, + pub capacity_mask: usize, +} + +const INCREMENT_BITS: usize = 5; + +/// Derive an odd probing step from the high bits of the hash so the walk spans all slots. +/// +/// this will generate a step in the range [1, 2^INCREMENT_BITS) based on hash and always odd. +#[inline(always)] +fn step(hash: u64) -> usize { + ((hash >> (64 - INCREMENT_BITS)) as usize) | 1 +} + +/// Move to the next slot with wrap-around using the power-of-two capacity mask. +/// +/// soundness: capacity is always a power of two, so mask is capacity - 1 +#[inline(always)] +fn next_slot(slot: usize, hash: u64, mask: usize) -> usize { + (slot + step(hash)) & mask +} + +#[inline(always)] +fn init_slot(hash: u64, capacity_mask: usize) -> usize { + hash as usize & capacity_mask } impl HashIndex { pub fn with_capacity(capacity: usize) -> Self { + debug_assert!(capacity.is_power_of_two()); + let capacity_mask = capacity - 1; Self { entries: vec![Entry::default(); capacity], count: 0, capacity, + capacity_mask, } } - fn init_slot(&self, hash: u64) -> usize { - hash as usize & (self.capacity - 1) - } - - fn find_or_insert(&mut self, mut slot: usize, salt: u16) -> (usize, bool) { + fn find_or_insert(&mut self, mut slot: usize, hash: u64) -> (usize, bool) { + let salt = Entry::hash_to_salt(hash); let entries = self.entries.as_mut_slice(); loop { - let entry = &mut entries[slot]; + debug_assert!(entries.get(slot).is_some()); + // SAFETY: slot is always in range + let entry = unsafe { entries.get_unchecked_mut(slot) }; if entry.is_occupied() { if entry.get_salt() == salt { return (slot, false); } else { - slot += 1; - if slot >= self.capacity { - slot = 0; - } + slot = next_slot(slot, hash, self.capacity_mask); continue; } } else { @@ -59,13 +82,10 @@ impl HashIndex { } pub fn probe_slot(&mut self, hash: u64) -> usize { - let mut slot = self.init_slot(hash); let entries = self.entries.as_mut_slice(); + let mut slot = init_slot(hash, self.capacity_mask); while entries[slot].is_occupied() { - slot += 1; - if slot >= self.capacity { - slot = 0; - } + slot = next_slot(slot, hash, self.capacity_mask); } slot as _ } @@ -159,8 +179,9 @@ impl HashIndex { slots.extend( state.group_hashes[..row_count] .iter() - .map(|hash| self.init_slot(*hash)), + .map(|hash| init_slot(*hash, self.capacity_mask)), ); + let capacity_mask = self.capacity_mask; let mut new_group_count = 0; let mut remaining_entries = row_count; @@ -176,7 +197,7 @@ impl HashIndex { let hash = state.group_hashes[row]; let is_new; - (*slot, is_new) = self.find_or_insert(*slot, Entry::hash_to_salt(hash)); + (*slot, is_new) = self.find_or_insert(*slot, hash); if is_new { state.empty_vector[new_entry_count] = row; @@ -217,13 +238,11 @@ impl HashIndex { no_match_count = adapter.compare(state, need_compare_count, no_match_count); } - // 5. Linear probing, just increase iter_times + // 5. Linear probing with hash-derived step for row in state.no_match_vector[..no_match_count].iter().copied() { let slot = &mut slots[row]; - *slot += 1; - if *slot >= self.capacity { - *slot = 0; - } + let hash = state.group_hashes[row]; + *slot = next_slot(*slot, hash, capacity_mask); } remaining_entries = no_match_count; } @@ -262,6 +281,7 @@ impl<'a> TableAdapter for AdapterImpl<'a> { #[cfg(test)] mod tests { use std::collections::HashMap; + use std::collections::HashSet; use super::*; use crate::ProbeState; @@ -405,6 +425,38 @@ mod tests { } } + #[test] + fn test_probe_walk_covers_full_capacity() { + // This test make sure that we can always cover all slots in the table + let capacity = 16; + let capacity_mask = capacity - 1; + + for high_bits in 0u64..(1 << INCREMENT_BITS) { + let hash = high_bits << (64 - INCREMENT_BITS); + let mut slot = init_slot(hash, capacity_mask); + let mut visited = HashSet::with_capacity(capacity); + + for _ in 0..capacity { + assert!( + visited.insert(slot), + "hash {hash:#x} revisited slot {slot} before covering the table" + ); + slot = next_slot(slot, hash, capacity_mask); + } + + assert_eq!( + capacity, + visited.len(), + "hash {hash:#x} failed to cover every slot for capacity {capacity}" + ); + assert_eq!( + init_slot(hash, capacity_mask), + slot, + "hash {hash:#x} walk did not return to its start after {capacity} steps" + ); + } + } + #[test] fn test_hash_index() { TestCase {