diff --git a/src/utils.rs b/src/utils.rs index 371d2d9a6..1b582354d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -426,17 +426,43 @@ impl SimdAngularInertia for SdpMatrix3 { // to zero, and automatically resetting previous flags once it is dropped. #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) struct FlushToZeroDenormalsAreZeroFlags { + #[cfg(any( + feature = "enhanced-determinism", + not(any( + target_arch = "aarch64", + all( + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "sse" + ) + )) + ))] + original_flags: (), + + #[cfg(all( + not(feature = "enhanced-determinism"), + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "sse" + ))] original_flags: u32, + + #[cfg(all(not(feature = "enhanced-determinism"), target_arch = "aarch64"))] + original_flags: u64, } +// Flush denormals & underflows to zero as this as a significant impact on the solver's performances. impl FlushToZeroDenormalsAreZeroFlags { - #[cfg(not(all( - not(feature = "enhanced-determinism"), - any(target_arch = "x86_64", target_arch = "x86"), - target_feature = "sse" - )))] + #[cfg(any( + feature = "enhanced-determinism", + not(any( + target_arch = "aarch64", + all( + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "sse" + ) + )) + ))] pub fn flush_denormal_to_zero() -> Self { - Self { original_flags: 0 } + Self { original_flags: () } } #[cfg(all( @@ -452,7 +478,6 @@ impl FlushToZeroDenormalsAreZeroFlags { #[cfg(target_arch = "x86_64")] use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _mm_getcsr, _mm_setcsr}; - // Flush denormals & underflows to zero as this as a significant impact on the solver's performances. // To enable this we need to set the bit 15 (given by _MM_FLUSH_ZERO_ON) and the bit 6 (for denormals-are-zero). // See https://software.intel.com/content/www/us/en/develop/articles/x87-and-sse-floating-point-assists-in-ia-32-flush-to-zero-ftz-and-denormals-are-zero-daz.html let original_flags = _mm_getcsr(); @@ -460,6 +485,20 @@ impl FlushToZeroDenormalsAreZeroFlags { Self { original_flags } } } + + #[cfg(all(not(feature = "enhanced-determinism"), target_arch = "aarch64"))] + pub fn flush_denormal_to_zero() -> Self { + let mut original_flags: u64; + unsafe { + std::arch::asm!("mrs {}, fpcr", out(reg) original_flags); + // This sets following bits of FPCR (Floating-point Control Register): + // FZ, bit 24 - Flushing denormalized numbers to zero + // FZ16, bit 19 - Enable flushing for half-precision (f16) numbers + // See https://developer.arm.com/documentation/ddi0601/2025-06/AArch64-Registers/FPCR--Floating-point-Control-Register + std::arch::asm!("msr fpcr, {}", in(reg) original_flags | (1 << 24) | (1 << 19)); + } + Self { original_flags } + } } #[cfg(all( @@ -481,6 +520,13 @@ impl Drop for FlushToZeroDenormalsAreZeroFlags { } } +#[cfg(all(not(feature = "enhanced-determinism"), target_arch = "aarch64"))] +impl Drop for FlushToZeroDenormalsAreZeroFlags { + fn drop(&mut self) { + unsafe { std::arch::asm!("msr fpcr, {}", in(reg) self.original_flags) } + } +} + /// This is an RAII structure that disables floating point exceptions while /// it is alive, so that operations which generate NaNs and infinite values /// intentionally will not trip an exception when debugging problematic