From 4ff9d324fe1518104bd2bfda8951dc493fead407 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Sat, 30 May 2026 16:23:28 +0300 Subject: [PATCH 1/4] Implement pointwise relative error bound for BitRound --- codecs/bit-round/Cargo.toml | 7 +- codecs/bit-round/docs | 1 + codecs/bit-round/src/lib.rs | 305 ++++++++++++++++++++++++----- codecs/bit-round/tests/schema.json | 57 ++++++ codecs/bit-round/tests/schema.rs | 20 ++ 5 files changed, 340 insertions(+), 50 deletions(-) create mode 120000 codecs/bit-round/docs create mode 100644 codecs/bit-round/tests/schema.json create mode 100644 codecs/bit-round/tests/schema.rs diff --git a/codecs/bit-round/Cargo.toml b/codecs/bit-round/Cargo.toml index 19975dda6..2006b421b 100644 --- a/codecs/bit-round/Cargo.toml +++ b/codecs/bit-round/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "numcodecs-bit-round" -version = "0.4.0" +version = "0.5.0" edition = { workspace = true } authors = { workspace = true } repository = { workspace = true } @@ -12,6 +12,8 @@ readme = "README.md" categories = ["compression", "encoding"] keywords = ["bit-rounding", "numcodecs", "compression", "encoding"] +include = ["/src", "/LICENSE", "/docs/katex.html"] + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] @@ -24,5 +26,8 @@ thiserror = { workspace = true } [lints] workspace = true +[package.metadata.docs.rs] +rustdoc-args = ["--html-in-header", "./docs/katex.html"] + [package.metadata.numcodecs-wasm] version = "0.2.2" # wasi 0.2.6 diff --git a/codecs/bit-round/docs b/codecs/bit-round/docs new file mode 120000 index 000000000..713700f74 --- /dev/null +++ b/codecs/bit-round/docs @@ -0,0 +1 @@ +../../docs/rs/ \ No newline at end of file diff --git a/codecs/bit-round/src/lib.rs b/codecs/bit-round/src/lib.rs index 5b9fa2364..aa110e033 100644 --- a/codecs/bit-round/src/lib.rs +++ b/codecs/bit-round/src/lib.rs @@ -17,13 +17,15 @@ //! //! Bit rounding codec implementation for the [`numcodecs`] API. +use std::{borrow::Cow, num::FpCategory}; + use ndarray::{Array, ArrayBase, Data, Dimension}; use numcodecs::{ AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray, Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion, }; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use thiserror::Error; #[derive(Clone, Serialize, Deserialize, JsonSchema)] @@ -38,16 +40,37 @@ use thiserror::Error; /// The approach is based on the paper by Klöwer et al. 2021 /// (). pub struct BitRoundCodec { - /// The number of bits of the mantissa to keep. - /// - /// The valid range depends on the dtype of the input data. - /// - /// If keepbits is equal to the bitlength of the dtype's mantissa, no - /// transformation is performed. - pub keepbits: u8, + /// Bit rounding mode. + #[serde(flatten)] + pub mode: BitRoundMode, /// The codec's encoding format version. Do not provide this parameter explicitly. #[serde(default, rename = "_version")] - pub version: StaticCodecVersion<1, 0, 0>, + pub version: StaticCodecVersion<2, 0, 0>, +} + +#[derive(Clone, Copy, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "mode")] +#[serde(deny_unknown_fields)] +/// Bit rounding mode +pub enum BitRoundMode { + /// Directly specify the number of bits of the mantissa to keep. + Keepbits { + /// The number of bits of the mantissa to keep. + /// + /// The valid range depends on the dtype of the input data. + /// + /// If keepbits is equal to the bitlength of the dtype's mantissa, no + /// transformation is performed. + keepbits: u8, + }, + /// Pointwise relative error. + RelativeError { + /// The pointwise relative error bound to preserve. + /// + /// This error bound guarantees that + /// `$|x - \hat{x}| \leq |x| \cdot \epsilon_{rel}$`. + eb_rel: NonNegative, + }, } impl Codec for BitRoundCodec { @@ -55,8 +78,8 @@ impl Codec for BitRoundCodec { fn encode(&self, data: AnyCowArray) -> Result { match data { - AnyCowArray::F32(data) => Ok(AnyArray::F32(bit_round(data, self.keepbits)?)), - AnyCowArray::F64(data) => Ok(AnyArray::F64(bit_round(data, self.keepbits)?)), + AnyCowArray::F32(data) => Ok(AnyArray::F32(bit_round(data, &self.mode)?)), + AnyCowArray::F64(data) => Ok(AnyArray::F64(bit_round(data, &self.mode)?)), encoded => Err(BitRoundCodecError::UnsupportedDtype(encoded.dtype())), } } @@ -119,6 +142,49 @@ pub enum BitRoundCodecError { }, } +#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq +#[derive(Copy, Clone, PartialEq, PartialOrd, Hash)] +/// Non-negative floating point number +pub struct NonNegative(T); + +impl Serialize for NonNegative { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_f64(self.0) + } +} + +impl<'de> Deserialize<'de> for NonNegative { + fn deserialize>(deserializer: D) -> Result { + let x = f64::deserialize(deserializer)?; + + if x >= 0.0 { + Ok(Self(x)) + } else { + Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Float(x), + &"a non-negative value", + )) + } + } +} + +impl JsonSchema for NonNegative { + fn schema_name() -> Cow<'static, str> { + Cow::Borrowed("NonNegativeF64") + } + + fn schema_id() -> Cow<'static, str> { + Cow::Borrowed(concat!(module_path!(), "::", "NonNegative")) + } + + fn json_schema(_gen: &mut SchemaGenerator) -> Schema { + json_schema!({ + "type": "number", + "minimum": 0.0 + }) + } +} + /// Floating-point bit rounding, which drops the specified number of bits from /// the floating point mantissa. /// @@ -131,31 +197,52 @@ pub enum BitRoundCodecError { /// [`T::MANITSSA_BITS`][`Float::MANITSSA_BITS`]. pub fn bit_round, D: Dimension>( data: ArrayBase, - keepbits: u8, + mode: &BitRoundMode, ) -> Result, BitRoundCodecError> { - if u32::from(keepbits) > T::MANITSSA_BITS { - return Err(BitRoundCodecError::ExcessiveKeepBits { - keepbits, - dtype: T::TY, - }); - } + let (keepbits, keep_non_normal) = match mode { + BitRoundMode::Keepbits { keepbits } => { + let keepbits = *keepbits; + if u32::from(keepbits) > T::MANITSSA_BITS { + return Err(BitRoundCodecError::ExcessiveKeepBits { + keepbits, + dtype: T::TY, + }); + } + (u32::from(keepbits), false) + } + BitRoundMode::RelativeError { eb_rel } => { + #[expect(clippy::cast_possible_truncation)] + // no truncation since the exponent is in [-1074, 1024] + let keepbits = (-eb_rel.0.log2().floor() as i64) - 1; + // keepbits must be within the range of the mantissa bits of single precision. + #[expect(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + // no sign loss or truncation since we clamp to between 0 and a u32 + let keepbits = keepbits.clamp(0, i64::from(T::MANITSSA_BITS)) as u32; + (keepbits, true) + } + }; let mut encoded = data.into_owned(); // Early return if no bit rounding needs to happen // - required since the ties to even impl does not work in this case - if u32::from(keepbits) == T::MANITSSA_BITS { + if keepbits == T::MANITSSA_BITS { return Ok(encoded); } // half of unit in last place (ulp) - let ulp_half = T::MANTISSA_MASK >> (u32::from(keepbits) + 1); + let ulp_half = T::MANTISSA_MASK >> (keepbits + 1); // mask to zero out trailing mantissa bits - let keep_mask = !(T::MANTISSA_MASK >> u32::from(keepbits)); + let keep_mask = !(T::MANTISSA_MASK >> keepbits); // shift to extract the least significant bit of the exponent - let shift = T::MANITSSA_BITS - u32::from(keepbits); + let shift = T::MANITSSA_BITS - keepbits; encoded.mapv_inplace(|x| { + // subnormal, infinite, and NaN values are + if keep_non_normal && !x.is_normal_or_zero() { + return x; + } + let mut bits = T::to_binary(x); // add ulp/2 with ties to even @@ -195,6 +282,9 @@ pub trait Float: Sized + Copy { fn to_binary(self) -> Self::Binary; /// Bit-cast the binary representation into a floating point value fn from_binary(u: Self::Binary) -> Self; + + /// Returns true if the number is neither infinite, subnormal, or NaN + fn is_normal_or_zero(self) -> bool; } impl Float for f32 { @@ -212,6 +302,10 @@ impl Float for f32 { fn from_binary(u: Self::Binary) -> Self { Self::from_bits(u) } + + fn is_normal_or_zero(self) -> bool { + matches!(self.classify(), FpCategory::Normal | FpCategory::Zero) + } } impl Float for f64 { @@ -229,6 +323,10 @@ impl Float for f64 { fn from_binary(u: Self::Binary) -> Self { Self::from_bits(u) } + + fn is_normal_or_zero(self) -> bool { + matches!(self.classify(), FpCategory::Normal | FpCategory::Zero) + } } #[cfg(test)] @@ -239,108 +337,205 @@ mod tests { use super::*; #[test] + #[expect(clippy::too_many_lines)] fn no_mantissa() { assert_eq!( - bit_round(ArrayView1::from(&[0.0_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[0.0_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![0.0_f32]) ); assert_eq!( - bit_round(ArrayView1::from(&[1.0_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[1.0_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![1.0_f32]) ); // tie to even rounds up as the offset exponent is odd assert_eq!( - bit_round(ArrayView1::from(&[1.5_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[1.5_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![2.0_f32]) ); assert_eq!( - bit_round(ArrayView1::from(&[2.0_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[2.0_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![2.0_f32]) ); assert_eq!( - bit_round(ArrayView1::from(&[2.5_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[2.5_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![2.0_f32]) ); // tie to even rounds down as the offset exponent is even assert_eq!( - bit_round(ArrayView1::from(&[3.0_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[3.0_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![2.0_f32]) ); assert_eq!( - bit_round(ArrayView1::from(&[3.5_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[3.5_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![4.0_f32]) ); assert_eq!( - bit_round(ArrayView1::from(&[4.0_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[4.0_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![4.0_f32]) ); assert_eq!( - bit_round(ArrayView1::from(&[5.0_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[5.0_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![4.0_f32]) ); // tie to even rounds up as the offset exponent is odd assert_eq!( - bit_round(ArrayView1::from(&[6.0_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[6.0_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![8.0_f32]) ); assert_eq!( - bit_round(ArrayView1::from(&[7.0_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[7.0_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![8.0_f32]) ); assert_eq!( - bit_round(ArrayView1::from(&[8.0_f32]), 0).unwrap(), + bit_round( + ArrayView1::from(&[8.0_f32]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![8.0_f32]) ); assert_eq!( - bit_round(ArrayView1::from(&[0.0_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[0.0_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![0.0_f64]) ); assert_eq!( - bit_round(ArrayView1::from(&[1.0_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[1.0_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![1.0_f64]) ); // tie to even rounds up as the offset exponent is odd assert_eq!( - bit_round(ArrayView1::from(&[1.5_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[1.5_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![2.0_f64]) ); assert_eq!( - bit_round(ArrayView1::from(&[2.0_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[2.0_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![2.0_f64]) ); assert_eq!( - bit_round(ArrayView1::from(&[2.5_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[2.5_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![2.0_f64]) ); // tie to even rounds down as the offset exponent is even assert_eq!( - bit_round(ArrayView1::from(&[3.0_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[3.0_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![2.0_f64]) ); assert_eq!( - bit_round(ArrayView1::from(&[3.5_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[3.5_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![4.0_f64]) ); assert_eq!( - bit_round(ArrayView1::from(&[4.0_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[4.0_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![4.0_f64]) ); assert_eq!( - bit_round(ArrayView1::from(&[5.0_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[5.0_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![4.0_f64]) ); // tie to even rounds up as the offset exponent is odd assert_eq!( - bit_round(ArrayView1::from(&[6.0_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[6.0_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![8.0_f64]) ); assert_eq!( - bit_round(ArrayView1::from(&[7.0_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[7.0_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![8.0_f64]) ); assert_eq!( - bit_round(ArrayView1::from(&[8.0_f64]), 0).unwrap(), + bit_round( + ArrayView1::from(&[8.0_f64]), + &BitRoundMode::Keepbits { keepbits: 0 } + ) + .unwrap(), Array1::from_vec(vec![8.0_f64]) ); } @@ -354,14 +549,26 @@ mod tests { for v in [0.0_f32, 1.0_f32, 2.0_f32, 3.0_f32, 4.0_f32] { assert_eq!( - bit_round(ArrayView1::from(&[full(v)]), f32::MANITSSA_BITS as u8).unwrap(), + bit_round( + ArrayView1::from(&[full(v)]), + &BitRoundMode::Keepbits { + keepbits: f32::MANITSSA_BITS as u8 + } + ) + .unwrap(), Array1::from_vec(vec![full(v)]) ); } for v in [0.0_f64, 1.0_f64, 2.0_f64, 3.0_f64, 4.0_f64] { assert_eq!( - bit_round(ArrayView1::from(&[full(v)]), f64::MANITSSA_BITS as u8).unwrap(), + bit_round( + ArrayView1::from(&[full(v)]), + &BitRoundMode::Keepbits { + keepbits: f64::MANITSSA_BITS as u8 + } + ) + .unwrap(), Array1::from_vec(vec![full(v)]) ); } diff --git a/codecs/bit-round/tests/schema.json b/codecs/bit-round/tests/schema.json new file mode 100644 index 000000000..a516039bf --- /dev/null +++ b/codecs/bit-round/tests/schema.json @@ -0,0 +1,57 @@ +{ + "type": "object", + "unevaluatedProperties": false, + "oneOf": [ + { + "type": "object", + "description": "Directly specify the number of bits of the mantissa to keep.", + "properties": { + "keepbits": { + "type": "integer", + "format": "uint8", + "minimum": 0, + "maximum": 255, + "description": "The number of bits of the mantissa to keep.\n\nThe valid range depends on the dtype of the input data.\n\nIf keepbits is equal to the bitlength of the dtype's mantissa, no\ntransformation is performed." + }, + "mode": { + "type": "string", + "const": "Keepbits" + } + }, + "required": [ + "mode", + "keepbits" + ] + }, + { + "type": "object", + "description": "Pointwise relative error.", + "properties": { + "eb_rel": { + "type": "number", + "minimum": 0.0, + "description": "The pointwise relative error bound to preserve.\n\nThis error bound guarantees that\n`$|x - \\hat{x}| \\leq |x| \\cdot \\epsilon_{rel}$`." + }, + "mode": { + "type": "string", + "const": "RelativeError" + } + }, + "required": [ + "mode", + "eb_rel" + ] + } + ], + "description": "Codec providing floating-point bit rounding.\n\nDrops the specified number of bits from the floating point mantissa,\nleaving an array that is more amenable to compression. The number of\nbits to keep should be determined by information analysis of the data\nto be compressed.\n\nThe approach is based on the paper by Klöwer et al. 2021\n().", + "properties": { + "_version": { + "type": "string", + "pattern": "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$", + "description": "The codec's encoding format version. Do not provide this parameter explicitly.", + "default": "2.0.0" + } + }, + "title": "BitRoundCodec", + "$schema": "https://json-schema.org/draft/2020-12/schema" +} \ No newline at end of file diff --git a/codecs/bit-round/tests/schema.rs b/codecs/bit-round/tests/schema.rs new file mode 100644 index 000000000..da345a568 --- /dev/null +++ b/codecs/bit-round/tests/schema.rs @@ -0,0 +1,20 @@ +#![expect(missing_docs)] + +use ::{ndarray as _, schemars as _, serde as _, thiserror as _}; + +use numcodecs::{DynCodecType, StaticCodecType}; +use numcodecs_bit_round::BitRoundCodec; + +#[test] +fn schema() { + let schema = format!( + "{:#}", + StaticCodecType::::of() + .codec_config_schema() + .to_value() + ); + + if schema != include_str!("schema.json") { + panic!("BitRound schema has changed\n===\n{schema}\n==="); + } +} From e496779a42ce6080ea71fe445c34cc33c0bd8fff Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Sat, 30 May 2026 21:06:44 +0300 Subject: [PATCH 2/4] Add absolute error bound --- codecs/bit-round/src/lib.rs | 151 +++++++++++++++++++++++------ codecs/bit-round/tests/schema.json | 19 ++++ 2 files changed, 140 insertions(+), 30 deletions(-) diff --git a/codecs/bit-round/src/lib.rs b/codecs/bit-round/src/lib.rs index aa110e033..3b379cb2a 100644 --- a/codecs/bit-round/src/lib.rs +++ b/codecs/bit-round/src/lib.rs @@ -17,7 +17,7 @@ //! //! Bit rounding codec implementation for the [`numcodecs`] API. -use std::{borrow::Cow, num::FpCategory}; +use std::borrow::Cow; use ndarray::{Array, ArrayBase, Data, Dimension}; use numcodecs::{ @@ -63,6 +63,14 @@ pub enum BitRoundMode { /// transformation is performed. keepbits: u8, }, + /// Pointwise absolute error. + AbsoluteError { + /// The pointwise absolute error bound to preserve. + /// + /// This error bound guarantees that + /// `$|x - \hat{x}| \leq \epsilon_{abs}$`. + eb_abs: NonNegative, + }, /// Pointwise relative error. RelativeError { /// The pointwise relative error bound to preserve. @@ -210,16 +218,27 @@ pub fn bit_round, D: Dimension>( } (u32::from(keepbits), false) } - BitRoundMode::RelativeError { eb_rel } => { - #[expect(clippy::cast_possible_truncation)] - // no truncation since the exponent is in [-1074, 1024] - let keepbits = (-eb_rel.0.log2().floor() as i64) - 1; - // keepbits must be within the range of the mantissa bits of single precision. - #[expect(clippy::cast_sign_loss, clippy::cast_possible_truncation)] - // no sign loss or truncation since we clamp to between 0 and a u32 - let keepbits = keepbits.clamp(0, i64::from(T::MANITSSA_BITS)) as u32; - (keepbits, true) + BitRoundMode::AbsoluteError { eb_abs } => { + let eb_abs = T::from_f64(eb_abs.0); + + let mut encoded = data.into_owned(); + + encoded.mapv_inplace(|x| { + // subnormal, infinite, and NaN values are hard so just keep + // them as is + if !x.is_normal() { + return x; + } + + let keepbits = BitRounder::keepbits_from_eb_rel(NonNegative(eb_abs / x.abs())); + let bit_round = BitRounder::new(keepbits); + + bit_round.apply(x) + }); + + return Ok(encoded); } + BitRoundMode::RelativeError { eb_rel } => (BitRounder::keepbits_from_eb_rel(*eb_rel), true), }; let mut encoded = data.into_owned(); @@ -230,35 +249,68 @@ pub fn bit_round, D: Dimension>( return Ok(encoded); } - // half of unit in last place (ulp) - let ulp_half = T::MANTISSA_MASK >> (keepbits + 1); - // mask to zero out trailing mantissa bits - let keep_mask = !(T::MANTISSA_MASK >> keepbits); - // shift to extract the least significant bit of the exponent - let shift = T::MANITSSA_BITS - keepbits; + let bit_round = BitRounder::new(keepbits); encoded.mapv_inplace(|x| { - // subnormal, infinite, and NaN values are - if keep_non_normal && !x.is_normal_or_zero() { + // subnormal, infinite, and NaN values are hard so just keep them as is + if keep_non_normal && !x.is_normal() { return x; } + bit_round.apply(x) + }); + + Ok(encoded) +} + +struct BitRounder { + ulp_half: T::Binary, + keep_mask: T::Binary, + shift: u32, +} + +impl BitRounder { + #[inline] + fn new(keepbits: u32) -> Self { + // half of unit in last place (ulp) + let ulp_half = T::MANTISSA_MASK >> (keepbits + 1); + // mask to zero out trailing mantissa bits + let keep_mask = !(T::MANTISSA_MASK >> keepbits); + // shift to extract the least significant bit of the exponent + let shift = T::MANITSSA_BITS - keepbits; + + Self { + ulp_half, + keep_mask, + shift, + } + } + + fn keepbits_from_eb_rel(eb_rel: NonNegative) -> u32 { + let keepbits = -eb_rel.0.log2_floor() - 1; + // keepbits must be within the range of the mantissa bits of single precision. + #[expect(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + // no sign loss or truncation since we clamp to between 0 and a u32 + let keepbits = keepbits.clamp(0, i64::from(T::MANITSSA_BITS)) as u32; + keepbits + } + + #[inline] + fn apply(&self, x: T) -> T { let mut bits = T::to_binary(x); // add ulp/2 with ties to even - bits += ulp_half + ((bits >> shift) & T::BINARY_ONE); + bits += self.ulp_half + ((bits >> self.shift) & T::BINARY_ONE); // set the trailing bits to zero - bits &= keep_mask; + bits &= self.keep_mask; T::from_binary(bits) - }); - - Ok(encoded) + } } /// Floating point types. -pub trait Float: Sized + Copy { +pub trait Float: Sized + Copy + std::ops::Div { /// Number of significant digits in base 2 const MANITSSA_BITS: u32; /// Binary mask to extract only the mantissa bits @@ -283,8 +335,18 @@ pub trait Float: Sized + Copy { /// Bit-cast the binary representation into a floating point value fn from_binary(u: Self::Binary) -> Self; - /// Returns true if the number is neither infinite, subnormal, or NaN - fn is_normal_or_zero(self) -> bool; + /// Returns the floating point category of the number + fn is_normal(self) -> bool; + + /// Returns the floor of the base-2 logarithm as a signed integer + fn log2_floor(self) -> i64; + + /// Computes the absolute value + #[must_use] + fn abs(self) -> Self; + + /// Convert from an [`f64`] value + fn from_f64(x: f64) -> Self; } impl Float for f32 { @@ -303,8 +365,23 @@ impl Float for f32 { Self::from_bits(u) } - fn is_normal_or_zero(self) -> bool { - matches!(self.classify(), FpCategory::Normal | FpCategory::Zero) + fn is_normal(self) -> bool { + self.is_normal() + } + + #[expect(clippy::cast_possible_truncation)] + fn log2_floor(self) -> i64 { + // no truncation since the exponent is in [-149, 128] + -self.log2().floor() as i64 + } + + fn abs(self) -> Self { + self.abs() + } + + #[expect(clippy::cast_possible_truncation)] + fn from_f64(x: f64) -> Self { + x as Self } } @@ -324,8 +401,22 @@ impl Float for f64 { Self::from_bits(u) } - fn is_normal_or_zero(self) -> bool { - matches!(self.classify(), FpCategory::Normal | FpCategory::Zero) + fn is_normal(self) -> bool { + self.is_normal() + } + + #[expect(clippy::cast_possible_truncation)] + fn log2_floor(self) -> i64 { + // no truncation since the exponent is in [-1074, 1024] + self.log2().floor() as i64 + } + + fn abs(self) -> Self { + self.abs() + } + + fn from_f64(x: f64) -> Self { + x } } diff --git a/codecs/bit-round/tests/schema.json b/codecs/bit-round/tests/schema.json index a516039bf..6e8b2ad02 100644 --- a/codecs/bit-round/tests/schema.json +++ b/codecs/bit-round/tests/schema.json @@ -23,6 +23,25 @@ "keepbits" ] }, + { + "type": "object", + "description": "Pointwise absolute error.", + "properties": { + "eb_abs": { + "type": "number", + "minimum": 0.0, + "description": "The pointwise absolute error bound to preserve.\n\nThis error bound guarantees that\n`$|x - \\hat{x}| \\leq \\epsilon_{abs}$`." + }, + "mode": { + "type": "string", + "const": "AbsoluteError" + } + }, + "required": [ + "mode", + "eb_abs" + ] + }, { "type": "object", "description": "Pointwise relative error.", From 7b4de89cacad846f65f7b42e9a54b3d1ba9472e5 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Sat, 30 May 2026 22:12:18 +0300 Subject: [PATCH 3/4] small cleanup --- codecs/bit-round/src/lib.rs | 5 ++++- codecs/bit-round/tests/schema.json | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/codecs/bit-round/src/lib.rs b/codecs/bit-round/src/lib.rs index 3b379cb2a..a07497d8e 100644 --- a/codecs/bit-round/src/lib.rs +++ b/codecs/bit-round/src/lib.rs @@ -29,7 +29,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use thiserror::Error; #[derive(Clone, Serialize, Deserialize, JsonSchema)] -#[serde(deny_unknown_fields)] +#[schemars(deny_unknown_fields)] /// Codec providing floating-point bit rounding. /// /// Drops the specified number of bits from the floating point mantissa, @@ -54,6 +54,7 @@ pub struct BitRoundCodec { /// Bit rounding mode pub enum BitRoundMode { /// Directly specify the number of bits of the mantissa to keep. + #[serde(rename = "keepbits")] Keepbits { /// The number of bits of the mantissa to keep. /// @@ -64,6 +65,7 @@ pub enum BitRoundMode { keepbits: u8, }, /// Pointwise absolute error. + #[serde(rename = "abs")] AbsoluteError { /// The pointwise absolute error bound to preserve. /// @@ -72,6 +74,7 @@ pub enum BitRoundMode { eb_abs: NonNegative, }, /// Pointwise relative error. + #[serde(rename = "rel")] RelativeError { /// The pointwise relative error bound to preserve. /// diff --git a/codecs/bit-round/tests/schema.json b/codecs/bit-round/tests/schema.json index 6e8b2ad02..56e10364d 100644 --- a/codecs/bit-round/tests/schema.json +++ b/codecs/bit-round/tests/schema.json @@ -15,7 +15,7 @@ }, "mode": { "type": "string", - "const": "Keepbits" + "const": "keepbits" } }, "required": [ @@ -34,7 +34,7 @@ }, "mode": { "type": "string", - "const": "AbsoluteError" + "const": "abs" } }, "required": [ @@ -53,7 +53,7 @@ }, "mode": { "type": "string", - "const": "RelativeError" + "const": "rel" } }, "required": [ From ec39e16b7339c5679ce45326d87891b6dc2dc588 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Sat, 30 May 2026 23:33:48 +0300 Subject: [PATCH 4/4] use binary ops for normal log2 floor --- codecs/bit-round/src/lib.rs | 54 +++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/codecs/bit-round/src/lib.rs b/codecs/bit-round/src/lib.rs index a07497d8e..8bbb482ca 100644 --- a/codecs/bit-round/src/lib.rs +++ b/codecs/bit-round/src/lib.rs @@ -290,11 +290,11 @@ impl BitRounder { } fn keepbits_from_eb_rel(eb_rel: NonNegative) -> u32 { - let keepbits = -eb_rel.0.log2_floor() - 1; + let keepbits = -(eb_rel.0.normal_log2_floor()) - 1; // keepbits must be within the range of the mantissa bits of single precision. #[expect(clippy::cast_sign_loss, clippy::cast_possible_truncation)] // no sign loss or truncation since we clamp to between 0 and a u32 - let keepbits = keepbits.clamp(0, i64::from(T::MANITSSA_BITS)) as u32; + let keepbits = i64::from(keepbits).clamp(0, i64::from(T::MANITSSA_BITS)) as u32; keepbits } @@ -342,7 +342,7 @@ pub trait Float: Sized + Copy + std::ops::Div { fn is_normal(self) -> bool; /// Returns the floor of the base-2 logarithm as a signed integer - fn log2_floor(self) -> i64; + fn normal_log2_floor(self) -> i16; /// Computes the absolute value #[must_use] @@ -372,10 +372,8 @@ impl Float for f32 { self.is_normal() } - #[expect(clippy::cast_possible_truncation)] - fn log2_floor(self) -> i64 { - // no truncation since the exponent is in [-149, 128] - -self.log2().floor() as i64 + fn normal_log2_floor(self) -> i16 { + (((self.to_bits() >> 23) & 0xff) as i16) - 127 } fn abs(self) -> Self { @@ -408,10 +406,8 @@ impl Float for f64 { self.is_normal() } - #[expect(clippy::cast_possible_truncation)] - fn log2_floor(self) -> i64 { - // no truncation since the exponent is in [-1074, 1024] - self.log2().floor() as i64 + fn normal_log2_floor(self) -> i16 { + (((self.to_bits() >> 52) & 0x7ff) as i16) - 1023 } fn abs(self) -> Self { @@ -667,4 +663,40 @@ mod tests { ); } } + + #[test] + fn normal_log2_floor_f32() { + for e in -100_i16..100 { + let b = f32::from(e).exp2(); + for f in [0.55, 0.75, 0.9, 1.0, 1.1, 1.5, 1.95] { + let x = b * f; + + #[expect(clippy::cast_possible_truncation)] + let math = x.log2().floor() as i16; + let binary = x.normal_log2_floor(); + + assert_eq!(math, binary, "{x}"); + } + } + + assert_eq!(i32::from(0.0_f32.normal_log2_floor()), f32::MIN_EXP - 2); + } + + #[test] + fn normal_log2_floor_f64() { + for e in -100_i32..100 { + let b = f64::from(e).exp2(); + for f in [0.55, 0.75, 0.9, 1.0, 1.1, 1.5, 1.95] { + let x = b * f; + + #[expect(clippy::cast_possible_truncation)] + let math = x.log2().floor() as i16; + let binary = x.normal_log2_floor(); + + assert_eq!(math, binary, "{x}"); + } + } + + assert_eq!(i32::from(0.0_f64.normal_log2_floor()), f64::MIN_EXP - 2); + } }