diff --git a/compiler/rustc_index/src/bit_set.rs b/compiler/rustc_index/src/bit_set.rs index f833acf6824e4..6804d4c288a6c 100644 --- a/compiler/rustc_index/src/bit_set.rs +++ b/compiler/rustc_index/src/bit_set.rs @@ -3,14 +3,16 @@ use std::marker::PhantomData; use std::mem; use std::ops::{Bound, Range, RangeBounds}; use std::rc::Rc; -use std::{fmt, iter, slice}; +use std::{fmt, iter}; use Chunk::*; #[cfg(feature = "nightly")] use rustc_macros::{Decodable_NoContext, Encodable_NoContext}; +use crate::bit_set::raw::RawBitIter; use crate::{Idx, IndexVec}; +mod raw; #[cfg(test)] mod tests; @@ -183,48 +185,18 @@ impl DenseBitSet { self.words.iter().all(|a| *a == 0) } - /// Insert `elem`. Returns whether the set has changed. + /// Inserts `value` into the set, and returns true if the set has changed + /// (i.e. the set did not contain the value). #[inline] - pub fn insert(&mut self, elem: T) -> bool { - assert!( - elem.index() < self.domain_size, - "inserting element at index {} but domain size is {}", - elem.index(), - self.domain_size, - ); - let (word_index, mask) = word_index_and_mask(elem); - let word_ref = &mut self.words[word_index]; - let word = *word_ref; - let new_word = word | mask; - *word_ref = new_word; - new_word != word + pub fn insert(&mut self, value: T) -> bool { + raw::insert(self.domain_size, &mut self.words, value.index()) } #[inline] - pub fn insert_range(&mut self, elems: impl RangeBounds) { - let Some((start, end)) = inclusive_start_end(elems, self.domain_size) else { - return; - }; - - let (start_word_index, start_mask) = word_index_and_mask(start); - let (end_word_index, end_mask) = word_index_and_mask(end); - - // Set all words in between start and end (exclusively of both). - for word_index in (start_word_index + 1)..end_word_index { - self.words[word_index] = !0; - } - - if start_word_index != end_word_index { - // Start and end are in different words, so we handle each in turn. - // - // We set all leading bits. This includes the start_mask bit. - self.words[start_word_index] |= !(start_mask - 1); - // And all trailing bits (i.e. from 0..=end) in the end word, - // including the end. - self.words[end_word_index] |= end_mask | (end_mask - 1); - } else { - self.words[start_word_index] |= end_mask | (end_mask - start_mask); - } + pub fn insert_range(&mut self, range: impl RangeBounds) { + let start = range.start_bound().map(|i| i.index()); + let end = range.end_bound().map(|i| i.index()); + raw::insert_range(self.domain_size, &mut self.words, (start, end)); } /// Sets all bits to true. @@ -235,40 +207,17 @@ impl DenseBitSet { /// Checks whether any bit in the given range is a 1. #[inline] - pub fn contains_any(&self, elems: impl RangeBounds) -> bool { - let Some((start, end)) = inclusive_start_end(elems, self.domain_size) else { - return false; - }; - let (start_word_index, start_mask) = word_index_and_mask(start); - let (end_word_index, end_mask) = word_index_and_mask(end); - - if start_word_index == end_word_index { - self.words[start_word_index] & (end_mask | (end_mask - start_mask)) != 0 - } else { - if self.words[start_word_index] & !(start_mask - 1) != 0 { - return true; - } - - let remaining = start_word_index + 1..end_word_index; - if remaining.start <= remaining.end { - self.words[remaining].iter().any(|&w| w != 0) - || self.words[end_word_index] & (end_mask | (end_mask - 1)) != 0 - } else { - false - } - } + pub fn contains_any(&self, range: impl RangeBounds) -> bool { + let start = range.start_bound().map(|i| i.index()); + let end = range.end_bound().map(|i| i.index()); + raw::contains_any(self.domain_size, &self.words, (start, end)) } - /// Returns `true` if the set has changed. + /// Removes `value` from the set, and returns true if the set has changed + /// (i.e. the set did contain the value). #[inline] - pub fn remove(&mut self, elem: T) -> bool { - assert!(elem.index() < self.domain_size); - let (word_index, mask) = word_index_and_mask(elem); - let word_ref = &mut self.words[word_index]; - let word = *word_ref; - let new_word = word & !mask; - *word_ref = new_word; - new_word != word + pub fn remove(&mut self, value: T) -> bool { + raw::remove(self.domain_size, &mut self.words, value.index()) } /// Iterates over the indices of set bits in a sorted order. @@ -278,33 +227,9 @@ impl DenseBitSet { } pub fn last_set_in(&self, range: impl RangeBounds) -> Option { - let (start, end) = inclusive_start_end(range, self.domain_size)?; - let (start_word_index, _) = word_index_and_mask(start); - let (end_word_index, end_mask) = word_index_and_mask(end); - - let end_word = self.words[end_word_index] & (end_mask | (end_mask - 1)); - if end_word != 0 { - let pos = max_bit(end_word) + WORD_BITS * end_word_index; - if start <= pos { - return Some(T::new(pos)); - } - } - - // We exclude end_word_index from the range here, because we don't want - // to limit ourselves to *just* the last word: the bits set it in may be - // after `end`, so it may not work out. - if let Some(offset) = - self.words[start_word_index..end_word_index].iter().rposition(|&w| w != 0) - { - let word_idx = start_word_index + offset; - let start_word = self.words[word_idx]; - let pos = max_bit(start_word) + WORD_BITS * word_idx; - if start <= pos { - return Some(T::new(pos)); - } - } - - None + let start = range.start_bound().map(|i| i.index()); + let end = range.end_bound().map(|i| i.index()); + raw::last_set_in(self.domain_size, &self.words, (start, end)).map(T::new) } bit_relations_inherent_impls! {} @@ -410,54 +335,22 @@ impl ToString for DenseBitSet { } pub struct BitIter<'a, T: Idx> { - /// A copy of the current word, but with any already-visited bits cleared. - /// (This lets us use `trailing_zeros()` to find the next set bit.) When it - /// is reduced to 0, we move onto the next word. - word: Word, - - /// The offset (measured in bits) of the current word. - offset: usize, - - /// Underlying iterator over the words. - iter: slice::Iter<'a, Word>, - + raw: RawBitIter<'a>, marker: PhantomData, } impl<'a, T: Idx> BitIter<'a, T> { - #[inline] - fn new(words: &'a [Word]) -> BitIter<'a, T> { - // We initialize `word` and `offset` to degenerate values. On the first - // call to `next()` we will fall through to getting the first word from - // `iter`, which sets `word` to the first word (if there is one) and - // `offset` to 0. Doing it this way saves us from having to maintain - // additional state about whether we have started. - BitIter { - word: 0, - offset: usize::MAX - (WORD_BITS - 1), - iter: words.iter(), - marker: PhantomData, - } + #[inline(always)] + fn new(words: &'a [Word]) -> Self { + BitIter { raw: RawBitIter::new(words), marker: PhantomData } } } impl<'a, T: Idx> Iterator for BitIter<'a, T> { type Item = T; - fn next(&mut self) -> Option { - loop { - if self.word != 0 { - // Get the position of the next set bit in the current word, - // then clear the bit. - let bit_pos = self.word.trailing_zeros() as usize; - self.word ^= 1 << bit_pos; - return Some(T::new(bit_pos + self.offset)); - } - // Move onto the next word. `wrapping_add()` is needed to handle - // the degenerate initial value given to `offset` in `new()`. - self.word = *self.iter.next()?; - self.offset = self.offset.wrapping_add(WORD_BITS); - } + fn next(&mut self) -> Option { + self.raw.next().map(T::new) } } @@ -740,7 +633,7 @@ impl ChunkedBitSet { Some(Ones) => ChunkIter::Ones(0..chunk_domain_size as usize), Some(Mixed { ones_count: _, words }) => { let num_words = num_words(chunk_domain_size as usize); - ChunkIter::Mixed(BitIter::new(&words[0..num_words])) + ChunkIter::Mixed(RawBitIter::new(&words[0..num_words])) } None => ChunkIter::Finished, } @@ -1058,7 +951,7 @@ impl Chunk { enum ChunkIter<'a> { Zeros, Ones(Range), - Mixed(BitIter<'a, usize>), + Mixed(RawBitIter<'a>), Finished, } @@ -1700,12 +1593,9 @@ fn num_chunks(domain_size: T) -> usize { domain_size.index().div_ceil(CHUNK_BITS) } -#[inline] +#[inline(always)] fn word_index_and_mask(elem: T) -> (usize, Word) { - let elem = elem.index(); - let word_index = elem / WORD_BITS; - let mask = 1 << (elem % WORD_BITS); - (word_index, mask) + raw::word_index_and_mask(elem.index()) } #[inline] diff --git a/compiler/rustc_index/src/bit_set/raw.rs b/compiler/rustc_index/src/bit_set/raw.rs new file mode 100644 index 0000000000000..3fa0e9715776e --- /dev/null +++ b/compiler/rustc_index/src/bit_set/raw.rs @@ -0,0 +1,192 @@ +use std::ops::Bound; +use std::slice; + +use crate::bit_set::{WORD_BITS, Word, inclusive_start_end, max_bit}; + +#[inline] +pub(crate) fn contains_any( + domain_size: usize, + words: &[Word], + range: (Bound, Bound), +) -> bool { + let Some((start, end)) = inclusive_start_end(range, domain_size) else { + return false; + }; + + let (start_word_index, start_mask) = word_index_and_mask(start); + let (end_word_index, end_mask) = word_index_and_mask(end); + + if start_word_index == end_word_index { + words[start_word_index] & (end_mask | (end_mask - start_mask)) != 0 + } else { + if words[start_word_index] & !(start_mask - 1) != 0 { + return true; + } + + let remaining = start_word_index + 1..end_word_index; + if remaining.start <= remaining.end { + words[remaining].iter().any(|&w| w != 0) + || words[end_word_index] & (end_mask | (end_mask - 1)) != 0 + } else { + false + } + } +} + +#[inline] +pub(crate) fn last_set_in( + domain_size: usize, + words: &[Word], + range: (Bound, Bound), +) -> Option { + let (start, end) = inclusive_start_end(range, domain_size)?; + + let (start_word_index, _) = word_index_and_mask(start); + let (end_word_index, end_mask) = word_index_and_mask(end); + + let end_word = words[end_word_index] & (end_mask | (end_mask - 1)); + if end_word != 0 { + let pos = max_bit(end_word) + WORD_BITS * end_word_index; + if start <= pos { + return Some(pos); + } + } + + // We exclude end_word_index from the range here, because we don't want + // to limit ourselves to *just* the last word: the bits set it in may be + // after `end`, so it may not work out. + if let Some(offset) = words[start_word_index..end_word_index].iter().rposition(|&w| w != 0) { + let word_idx = start_word_index + offset; + let start_word = words[word_idx]; + let pos = max_bit(start_word) + WORD_BITS * word_idx; + if start <= pos { + return Some(pos); + } + } + + None +} + +#[inline(always)] +pub(crate) fn insert(domain_size: usize, words: &mut [Word], index: usize) -> bool { + if index >= domain_size { + index_not_in_domain("inserting", index, domain_size); + } + + let (word_index, mask) = word_index_and_mask(index); + modify_word(&mut words[word_index], |w| w | mask) +} + +#[inline(always)] +pub(crate) fn remove(domain_size: usize, words: &mut [Word], index: usize) -> bool { + if index >= domain_size { + index_not_in_domain("removing", index, domain_size); + } + + let (word_index, mask) = word_index_and_mask(index); + modify_word(&mut words[word_index], |w| w & !mask) +} + +#[cold] +#[inline(never)] +fn index_not_in_domain(verb: &'static str, index: usize, domain_size: usize) -> ! { + panic!("{verb} at index {index} but domain size is {domain_size}"); +} + +/// Updates a `&mut Word` using the given function, and returns true if the value was changed. +#[inline(always)] +fn modify_word(word: &mut Word, modify_fn: impl Fn(Word) -> Word) -> bool { + let old = *word; + *word = modify_fn(old); + old != *word +} + +#[inline] +pub(crate) fn insert_range( + domain_size: usize, + words: &mut [Word], + range: (Bound, Bound), +) { + let Some((start, end)) = inclusive_start_end(range, domain_size) else { + return; + }; + + let (start_word_index, start_mask) = word_index_and_mask(start); + let (end_word_index, end_mask) = word_index_and_mask(end); + + // Set all words in between start and end (exclusively of both). + for word_index in (start_word_index + 1)..end_word_index { + words[word_index] = !0; + } + + if start_word_index != end_word_index { + // Start and end are in different words, so we handle each in turn. + // + // We set all leading bits. This includes the start_mask bit. + words[start_word_index] |= !(start_mask - 1); + // And all trailing bits (i.e. from 0..=end) in the end word, + // including the end. + words[end_word_index] |= end_mask | (end_mask - 1); + } else { + words[start_word_index] |= end_mask | (end_mask - start_mask); + } +} + +pub(crate) struct RawBitIter<'a> { + /// A copy of the current word, but with any already-visited bits cleared. + /// (This lets us use `trailing_zeros()` to find the next set bit.) When it + /// is reduced to 0, we move onto the next word. + word: Word, + + /// The offset (measured in bits) of the current word. + offset: usize, + + /// Underlying iterator over the words. + iter: slice::Iter<'a, Word>, +} + +impl<'a> RawBitIter<'a> { + #[inline(always)] + pub(crate) fn new(words: &'a [Word]) -> Self { + // Initialize `offset` to `0 - WORD_BITS`, so that the first iteration + // will see `word == 0` and increase the offset to its starting value of 0. + // + // This avoids having to explicitly track whether the iterator has started. + RawBitIter { + word: 0, + offset: const { (0usize).wrapping_sub(WORD_BITS) }, + iter: words.iter(), + } + } +} + +impl<'a> Iterator for RawBitIter<'a> { + type Item = usize; + + #[inline] + fn next(&mut self) -> Option { + // Keep looping until we find a non-empty word, or run out of words. + loop { + if self.word != 0 { + // Get the position of the next set bit in the current word, + // then clear the bit. + let bit_pos = self.word.trailing_zeros() as usize; + self.word ^= 1 << bit_pos; + return Some(bit_pos + self.offset); + } + + // Move onto the next word, or stop if there isn't one. + self.word = self.iter.next().copied()?; + // This needs to be a wrapping add so that the first iteration will + // correctly overflow to a starting offset of 0. + self.offset = self.offset.wrapping_add(WORD_BITS); + } + } +} + +#[inline(always)] +pub(crate) fn word_index_and_mask(bit_index: usize) -> (usize, Word) { + let word_index = bit_index / WORD_BITS; + let mask = 1 << (bit_index % WORD_BITS); + (word_index, mask) +}