From cf004758da9ff16c26449a4145a446ce4c506e3d Mon Sep 17 00:00:00 2001 From: lkdvos Date: Tue, 28 Apr 2026 08:02:07 -0400 Subject: [PATCH 1/5] make `ind` functions consistent --- src/pullbacks/eig.jl | 4 ++-- src/pullbacks/eigh.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 7b78121b..24671aa3 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -211,11 +211,11 @@ across eigenvectors associated with degenerate eigenvalues), so the correspondin `ΔV` are projected out. """ function remove_eig_gauge_dependence!( - ΔV, D, V, ind = axes(ΔV, 2); + ΔV, D, V, ind = Colon(); degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) ) - length(ind) == size(ΔV, 2) || throw(DimensionMismatch()) indV = axes(V, 2)[ind] + length(indV) == size(ΔV, 2) || throw(DimensionMismatch("Incompatible size of selected `ind` and `ΔV`")) Vp = view(V, :, indV) Ddiag = view(diagview(D), indV) gaugepart = Vp' * ΔV diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 3b517b97..fcb84599 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -201,11 +201,11 @@ across eigenvectors associated with degenerate eigenvalues), so the correspondin components of `V' * ΔV` are projected out. """ function remove_eigh_gauge_dependence!( - ΔV, D, V, ind = axes(ΔV, 2); + ΔV, D, V, ind = Colon(); degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) ) - length(ind) == size(ΔV, 2) || throw(DimensionMismatch()) indV = axes(V, 2)[ind] + length(indV) == size(ΔV, 2) || throw(DimensionMismatch("Incompatible size of selected `ind` and `ΔV`")) Vp = view(V, :, indV) Ddiag = view(diagview(D), indV) gaugepart = project_antihermitian!(Vp' * ΔV) From d4d6ce73ccebdd90e6800aff6f331cd7b819096b Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 1 May 2026 00:19:04 +0200 Subject: [PATCH 2/5] first set of fixes and consistencies --- src/pullbacks/eigh.jl | 172 ++++++++++++++++++++++++------------------ src/pullbacks/svd.jl | 24 +++--- 2 files changed, 113 insertions(+), 83 deletions(-) mode change 100644 => 100755 src/pullbacks/eigh.jl mode change 100644 => 100755 src/pullbacks/svd.jl diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl old mode 100644 new mode 100755 index fcb84599..6c40fe0a --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -1,3 +1,53 @@ +function check_and_prepare_eigh_cotangents( + D, V, ΔDmat, ΔV, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(S), + gauge_atol::Real = default_pullback_gauge_atol(ΔDmat, ΔV) + ) + + n, p = size(V) + indD = axes(D, 1)[ind] + indV = axes(V, 2)[ind] + if !iszerotangent(ΔV) + n == size(ΔV, 1) || throw(DimensionMismatch()) + length(indV) == size(ΔV, 2) || throw(DimensionMismatch()) + if indV == 1:p + ΔV₁ = copy(ΔV) + else + ΔV₁ = zero(V) + for (j, i) in enumerate(indV) + ΔV₁[:, i] .= view(ΔV, :, j) + end + end + VᴴΔV₁ = V' * ΔV₁ + ΔV₊ = mul!(ΔV₁, V, VᴴΔV₁, -1, 1) + aVᴴΔV₁ = project_antihermitian!(VᴴΔV₁) + else + ΔV₊ = nothing + aVᴴΔV₁ = zero!(similar(V, (p, p))) + end + bc = Base.broadcasted(D', D, aVᴴΔV₁) do d₁, d₂, v + return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) + end + Δgauge = norm(bc, Inf) + + Δgauge ≤ gauge_atol || + @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + + aVᴴΔV₁ .*= inv_safe.(D' .- D, degeneracy_atol) + VᴴAΔV = aVᴴΔV₁ + + if !iszerotangent(ΔDmat) + ΔD = diagview(ΔDmat) + length(indD) == length(ΔD) || throw(DimensionMismatch()) + view(diagview(VᴴAΔV), indD) .+= real.(ΔD) + else + ΔD = nothing + end + + return VᴴAΔV, ΔV₊ +end + + function check_eigh_cotangents( D, aVᴴΔV; degeneracy_atol::Real = default_pullback_rank_atol(D), @@ -39,42 +89,17 @@ function eigh_pullback!( # Basic size checks and determination Dmat, V = DV - D = diagview(Dmat) - ΔDmat, ΔV = ΔDV n = LinearAlgebra.checksquare(V) + D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + D = diagview(Dmat) - if !iszerotangent(ΔV) - n == size(ΔV, 1) || throw(DimensionMismatch()) - pV = size(ΔV, 2) - VᴴΔV = fill!(similar(V), 0) - indV = axes(V, 2)[ind] - length(indV) == pV || throw(DimensionMismatch()) - mul!(view(VᴴΔV, :, indV), V', ΔV) - aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work - - check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol) - - aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) - - if !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - pD = length(ΔDvec) - indD = axes(D, 1)[ind] - length(indD) == pD || throw(DimensionMismatch()) - view(diagview(aVᴴΔV), indD) .+= real.(ΔDvec) - end - # recycle VdΔV space - ΔA = mul!(ΔA, mul!(VᴴΔV, V, aVᴴΔV), V', 1, 1) - elseif !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - pD = length(ΔDvec) - indD = axes(D, 1)[ind] - length(indD) == pD || throw(DimensionMismatch()) - Vp = view(V, :, indD) - ΔA = mul!(ΔA, Vp * Diagonal(real(ΔDvec)), Vp', 1, 1) - end + ΔDmat, ΔV = ΔDV + VᴴΔAV, = check_and_prepare_eigh_cotangents( + D, V, ΔDmat, ΔV, ind; degeneracy_atol, gauge_atol + ) + ΔA = mul!(ΔA, V, VᴴΔAV * V', 1, 1) return ΔA end function eigh_pullback!( @@ -113,47 +138,53 @@ not small compared to `gauge_atol`. function eigh_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]), + maxiter::Int = 100 ) # Basic size checks and determination Dmat, V = DV - D = diagview(Dmat) - ΔDmat, ΔV = ΔDV (n, p) = size(V) - p == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + D = diagview(Dmat) + p == length(D) || throw(DimensionMismatch()) - if !iszerotangent(ΔV) - (n, p) == size(ΔV) || throw(DimensionMismatch()) - VᴴΔV = V' * ΔV - aVᴴΔV = project_antihermitian!(VᴴΔV) - - check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol) - - aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) - - if !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - p == length(ΔDvec) || throw(DimensionMismatch()) - diagview(aVᴴΔV) .+= real.(ΔDvec) + ΔDmat, ΔV = ΔDV + VᴴΔAV, ΔV₊ = check_and_prepare_eigh_cotangents( + D, V, ΔDmat, ΔV; degeneracy_atol, gauge_atol + ) + ΔAV = V * VᴴΔAV + ΔA = mul!(ΔA, ΔAV, V', 1, 1) + + if !iszerotangent(ΔV₊) + X₀ = rdiv!(ΔV₊, Diagonal(D)) + VD = mul!(ΔAV, V, Dmat) # recycle ΔAV + AP = mul!(copy(A), VD, V', -1, 1) + dabsmax = maximum(abs, D) + AP ./= dabsmax + D⁻¹ = dabsmax ./ D + X₁ = rmul!(AP * X₀, Diagonal(D⁻¹)) + X₁ .+= X₀ + Xₖ, Xₖ₊₁ = X₁, X₀ + APₖ, APₖ₊₁ = AP * AP, AP + D⁻¹ₖ, D⁻¹ₖ₊₁ = D⁻¹ .^ 2, D⁻¹ + for k in 1:maxiter + Xₖ₊₁ = rmul!(mul!(Xₖ₊₁, APₖ, Xₖ), Diagonal(D⁻¹ₖ)) + if norm(Xₖ₊₁, Inf) < degeneracy_atol + break + end + Xₖ₊₁ .+= Xₖ + if k == maxiter + @warn "Sylvester iteration did not converge after $k iterations, final norm of X: $(norm(Xₖ₊₁, Inf)))" + break + end + D⁻¹ₖ₊₁ .= D⁻¹ₖ .^ 2 + APₖ₊₁ = mul!(APₖ₊₁, APₖ, APₖ) + Xₖ, Xₖ₊₁ = Xₖ₊₁, Xₖ + APₖ, APₖ₊₁ = APₖ₊₁, APₖ + D⁻¹ₖ, D⁻¹ₖ₊₁ = D⁻¹ₖ₊₁, D⁻¹ₖ end - - Z = V * aVᴴΔV - - # add contribution from orthogonal complement - W = qr_null(V) - WᴴΔV = W' * ΔV - X = _sylvester(W' * A * W, -Dmat, WᴴΔV) - Z = mul!(Z, W, X, 1, 1) - - # put everything together: symmetrize for hermitian case - ΔA = mul!(ΔA, Z, V', 1 // 2, 1) - ΔA = mul!(ΔA, V, Z', 1 // 2, 1) - elseif !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - p == length(ΔDvec) || throw(DimensionMismatch()) - ΔA = mul!(ΔA, V * Diagonal(real(ΔDvec)), V', 1, 1) + ΔA = project_hermitian!(mul!(ΔA, Xₖ, V', 1, 1)) end return ΔA end @@ -201,15 +232,12 @@ across eigenvectors associated with degenerate eigenvalues), so the correspondin components of `V' * ΔV` are projected out. """ function remove_eigh_gauge_dependence!( - ΔV, D, V, ind = Colon(); + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) ) - indV = axes(V, 2)[ind] - length(indV) == size(ΔV, 2) || throw(DimensionMismatch("Incompatible size of selected `ind` and `ΔV`")) - Vp = view(V, :, indV) - Ddiag = view(diagview(D), indV) - gaugepart = project_antihermitian!(Vp' * ΔV) + Ddiag = diagview(D) + gaugepart = project_antihermitian!(V' * ΔV) gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0 - mul!(ΔV, Vp, gaugepart, -1, 1) + mul!(ΔV, V, gaugepart, -1, 1) return ΔV end diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl old mode 100644 new mode 100755 index 4b352416..f989bf5f --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -82,7 +82,7 @@ function check_and_prepare_svd_cotangents( aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r))) end bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v - return abs(s₁ - s₂) < degeneracy_atol ? zero(u) + zero(v) : u + v + return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v) end Δgauge = max(Δgauge, norm(bc, Inf)) @@ -104,13 +104,13 @@ function check_and_prepare_svd_cotangents( Δgauge ≤ gauge_atol || @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - UdΔAV = (aUᴴΔU₁ .+ aVᴴΔV₁) .* inv_safe.(S₁' .- S₁, degeneracy_atol) .+ + UᴴΔAV = (aUᴴΔU₁ .+ aVᴴΔV₁) .* inv_safe.(S₁' .- S₁, degeneracy_atol) .+ (aUᴴΔU₁ .- aVᴴΔV₁) .* inv_safe.(S₁' .+ S₁, degeneracy_atol) if !iszerotangent(ΔS₁) - diagview(UdΔAV) .+= real.(ΔS₁) + diagview(UᴴΔAV) .+= real.(ΔS₁) end - return UdΔAV, ΔU₊, ΔV₊ᴴ + return UᴴΔAV, ΔU₊, ΔV₊ᴴ end """ @@ -155,10 +155,10 @@ function svd_pullback!( S₁ = view(S, 1:r) ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ - UdΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( + UᴴΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, r, ind; degeneracy_atol, gauge_atol ) - ΔA = mul!(ΔA, U₁, UdΔAV * V₁ᴴ, 1, 1) # add the contribution to ΔA + ΔA = mul!(ΔA, U₁, UᴴΔAV * V₁ᴴ, 1, 1) # add the contribution to ΔA # Add the remaining contributions if m > r && !iszerotangent(ΔU₊) # ΔU₁ is already orthogonal to U₁ @@ -210,7 +210,7 @@ function svd_trunc_pullback!( rank_atol::Real = 0, degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...), - maxiter::Int = 1000, + maxiter::Int = 100, ) # Extract the SVD components U, Smat, Vᴴ = USVᴴ @@ -223,17 +223,19 @@ function svd_trunc_pullback!( # Extract and check the cotangents ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ - UdΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( + UᴴΔAV, ΔU₊, ΔV₊ᴴ = check_and_prepare_svd_cotangents( U, S, Vᴴ, ΔU, ΔSmat, ΔVᴴ, p; degeneracy_atol, gauge_atol ) - ΔA = mul!(ΔA, U, UdΔAV * Vᴴ, 1, 1) # add the contribution to ΔA + ΔAV = U * UᴴΔAV + ΔA = mul!(ΔA, ΔAV, Vᴴ, 1, 1) # add the contribution to ΔA # The contribtutions from the orthogonal complement need to be treated differently # ΔU and ΔVᴴ are already orthogonal to U and Vᴴ if !(iszerotangent(ΔU₊) && iszerotangent(ΔV₊ᴴ)) X₀ = iszerotangent(ΔU₊) ? zero(U) : rdiv!(ΔU₊, Diagonal(S)) Y₀ᴴ = iszerotangent(ΔV₊ᴴ) ? zero(Vᴴ) : ldiv!(Diagonal(S), ΔV₊ᴴ) - AP = mul!(copy(A), U, Smat * Vᴴ, -1, 1) + US = mul!(ΔAV, U, Smat) # recycle ΔAV + AP = mul!(copy(A), US, Vᴴ, -1, 1) AP ./= S[end] S⁻¹ = S[end] ./ S X₁ = rmul!(AP * Y₀ᴴ', Diagonal(S⁻¹)) @@ -254,7 +256,7 @@ function svd_trunc_pullback!( Xₖ₊₁ .+= Xₖ Yₖ₊₁ᴴ .+= Yₖᴴ if k == maxiter - @warn "Sylvester iteration did not converge after $k iterations, final norms: (X: $(norm(Xₖ₊₁, Inf)), Yᴴ: $(norm(Yₖ₊₁ᴴ, Inf)))" + @warn "Sylvester iteration did not converge after $k iterations, final norms of X: $(norm(Xₖ₊₁, Inf)), Yᴴ: $(norm(Yₖ₊₁ᴴ, Inf)))" break end S⁻¹ₖ₊₁ .= S⁻¹ₖ .^ 2 From e2b2cd519bf8f313675cf0a49c2983d1b0da867b Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Tue, 5 May 2026 01:20:47 +0200 Subject: [PATCH 3/5] more updates to pullbacks of eig and eigh --- src/pullbacks/eig.jl | 187 ++++++++++++++++++++--------------- src/pullbacks/eigh.jl | 46 ++++----- src/pullbacks/svd.jl | 2 +- test/testsuite/chainrules.jl | 14 +-- 4 files changed, 134 insertions(+), 115 deletions(-) mode change 100644 => 100755 src/pullbacks/eig.jl mode change 100644 => 100755 test/testsuite/chainrules.jl diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl old mode 100644 new mode 100755 index 24671aa3..bf410169 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -1,13 +1,53 @@ -function check_eig_cotangents( - D, VᴴΔV; - degeneracy_atol::Real = default_pullback_rank_atol(D), - gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV) +function check_and_prepare_eig_cotangents( + D, V, ViG, ΔDmat, ΔV, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(S), + gauge_atol::Real = default_pullback_gauge_atol(ΔDmat, ΔV) ) - mask = abs.(transpose(D) .- D) .< degeneracy_atol - Δgauge = norm(view(VᴴΔV, mask)) + + n, p = size(V) + indD = axes(D, 1)[ind] + indV = axes(V, 2)[ind] + if !iszerotangent(ΔV) + n == size(ΔV, 1) || throw(DimensionMismatch()) + length(indV) == size(ΔV, 2) || throw(DimensionMismatch()) + if indV == 1:p + ΔV₁ = copy(ΔV) + else + ΔV₁ = zero(V) + for (j, i) in enumerate(indV) + ΔV₁[:, i] .= view(ΔV, :, j) + end + end + VᴴΔV₁ = V' * ΔV₁ + if p == n + ΔV₊ = zero!(ΔV₁) + else + ΔV₊ = mul!(ΔV₁, ViG, VᴴΔV₁, -1, 1) + end + else + ΔV₊ = nothing + VᴴΔV₁ = zero!(similar(V, (p, p))) + end + bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v + return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) + end + Δgauge = norm(bc, Inf) + Δgauge ≤ gauge_atol || @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return + + VᴴΔV₁ .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) + VᴴAΔV = VᴴΔV₁ + + if !iszerotangent(ΔDmat) + ΔD = diagview(ΔDmat) + length(indD) == length(ΔD) || throw(DimensionMismatch()) + view(diagview(VᴴAΔV), indD) .+= ΔD + else + ΔD = nothing + end + + return VᴴAΔV, ΔV₊ end """ @@ -39,51 +79,24 @@ function eig_pullback!( # Basic size checks and determination Dmat, V = DV - D = diagview(Dmat) - ΔDmat, ΔV = ΔDV n = LinearAlgebra.checksquare(V) + D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + ViG = inv(V)' - if !iszerotangent(ΔV) - n == size(ΔV, 1) || throw(DimensionMismatch()) - pV = size(ΔV, 2) - VᴴΔV = fill!(similar(V), 0) - indV = axes(V, 2)[ind] - length(indV) == pV || throw(DimensionMismatch()) - mul!(view(VᴴΔV, :, indV), V', ΔV) - - check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol) - - VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) - - if !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - pD = length(ΔDvec) - indD = axes(D, 1)[ind] - length(indD) == pD || throw(DimensionMismatch()) - view(diagview(VᴴΔV), indD) .+= ΔDvec - end - PΔV = V' \ VᴴΔV - if eltype(ΔA) <: Real - ΔAc = mul!(VᴴΔV, PΔV, V') # recycle VdΔV memory - ΔA .+= real.(ΔAc) - else - ΔA = mul!(ΔA, PΔV, V', 1, 1) - end - elseif !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - pD = length(ΔDvec) - indD = axes(D, 1)[ind] - length(indD) == pD || throw(DimensionMismatch()) - Vp = view(V, :, indD) - PΔV = Vp' \ Diagonal(ΔDvec) - if eltype(ΔA) <: Real - ΔAc = PΔV * Vp' - ΔA .+= real.(ΔAc) - else - ΔA = mul!(ΔA, PΔV, V', 1, 1) - end + ΔDmat, ΔV = ΔDV + VᴴΔAV, = check_and_prepare_eig_cotangents( + D, V, ViG, ΔDmat, ΔV, ind; degeneracy_atol, gauge_atol + ) + + if eltype(ΔA) <: Real + Z = ViG * VᴴΔAV + ΔAc = mul!(VᴴΔAV, Z, V') # recycle VᴴΔAV + ΔA .+= real.(ΔAc) + else + Z = ViG * VᴴΔAV + ΔA = mul!(ΔA, Z, V', 1, 1) end return ΔA end @@ -123,44 +136,56 @@ not small compared to `gauge_atol`. function eig_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]), + maxiter::Int = 100 # TODO: better default, depending on expected number of steps using quadratic convergence? ) # Basic size checks and determination Dmat, V = DV - D = diagview(Dmat) - ΔDmat, ΔV = ΔDV (n, p) = size(V) - p == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + D = diagview(Dmat) + p == length(D) || throw(DimensionMismatch()) G = V' * V + ViG = V / LinearAlgebra.cholesky!(G) - if !iszerotangent(ΔV) - (n, p) == size(ΔV) || throw(DimensionMismatch()) - VᴴΔV = V' * ΔV - check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol) - - ΔVperp = ΔV - V * inv(G) * VᴴΔV - VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) - else - VᴴΔV = zero(G) - end - - if !iszerotangent(ΔDmat) - ΔDvec = diagview(ΔDmat) - p == length(ΔDvec) || throw(DimensionMismatch()) - diagview(VᴴΔV) .+= ΔDvec - end - Z = V' \ VᴴΔV + ΔDmat, ΔV = ΔDV + VᴴΔAV, ΔV₊ = check_and_prepare_eig_cotangents( + D, V, ViG, ΔDmat, ΔV; degeneracy_atol, gauge_atol + ) + Z = ViG * VᴴΔAV # add contribution from orthogonal complement - PA = A - (A * V) / V - Y = mul!(ΔVperp, PA', Z, 1, 1) - X = _sylvester(PA', -Dmat', Y) - Z .+= X - + AP = mul!(complex.(A), V * Dmat, ViG', -1, 1) + X₀ = iszerotangent(ΔV₊) ? AP' * Z : mul!(ΔV₊, AP', Z, 1, 1) + X₀ ./= D' + dabsmax = maximum(abs, D) + AP ./= dabsmax + D̄⁻¹ = dabsmax ./ conj.(D) + X₁ = rmul!(AP' * X₀, Diagonal(D̄⁻¹)) + X₁ .+= X₀ + Xₖ, Xₖ₊₁ = X₁, X₀ + APₖ, APₖ₊₁ = AP * AP, AP + D̄⁻¹ₖ, D̄⁻¹ₖ₊₁ = D̄⁻¹ .^ 2, D̄⁻¹ + for k in 1:maxiter + Xₖ₊₁ = rmul!(mul!(Xₖ₊₁, APₖ', Xₖ), Diagonal(D̄⁻¹ₖ)) + if norm(Xₖ₊₁, Inf) < degeneracy_atol + break + end + Xₖ₊₁ .+= Xₖ + if k == maxiter + @warn "Sylvester iteration did not converge after $k iterations, final norm of X: $(norm(Xₖ₊₁, Inf)))" + break + end + D̄⁻¹ₖ₊₁ .= D̄⁻¹ₖ .^ 2 + APₖ₊₁ = mul!(APₖ₊₁, APₖ, APₖ) + Xₖ, Xₖ₊₁ = Xₖ₊₁, Xₖ + APₖ, APₖ₊₁ = APₖ₊₁, APₖ + D̄⁻¹ₖ, D̄⁻¹ₖ₊₁ = D̄⁻¹ₖ₊₁, D̄⁻¹ₖ + end + Z .+= Xₖ if eltype(ΔA) <: Real - ΔAc = Z * V' + ΔAc = mul!(AP, Z, V') # recycle AP ΔA .+= real.(ΔAc) else ΔA = mul!(ΔA, Z, V', 1, 1) @@ -211,15 +236,13 @@ across eigenvectors associated with degenerate eigenvalues), so the correspondin `ΔV` are projected out. """ function remove_eig_gauge_dependence!( - ΔV, D, V, ind = Colon(); + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) ) - indV = axes(V, 2)[ind] - length(indV) == size(ΔV, 2) || throw(DimensionMismatch("Incompatible size of selected `ind` and `ΔV`")) - Vp = view(V, :, indV) - Ddiag = view(diagview(D), indV) - gaugepart = Vp' * ΔV + Ddiag = diagview(D) + gaugepart = V' * ΔV gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0 - mul!(ΔV, Vp / (Vp' * Vp), gaugepart, -1, 1) + ViG = V / LinearAlgebra.cholesky!(V' * V) + mul!(ΔV, ViG, gaugepart, -1, 1) return ΔV end diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 6c40fe0a..fc8d86e2 100755 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -19,13 +19,17 @@ function check_and_prepare_eigh_cotangents( end end VᴴΔV₁ = V' * ΔV₁ - ΔV₊ = mul!(ΔV₁, V, VᴴΔV₁, -1, 1) + if p == n + ΔV₊ = zero!(ΔV₁) + else + ΔV₊ = mul!(ΔV₁, V, VᴴΔV₁, -1, 1) + end aVᴴΔV₁ = project_antihermitian!(VᴴΔV₁) else ΔV₊ = nothing aVᴴΔV₁ = zero!(similar(V, (p, p))) end - bc = Base.broadcasted(D', D, aVᴴΔV₁) do d₁, d₂, v + bc = Base.broadcasted(transpose(D), D, aVᴴΔV₁) do d₁, d₂, v return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) end Δgauge = norm(bc, Inf) @@ -47,19 +51,6 @@ function check_and_prepare_eigh_cotangents( return VᴴAΔV, ΔV₊ end - -function check_eigh_cotangents( - D, aVᴴΔV; - degeneracy_atol::Real = default_pullback_rank_atol(D), - gauge_atol::Real = default_pullback_gauge_atol(aVᴴΔV) - ) - mask = abs.(D' .- D) .< degeneracy_atol - Δgauge = norm(view(aVᴴΔV, mask)) - Δgauge ≤ gauge_atol || - @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - return -end - """ eigh_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; @@ -93,13 +84,13 @@ function eigh_pullback!( D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) - D = diagview(Dmat) ΔDmat, ΔV = ΔDV VᴴΔAV, = check_and_prepare_eigh_cotangents( D, V, ΔDmat, ΔV, ind; degeneracy_atol, gauge_atol ) - ΔA = mul!(ΔA, V, VᴴΔAV * V', 1, 1) + + ΔA = mul!(ΔA, V * VᴴΔAV, V', 1, 1) return ΔA end function eigh_pullback!( @@ -139,27 +130,24 @@ function eigh_trunc_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV; degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]), - maxiter::Int = 100 + maxiter::Int = 100 # TODO: better default, depending on expected number of steps using quadratic convergence? ) # Basic size checks and determination Dmat, V = DV (n, p) = size(V) - (n, n) == size(ΔA) || throw(DimensionMismatch()) D = diagview(Dmat) p == length(D) || throw(DimensionMismatch()) + (n, n) == size(ΔA) || throw(DimensionMismatch()) ΔDmat, ΔV = ΔDV VᴴΔAV, ΔV₊ = check_and_prepare_eigh_cotangents( D, V, ΔDmat, ΔV; degeneracy_atol, gauge_atol ) - ΔAV = V * VᴴΔAV - ΔA = mul!(ΔA, ΔAV, V', 1, 1) - + Z = V * VᴴΔAV if !iszerotangent(ΔV₊) X₀ = rdiv!(ΔV₊, Diagonal(D)) - VD = mul!(ΔAV, V, Dmat) # recycle ΔAV - AP = mul!(copy(A), VD, V', -1, 1) + AP = mul!(copy(A), V * Dmat, V', -1, 1) dabsmax = maximum(abs, D) AP ./= dabsmax D⁻¹ = dabsmax ./ D @@ -184,7 +172,15 @@ function eigh_trunc_pullback!( APₖ, APₖ₊₁ = APₖ₊₁, APₖ D⁻¹ₖ, D⁻¹ₖ₊₁ = D⁻¹ₖ₊₁, D⁻¹ₖ end - ΔA = project_hermitian!(mul!(ΔA, Xₖ, V', 1, 1)) + Z .+= Xₖ + # we cannot directly multiply Z * V' into ΔA, because we have to + # take the Hermitian part, and cannot apply project_hermitian! to + # the current contents of ΔA + ΔA′ = project_hermitian!(mul!(AP, Z, V', 1, 1)) # recycle AP + ΔA .+= ΔA′ + else + # in this case, Z * V' is automatically Hermitian, so we can directly add it to ΔA + ΔA = mul!(ΔA, Z, V', 1, 1) end return ΔA end diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index f989bf5f..832d04a1 100755 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -210,7 +210,7 @@ function svd_trunc_pullback!( rank_atol::Real = 0, degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...), - maxiter::Int = 100, + maxiter::Int = 100 # TODO: better default, depending on expected number of steps using quadratic convergence? ) # Extract the SVD components U, Smat, Vᴴ = USVᴴ diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl old mode 100644 new mode 100755 index 558afc83..25b49841 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -258,9 +258,13 @@ function test_chainrules_eig( output_tangent = ΔDVtrunc, atol = atol, rtol = rtol ) ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + Ddiag = diagview(DV[1]) + p = sortperm(Ddiag, by = abs, rev = true) + if abs(Ddiag[p[r + 1]]) < abs(Ddiag[p[r]]) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end end truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) @@ -273,10 +277,6 @@ function test_chainrules_eig( cr_copy_eig_trunc_no_error, A, truncalg ⊢ NoTangent(); output_tangent = ΔDVtrunc, atol = atol, rtol = rtol ) - ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) end end end From 7a462539bd7f3b315904619f0c3cac33c2870c00 Mon Sep 17 00:00:00 2001 From: Jutho Date: Wed, 6 May 2026 16:24:30 +0200 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: Lukas Devos Co-authored-by: Jutho --- src/pullbacks/eig.jl | 11 +++-------- src/pullbacks/eigh.jl | 12 ++++-------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index bf410169..f2ce2ec6 100755 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -10,14 +10,9 @@ function check_and_prepare_eig_cotangents( if !iszerotangent(ΔV) n == size(ΔV, 1) || throw(DimensionMismatch()) length(indV) == size(ΔV, 2) || throw(DimensionMismatch()) - if indV == 1:p - ΔV₁ = copy(ΔV) - else - ΔV₁ = zero(V) - for (j, i) in enumerate(indV) - ΔV₁[:, i] .= view(ΔV, :, j) - end - end + ΔV₁ = similar(V) + ΔV₁[:, indV] = ΔV + zero!(view(ΔV₁, :, (length(indV) + 1):p)) VᴴΔV₁ = V' * ΔV₁ if p == n ΔV₊ = zero!(ΔV₁) diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index fc8d86e2..d2a2998f 100755 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -10,14 +10,9 @@ function check_and_prepare_eigh_cotangents( if !iszerotangent(ΔV) n == size(ΔV, 1) || throw(DimensionMismatch()) length(indV) == size(ΔV, 2) || throw(DimensionMismatch()) - if indV == 1:p - ΔV₁ = copy(ΔV) - else - ΔV₁ = zero(V) - for (j, i) in enumerate(indV) - ΔV₁[:, i] .= view(ΔV, :, j) - end - end + ΔV₁ = similar(V) + ΔV₁[:, indV] = ΔV + zero!(view(ΔV₁, :, (length(indV) + 1):p)) VᴴΔV₁ = V' * ΔV₁ if p == n ΔV₊ = zero!(ΔV₁) @@ -176,6 +171,7 @@ function eigh_trunc_pullback!( # we cannot directly multiply Z * V' into ΔA, because we have to # take the Hermitian part, and cannot apply project_hermitian! to # the current contents of ΔA + # TODO: add an `add_project_hermitian!` ΔA′ = project_hermitian!(mul!(AP, Z, V', 1, 1)) # recycle AP ΔA .+= ΔA′ else From b7f968a31feb29d478979f157a57171a34bb043f Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 6 May 2026 13:39:49 -0400 Subject: [PATCH 5/5] bypass issue with `zero!` and `view` --- src/pullbacks/eig.jl | 3 +-- src/pullbacks/eigh.jl | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index f2ce2ec6..427aaf55 100755 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -10,9 +10,8 @@ function check_and_prepare_eig_cotangents( if !iszerotangent(ΔV) n == size(ΔV, 1) || throw(DimensionMismatch()) length(indV) == size(ΔV, 2) || throw(DimensionMismatch()) - ΔV₁ = similar(V) + ΔV₁ = zero(V) ΔV₁[:, indV] = ΔV - zero!(view(ΔV₁, :, (length(indV) + 1):p)) VᴴΔV₁ = V' * ΔV₁ if p == n ΔV₊ = zero!(ΔV₁) diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index d2a2998f..cc986907 100755 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -10,9 +10,8 @@ function check_and_prepare_eigh_cotangents( if !iszerotangent(ΔV) n == size(ΔV, 1) || throw(DimensionMismatch()) length(indV) == size(ΔV, 2) || throw(DimensionMismatch()) - ΔV₁ = similar(V) + ΔV₁ = zero(V) ΔV₁[:, indV] = ΔV - zero!(view(ΔV₁, :, (length(indV) + 1):p)) VᴴΔV₁ = V' * ΔV₁ if p == n ΔV₊ = zero!(ΔV₁)