From 888373bcb6b0c145b05e3912337f3638ae3e4f9c Mon Sep 17 00:00:00 2001 From: koreaygj Date: Mon, 1 Dec 2025 18:54:58 +0900 Subject: [PATCH] feat: Add global fast_math flag to CudaBuilder - Add fast_math field to CudaBuilder struct - Enables ftz, fast_sqrt, fast_div and fma_contraction(fmad) internally - Provides convenient parity with NVCC's --use_fast_math --- crates/cuda_builder/src/lib.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/crates/cuda_builder/src/lib.rs b/crates/cuda_builder/src/lib.rs index f54775c9..4616a8e0 100644 --- a/crates/cuda_builder/src/lib.rs +++ b/crates/cuda_builder/src/lib.rs @@ -148,6 +148,10 @@ pub struct CudaBuilder { /// Enable FMA (fused multiply-add) contraction. /// `true` by default. pub fma_contraction: bool, + /// Enable fast math approximations globally (equivalent to NVCC's --use_fast_math). + /// This implies ftz=true, prec-div=false, prec-sqrt=false, and fmad=true. + /// `false` by default. + pub fast_math: bool, /// Whether to emit a certain IR. Emitting LLVM IR is useful to debug any codegen /// issues. If you are submitting a bug report try to include the LLVM IR file of /// the program that contains the offending function. @@ -206,6 +210,7 @@ impl CudaBuilder { nvvm_opts: true, arch: NvvmArch::default(), ftz: false, + fast_math: false, fast_sqrt: false, fast_div: false, fma_contraction: true, @@ -266,6 +271,19 @@ impl CudaBuilder { self } + /// Enable fast math approximations globally (equivalent to NVCC's --use_fast_math). + /// This implies ftz=true, prec-div=false, prec-sqrt=false, and fmad=true. + pub fn fast_math(mut self, fast_math: bool) -> Self { + self.fast_math = fast_math; + if fast_math { + self.ftz = true; + self.fast_sqrt = true; + self.fast_div = true; + self.fma_contraction = true; + } + self + } + /// Use a fast approximation for single-precision floating point square root. pub fn fast_sqrt(mut self, fast_sqrt: bool) -> Self { self.fast_sqrt = fast_sqrt; @@ -725,6 +743,10 @@ fn invoke_rustc(builder: &CudaBuilder) -> Result { llvm_args.push("-ftz=1".to_string()); } + if builder.fast_math { + llvm_args.push("--use_fast_math".to_string()); + } + if builder.fast_sqrt { llvm_args.push("-prec-sqrt=0".to_string()); }