diff --git a/src/const_choice.rs b/src/const_choice.rs index 67dbfa8b..e4579b63 100644 --- a/src/const_choice.rs +++ b/src/const_choice.rs @@ -241,21 +241,16 @@ impl ConstChoice { x & self.as_u32_mask() } + /// WARNING: this method should only be used in contexts that aren't constant-time critical! #[inline] - pub(crate) const fn is_true_vartime(&self) -> bool { - self.0 == ConstChoice::TRUE.0 + pub(crate) const fn to_bool_vartime(self) -> bool { + self.0 != 0 } #[inline] pub(crate) const fn to_u8(self) -> u8 { (self.0 as u8) & 1 } - - /// WARNING: this method should only be used in contexts that aren't constant-time critical! - #[inline] - pub(crate) const fn to_bool_vartime(self) -> bool { - self.to_u8() != 0 - } } /// `const` equivalent of `u32::max(a, b)`. @@ -284,7 +279,7 @@ impl From for ConstChoice { impl From for bool { fn from(choice: ConstChoice) -> Self { - choice.is_true_vartime() + choice.to_bool_vartime() } } @@ -351,7 +346,7 @@ impl ConstCtOption { #[track_caller] pub fn unwrap(self) -> T { assert!( - self.is_some.is_true_vartime(), + self.is_some.to_bool_vartime(), "called `ConstCtOption::unwrap()` on a `None` value" ); self.value @@ -403,7 +398,7 @@ impl ConstCtOption> { #[inline] #[track_caller] pub const fn expect(self, msg: &str) -> Uint { - assert!(self.is_some.is_true_vartime(), "{}", msg); + assert!(self.is_some.to_bool_vartime(), "{}", msg); self.value } @@ -424,7 +419,7 @@ impl ConstCtOption<(Uint, Uint)> { #[inline] #[track_caller] pub const fn expect(self, msg: &str) -> (Uint, Uint) { - assert!(self.is_some.is_true_vartime(), "{}", msg); + assert!(self.is_some.to_bool_vartime(), "{}", msg); self.value } } @@ -439,7 +434,7 @@ impl ConstCtOption>> { #[inline] #[track_caller] pub const fn expect(self, msg: &str) -> NonZero> { - assert!(self.is_some.is_true_vartime(), "{}", msg); + assert!(self.is_some.to_bool_vartime(), "{}", msg); self.value } } @@ -454,7 +449,7 @@ impl ConstCtOption>> { #[inline] #[track_caller] pub const fn expect(self, msg: &str) -> Odd> { - assert!(self.is_some.is_true_vartime(), "{}", msg); + assert!(self.is_some.to_bool_vartime(), "{}", msg); self.value } } @@ -475,7 +470,7 @@ impl ConstCtOption> { #[inline] #[track_caller] pub const fn expect(self, msg: &str) -> Int { - assert!(self.is_some.is_true_vartime(), "{}", msg); + assert!(self.is_some.to_bool_vartime(), "{}", msg); self.value } } @@ -490,7 +485,7 @@ impl ConstCtOption> { #[inline] #[track_caller] pub const fn expect(self, msg: &str) -> NonZeroInt { - assert!(self.is_some.is_true_vartime(), "{}", msg); + assert!(self.is_some.to_bool_vartime(), "{}", msg); self.value } } @@ -505,7 +500,7 @@ impl ConstCtOption> { #[inline] #[track_caller] pub const fn expect(self, msg: &str) -> OddInt { - assert!(self.is_some.is_true_vartime(), "{}", msg); + assert!(self.is_some.to_bool_vartime(), "{}", msg); self.value } } @@ -520,7 +515,7 @@ impl ConstCtOption> { #[inline] #[track_caller] pub const fn expect(self, msg: &str) -> NonZero { - assert!(self.is_some.is_true_vartime(), "{}", msg); + assert!(self.is_some.to_bool_vartime(), "{}", msg); self.value } } @@ -535,7 +530,7 @@ impl ConstCtOption> { #[inline] #[track_caller] pub const fn expect(self, msg: &str) -> SafeGcdInverter { - assert!(self.is_some.is_true_vartime(), "{}", msg); + assert!(self.is_some.to_bool_vartime(), "{}", msg); self.value } } @@ -555,7 +550,7 @@ impl, const LIMBS: usize> ConstCtOption ConstMontyForm { - assert!(self.is_some.is_true_vartime(), "{}", msg); + assert!(self.is_some.to_bool_vartime(), "{}", msg); self.value } } diff --git a/src/int/mod_symbol.rs b/src/int/mod_symbol.rs index f4d167ac..ba048991 100644 --- a/src/int/mod_symbol.rs +++ b/src/int/mod_symbol.rs @@ -25,7 +25,7 @@ impl Int { let (abs, sign) = self.abs_sign(); let jacobi = abs.jacobi_symbol_vartime(rhs); JacobiSymbol::from_i8( - if sign.is_true_vartime() && rhs.as_ref().limbs[0].0 & 3 == 3 { + if sign.to_bool_vartime() && rhs.as_ref().limbs[0].0 & 3 == 3 { -(jacobi as i8) } else { jacobi as i8 diff --git a/src/int/shl.rs b/src/int/shl.rs index 25d86931..b8caea6a 100644 --- a/src/int/shl.rs +++ b/src/int/shl.rs @@ -153,8 +153,8 @@ mod tests { #[test] fn shl256_const() { - assert!(N.overflowing_shl(256).is_none().is_true_vartime()); - assert!(N.overflowing_shl_vartime(256).is_none().is_true_vartime()); + assert!(N.overflowing_shl(256).is_none().to_bool_vartime()); + assert!(N.overflowing_shl_vartime(256).is_none().to_bool_vartime()); } #[test] diff --git a/src/int/shr.rs b/src/int/shr.rs index 333da21c..65634222 100644 --- a/src/int/shr.rs +++ b/src/int/shr.rs @@ -227,8 +227,8 @@ mod tests { #[test] fn shr256_const() { - assert!(N.overflowing_shr(256).is_none().is_true_vartime()); - assert!(N.overflowing_shr_vartime(256).is_none().is_true_vartime()); + assert!(N.overflowing_shr(256).is_none().to_bool_vartime()); + assert!(N.overflowing_shr_vartime(256).is_none().to_bool_vartime()); } #[test] diff --git a/src/modular/reduction.rs b/src/modular/reduction.rs index d2d62b32..2a6114e5 100644 --- a/src/modular/reduction.rs +++ b/src/modular/reduction.rs @@ -104,7 +104,7 @@ pub const fn montgomery_retrieve( modulus: &Odd>, mod_neg_inv: Limb, ) -> Uint { - debug_assert!(Uint::lt(montgomery_form, modulus.as_ref()).is_true_vartime()); + debug_assert!(Uint::lt(montgomery_form, modulus.as_ref()).to_bool_vartime()); let mut res = Uint::ZERO; montgomery_retrieve_inner( montgomery_form.as_limbs(), diff --git a/src/modular/safegcd.rs b/src/modular/safegcd.rs index edae287a..65dfd42c 100644 --- a/src/modular/safegcd.rs +++ b/src/modular/safegcd.rs @@ -508,7 +508,7 @@ impl fmt::Debug for SignedInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_fmt(format_args!( "{}0x{}", - if self.sign.is_true_vartime() { + if self.sign.to_bool_vartime() { "-" } else { "+" @@ -534,7 +534,7 @@ impl ConstCtOption>> { #[inline] #[track_caller] pub const fn expect(self, msg: &str) -> Odd> { - assert!(self.is_some().is_true_vartime(), "{}", msg); + assert!(self.is_some().to_bool_vartime(), "{}", msg); *self.components_ref().0 } } diff --git a/src/modular/safegcd/boxed.rs b/src/modular/safegcd/boxed.rs index ca6ff803..3c06f2dc 100644 --- a/src/modular/safegcd/boxed.rs +++ b/src/modular/safegcd/boxed.rs @@ -417,7 +417,7 @@ impl fmt::Debug for SignedBoxedInt { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_fmt(format_args!( "{}0x{}", - if self.sign.is_true_vartime() { + if self.sign.to_bool_vartime() { "-" } else { "+" @@ -443,7 +443,7 @@ impl ConstCtOption> { #[inline] #[track_caller] pub fn expect(self, msg: &str) -> Odd { - assert!(self.is_some().is_true_vartime(), "{}", msg); + assert!(self.is_some().to_bool_vartime(), "{}", msg); self.components_ref().0.clone() } } diff --git a/src/non_zero.rs b/src/non_zero.rs index 9524f79d..80c64a46 100644 --- a/src/non_zero.rs +++ b/src/non_zero.rs @@ -152,7 +152,7 @@ impl NonZero { /// `NonZero::new(…).unwrap()` // TODO: Remove when `Self::new` and `CtOption::unwrap` support `const fn` pub const fn new_unwrap(n: Limb) -> Self { - if n.is_nonzero().is_true_vartime() { + if n.is_nonzero().to_bool_vartime() { Self(n) } else { panic!("Invalid value: zero") @@ -195,7 +195,7 @@ impl NonZeroUint { /// - if the value is zero. // TODO: Remove when `Self::new` and `CtOption::unwrap` support `const fn` pub const fn new_unwrap(n: Uint) -> Self { - if n.is_nonzero().is_true_vartime() { + if n.is_nonzero().to_bool_vartime() { Self(n) } else { panic!("Invalid value: zero") diff --git a/src/odd.rs b/src/odd.rs index 2e694767..3cebda34 100644 --- a/src/odd.rs +++ b/src/odd.rs @@ -91,7 +91,7 @@ impl Odd> { /// Panics if the hex is malformed or not zero-padded accordingly for the size, or if the value is even. pub const fn from_be_hex(hex: &str) -> Self { let uint = Uint::::from_be_hex(hex); - assert!(uint.is_odd().is_true_vartime(), "number must be odd"); + assert!(uint.is_odd().to_bool_vartime(), "number must be odd"); Odd(uint) } @@ -100,7 +100,7 @@ impl Odd> { /// Panics if the hex is malformed or not zero-padded accordingly for the size, or if the value is even. pub const fn from_le_hex(hex: &str) -> Self { let uint = Uint::::from_be_hex(hex); - assert!(uint.is_odd().is_true_vartime(), "number must be odd"); + assert!(uint.is_odd().to_bool_vartime(), "number must be odd"); Odd(uint) } diff --git a/src/uint/bits.rs b/src/uint/bits.rs index dc519f71..e2ed6200 100644 --- a/src/uint/bits.rs +++ b/src/uint/bits.rs @@ -162,13 +162,13 @@ mod tests { #[test] fn bit() { let u = uint_with_bits_at(&[16, 48, 112, 127, 255]); - assert!(!u.bit(0).is_true_vartime()); - assert!(!u.bit(1).is_true_vartime()); - assert!(u.bit(16).is_true_vartime()); - assert!(u.bit(127).is_true_vartime()); - assert!(u.bit(255).is_true_vartime()); - assert!(!u.bit(256).is_true_vartime()); - assert!(!u.bit(260).is_true_vartime()); + assert!(!u.bit(0).to_bool_vartime()); + assert!(!u.bit(1).to_bool_vartime()); + assert!(u.bit(16).to_bool_vartime()); + assert!(u.bit(127).to_bool_vartime()); + assert!(u.bit(255).to_bool_vartime()); + assert!(!u.bit(256).to_bool_vartime()); + assert!(!u.bit(260).to_bool_vartime()); } #[test] diff --git a/src/uint/invert_mod.rs b/src/uint/invert_mod.rs index d4f797dd..afeef7f2 100644 --- a/src/uint/invert_mod.rs +++ b/src/uint/invert_mod.rs @@ -291,10 +291,10 @@ mod tests { // An inverse of an even number does not exist. let a = U256::from(10u64).invert_mod2k(4); - assert!(a.is_none().is_true_vartime()); + assert!(a.is_none().to_bool_vartime()); let a = U256::from(10u64).invert_mod2k_vartime(4); - assert!(a.is_none().is_true_vartime()); + assert!(a.is_none().to_bool_vartime()); // A degenerate case. An inverse mod 2^0 == 1 always exists even for even numbers. @@ -346,7 +346,7 @@ mod tests { // `m` is a multiple of `p1`, so no inverse exists let res = p1.invert_odd_mod(&m); - assert!(res.is_none().is_true_vartime()); + assert!(res.is_none().to_bool_vartime()); } #[test] @@ -391,7 +391,7 @@ mod tests { let m = U64::from(49u64).to_odd().unwrap(); let res = a.invert_odd_mod(&m); - assert!(res.is_none().is_true_vartime()); + assert!(res.is_none().to_bool_vartime()); } #[test] diff --git a/src/uint/ref_type/div.rs b/src/uint/ref_type/div.rs index ba71b320..d8c42758 100644 --- a/src/uint/ref_type/div.rs +++ b/src/uint/ref_type/div.rs @@ -277,7 +277,7 @@ impl UintRef { // This loop is a no-op once xi is smaller than the number of words in the divisor let done = ConstChoice::from_u32_lt(xi as u32, ywords - 1); - if vartime.and(done).is_true_vartime() { + if vartime.and(done).to_bool_vartime() { break; } quo = done.select_word(quo, 0); @@ -450,7 +450,7 @@ impl UintRef { // This loop is a no-op once xi is smaller than the number of words in the divisor let done = ConstChoice::from_u32_lt(xi as u32, ywords - 1); - if vartime.and(done).is_true_vartime() { + if vartime.and(done).to_bool_vartime() { break; } quo = done.select_word(quo, 0); diff --git a/src/uint/shl.rs b/src/uint/shl.rs index 8a65d63b..8b46bba0 100644 --- a/src/uint/shl.rs +++ b/src/uint/shl.rs @@ -321,8 +321,8 @@ mod tests { #[test] fn shl256_const() { - assert!(N.overflowing_shl(256).is_none().is_true_vartime()); - assert!(N.overflowing_shl_vartime(256).is_none().is_true_vartime()); + assert!(N.overflowing_shl(256).is_none().to_bool_vartime()); + assert!(N.overflowing_shl_vartime(256).is_none().to_bool_vartime()); } #[test] @@ -361,7 +361,7 @@ mod tests { assert!( Uint::overflowing_shl_vartime_wide((U128::MAX, U128::MAX), 256) .is_none() - .is_true_vartime(), + .to_bool_vartime(), ); } diff --git a/src/uint/shr.rs b/src/uint/shr.rs index f39e4d57..32355b76 100644 --- a/src/uint/shr.rs +++ b/src/uint/shr.rs @@ -306,8 +306,8 @@ mod tests { #[test] fn shr256_const() { - assert!(N.overflowing_shr(256).is_none().is_true_vartime()); - assert!(N.overflowing_shr_vartime(256).is_none().is_true_vartime()); + assert!(N.overflowing_shr(256).is_none().to_bool_vartime()); + assert!(N.overflowing_shr_vartime(256).is_none().to_bool_vartime()); } #[test] @@ -337,7 +337,7 @@ mod tests { assert!( Uint::overflowing_shr_vartime_wide((U128::MAX, U128::MAX), 256) .is_none() - .is_true_vartime() + .to_bool_vartime() ); }