diff --git a/src/NextLA.jl b/src/NextLA.jl index 3905d2f..8b0a938 100644 --- a/src/NextLA.jl +++ b/src/NextLA.jl @@ -67,6 +67,7 @@ include("larf.jl") include("larfg.jl") include("larft.jl") include("larfb.jl") +include("lartg.jl") include("unmqr.jl") include("gerc.jl") include("tsqrt.jl") diff --git a/src/lartg.jl b/src/lartg.jl new file mode 100644 index 0000000..f05382d --- /dev/null +++ b/src/lartg.jl @@ -0,0 +1,112 @@ +""" + lartg(f::R, g::S) where {R,S} + +Generate a plane rotation (Givens rotation) such that: + + [ c s ]' * [ f ] = [ r ] + [-s c ] [ g ] [ 0 ] + +where `c` is real and `s` may be real or complex. The scalar `r` has the +phase of `f` when `f` is nonzero. + +# Arguments +- `f`: Scalar element (real or complex) +- `g`: Scalar element (real or complex) + +# Returns +- `c`: Real cosine of the rotation +- `s`: Sine of the rotation (real or complex) +- `r`: Resulting scalar after applying the rotation + +# Algorithm +The implementation follows LAPACK's scaling strategy to avoid +over/underflow when computing norms. For complex inputs, `c` is always +real and `s` carries the phase so that `r` aligns with `f`. + +Special cases: +- If `g == 0`, then `c = 1`, `s = 0`, `r = f` +- If `f == 0`, then `c = 0`, `s = conj(g)/abs(g)`, `r = abs(g)` + +# Note +This is a low-level LAPACK-style computational routine. Input validation +should be performed by higher-level interfaces. +""" +function lartg(f::R, g::S) where {R,S} + T = promote_type(R, S) + RT = real(T) + + f = convert(T, f) + g = convert(T, g) + + sfmin = lamch(RT, 'S') + sfmax = one(RT) / sfmin + rtmin = sqrt(sfmin) + rtmax = one(RT) / rtmin + + if iszero(g) + return one(RT), zero(T), f + end + + if iszero(f) + gmax = max(abs(real(g)), abs(imag(g))) + if rtmin < gmax < rtmax + c = zero(RT) + s = g / abs(g) + r = convert(T, abs(g)) + return c, s, r + else + u = min(sfmax, max(sfmin, gmax)) + gs = g / u + c = zero(RT) + s = gs / abs(gs) + r = convert(T, abs(gs) * u) + return c, s, r + end + end + + fmax = max(abs(real(f)), abs(imag(f))) + gmax = max(abs(real(g)), abs(imag(g))) + + if (rtmin < fmax < rtmax) && (rtmin < gmax < rtmax) + # unscaled algorithm + f2 = abs2(f) + g2 = abs2(g) + h2 = f2 + g2 + + d = (f2 > rtmin && h2 < rtmax) ? sqrt(f2 * h2) : sqrt(f2) * sqrt(h2) + p = inv(d) + + c = convert(RT, f2 * p) + s = conj(g) * (f * p) + r = f * (h2 * p) + return c, s, r + else + # scaled algorithm + u = min(sfmax, max(sfmin, fmax, gmax)) + gs = g / u + g2 = abs2(gs) + + if fmax / u < rtmin + # different scalings for f and g + v = min(sfmax, max(sfmin, fmax)) + w = v / u + fs = f / v + f2 = abs2(fs) + h2 = f2 * (w * w) + g2 + else + # same scaling for f and g + w = one(RT) + fs = f / u + f2 = abs2(fs) + h2 = f2 + g2 + end + + d = (f2 > rtmin && h2 < rtmax) ? sqrt(f2 * h2) : sqrt(f2) * sqrt(h2) + p = inv(d) + + c = convert(RT, (f2 * p) * w) + s = conj(gs) * (fs * p) + r = (fs * (h2 * p)) * u + return c, s, r + end +end diff --git a/test/lapack_helpers.jl b/test/lapack_helpers.jl index c71e09b..be24603 100644 --- a/test/lapack_helpers.jl +++ b/test/lapack_helpers.jl @@ -33,6 +33,27 @@ for (elty, func) in ((Float64, :dlarfg_), end end +# ── xLARTG — Givens rotation generation ───────────────────────────────────── +# Reference for lartg comparison tests. +for (elty, func, rty) in ((Float64, :dlartg_, Float64), + (Float32, :slartg_, Float32), + (ComplexF64, :zlartg_, Float64), + (ComplexF32, :clartg_, Float32)) + @eval begin + function lapack_lartg(f::$elty, g::$elty) + fref = Ref{$elty}(f) + gref = Ref{$elty}(g) + cref = Ref{$rty}(0) + sref = Ref{$elty}(0) + rref = Ref{$elty}(0) + ccall((@blasfunc($func), libblastrampoline), Cvoid, + (Ref{$elty}, Ref{$elty}, Ref{$rty}, Ref{$elty}, Ref{$elty}), + fref, gref, cref, sref, rref) + return cref[], sref[], rref[] + end + end +end + # ── xTPQRT — triangular‑pentagonal QR factorization ───────────────────────── # Reference for tsqrt! (l=0) and ttqrt! (l=n). for (elty, func) in ((Float64, :dtpqrt_), diff --git a/test/lartg.jl b/test/lartg.jl new file mode 100644 index 0000000..2dfe0be --- /dev/null +++ b/test/lartg.jl @@ -0,0 +1,53 @@ +@testset "LARTG" begin + @testset "$T" for T in TEST_TYPES + rtol = test_rtol(T) + for _ in 1:50 + f = randn(T) + g = randn(T) + c, s, r = NextLA.lartg(f, g) + r_calc = c * f + s * g + z_calc = -conj(s) * f + c * g + @test r ≈ r_calc rtol=rtol + @test abs(z_calc) <= rtol * max(one(real(T)), abs(r)) + eps(real(T)) + @test abs2(r) ≈ (abs2(f) + abs2(g)) rtol=rtol + @test isfinite(r) + end + + # different branch coverage + zeroT = zero(T) + oneRT = one(real(T)) + + c, s, r = NextLA.lartg(zeroT, zeroT) + @test c == oneRT + @test s == zeroT + @test r == zeroT + + f = randn(T) + c, s, r = NextLA.lartg(f, zeroT) + @test c == oneRT + @test s == zeroT + @test r == f + + g = randn(T) + c, s, r = NextLA.lartg(zeroT, g) + @test c == zero(real(T)) + @test abs(r) ≈ abs(g) rtol=rtol + @test abs(s) ≈ oneRT rtol=rtol + end +end + +for T in (ComplexF32, ComplexF64, Float32, Float64) + @testset "LARTG LAPACK $T" begin + rtol = test_rtol(T) + for _ in 1:50 + f = randn(T) + g = randn(T) + c_nla, s_nla, r_nla = NextLA.lartg(f, g) + c_ref, s_ref, r_ref = lapack_lartg(f, g) + + @test c_nla ≈ c_ref rtol=rtol + @test s_nla ≈ s_ref rtol=rtol + @test r_nla ≈ r_ref rtol=rtol + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d2d31cf..4a7bea3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,7 @@ include("larf.jl") include("geqr2.jl") include("larft.jl") include("larfb.jl") +include("lartg.jl") include("geqrt.jl") include("unmqr.jl") include("tsqrt.jl")