Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 99 additions & 82 deletions src/pullbacks/eig.jl
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,13 +1,47 @@
function check_eig_cotangents(
D, VᴴΔV;
degeneracy_atol::Real = default_pullback_rank_atol(D),
gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV)
function check_and_prepare_eig_cotangents(
D, V, ViG, ΔDmat, ΔV, ind = Colon();
degeneracy_atol::Real = default_pullback_rank_atol(S),
gauge_atol::Real = default_pullback_gauge_atol(ΔDmat, ΔV)
)
mask = abs.(transpose(D) .- D) .< degeneracy_atol
Δgauge = norm(view(VᴴΔV, mask))

n, p = size(V)
indD = axes(D, 1)[ind]
indV = axes(V, 2)[ind]
if !iszerotangent(ΔV)
n == size(ΔV, 1) || throw(DimensionMismatch())
length(indV) == size(ΔV, 2) || throw(DimensionMismatch())
ΔV₁ = zero(V)
ΔV₁[:, indV] = ΔV
Comment on lines +13 to +14
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still like to keep some of the old behaviour, so maybe this is a compromise

Suggested change
ΔV₁ = zero(V)
ΔV₁[:, indV] = ΔV
if indV == 1:p
ΔV₁ = copy(ΔV)
else
ΔV₁ = zero(V)
ΔV₁[:, indV] .= ΔV
end

VᴴΔV₁ = V' * ΔV₁
if p == n
ΔV₊ = zero!(ΔV₁)
else
ΔV₊ = mul!(ΔV₁, ViG, VᴴΔV₁, -1, 1)
end
else
ΔV₊ = nothing
VᴴΔV₁ = zero!(similar(V, (p, p)))
end
bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v
return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v)
end
Δgauge = norm(bc, Inf)

Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return

VᴴΔV₁ .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
VᴴAΔV = VᴴΔV₁

if !iszerotangent(ΔDmat)
ΔD = diagview(ΔDmat)
length(indD) == length(ΔD) || throw(DimensionMismatch())
view(diagview(VᴴAΔV), indD) .+= ΔD
else
ΔD = nothing
end

return VᴴAΔV, ΔV₊
end

"""
Expand Down Expand Up @@ -39,51 +73,24 @@ function eig_pullback!(

# Basic size checks and determination
Dmat, V = DV
D = diagview(Dmat)
ΔDmat, ΔV = ΔDV
n = LinearAlgebra.checksquare(V)
D = diagview(Dmat)
n == length(D) || throw(DimensionMismatch())
(n, n) == size(ΔA) || throw(DimensionMismatch())
ViG = inv(V)'

if !iszerotangent(ΔV)
n == size(ΔV, 1) || throw(DimensionMismatch())
pV = size(ΔV, 2)
VᴴΔV = fill!(similar(V), 0)
indV = axes(V, 2)[ind]
length(indV) == pV || throw(DimensionMismatch())
mul!(view(VᴴΔV, :, indV), V', ΔV)

check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)

VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))

if !iszerotangent(ΔDmat)
ΔDvec = diagview(ΔDmat)
pD = length(ΔDvec)
indD = axes(D, 1)[ind]
length(indD) == pD || throw(DimensionMismatch())
view(diagview(VᴴΔV), indD) .+= ΔDvec
end
PΔV = V' \ VᴴΔV
if eltype(ΔA) <: Real
ΔAc = mul!(VᴴΔV, PΔV, V') # recycle VdΔV memory
ΔA .+= real.(ΔAc)
else
ΔA = mul!(ΔA, PΔV, V', 1, 1)
end
elseif !iszerotangent(ΔDmat)
ΔDvec = diagview(ΔDmat)
pD = length(ΔDvec)
indD = axes(D, 1)[ind]
length(indD) == pD || throw(DimensionMismatch())
Vp = view(V, :, indD)
PΔV = Vp' \ Diagonal(ΔDvec)
if eltype(ΔA) <: Real
ΔAc = PΔV * Vp'
ΔA .+= real.(ΔAc)
else
ΔA = mul!(ΔA, PΔV, V', 1, 1)
end
ΔDmat, ΔV = ΔDV
VᴴΔAV, = check_and_prepare_eig_cotangents(
D, V, ViG, ΔDmat, ΔV, ind; degeneracy_atol, gauge_atol
)

if eltype(ΔA) <: Real
Z = ViG * VᴴΔAV
ΔAc = mul!(VᴴΔAV, Z, V') # recycle VᴴΔAV
ΔA .+= real.(ΔAc)
else
Z = ViG * VᴴΔAV
ΔA = mul!(ΔA, Z, V', 1, 1)
end
return ΔA
end
Expand Down Expand Up @@ -123,44 +130,56 @@ not small compared to `gauge_atol`.
function eig_trunc_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV;
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]),
maxiter::Int = 100 # TODO: better default, depending on expected number of steps using quadratic convergence?
)

# Basic size checks and determination
Dmat, V = DV
D = diagview(Dmat)
ΔDmat, ΔV = ΔDV
(n, p) = size(V)
p == length(D) || throw(DimensionMismatch())
(n, n) == size(ΔA) || throw(DimensionMismatch())
D = diagview(Dmat)
p == length(D) || throw(DimensionMismatch())
G = V' * V
ViG = V / LinearAlgebra.cholesky!(G)

if !iszerotangent(ΔV)
(n, p) == size(ΔV) || throw(DimensionMismatch())
VᴴΔV = V' * ΔV
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)

ΔVperp = ΔV - V * inv(G) * VᴴΔV
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
else
VᴴΔV = zero(G)
end

if !iszerotangent(ΔDmat)
ΔDvec = diagview(ΔDmat)
p == length(ΔDvec) || throw(DimensionMismatch())
diagview(VᴴΔV) .+= ΔDvec
end
Z = V' \ VᴴΔV
ΔDmat, ΔV = ΔDV
VᴴΔAV, ΔV₊ = check_and_prepare_eig_cotangents(
D, V, ViG, ΔDmat, ΔV; degeneracy_atol, gauge_atol
)
Z = ViG * VᴴΔAV

# add contribution from orthogonal complement
PA = A - (A * V) / V
Y = mul!(ΔVperp, PA', Z, 1, 1)
X = _sylvester(PA', -Dmat', Y)
Z .+= X

AP = mul!(complex.(A), V * Dmat, ViG', -1, 1)
X₀ = iszerotangent(ΔV₊) ? AP' * Z : mul!(ΔV₊, AP', Z, 1, 1)
X₀ ./= D'
dabsmax = maximum(abs, D)
AP ./= dabsmax
D̄⁻¹ = dabsmax ./ conj.(D)
X₁ = rmul!(AP' * X₀, Diagonal(D̄⁻¹))
X₁ .+= X₀
Xₖ, Xₖ₊₁ = X₁, X₀
APₖ, APₖ₊₁ = AP * AP, AP
D̄⁻¹ₖ, D̄⁻¹ₖ₊₁ = D̄⁻¹ .^ 2, D̄⁻¹
for k in 1:maxiter
Xₖ₊₁ = rmul!(mul!(Xₖ₊₁, APₖ', Xₖ), Diagonal(D̄⁻¹ₖ))
if norm(Xₖ₊₁, Inf) < degeneracy_atol
break
end
Xₖ₊₁ .+= Xₖ
if k == maxiter
@warn "Sylvester iteration did not converge after $k iterations, final norm of X: $(norm(Xₖ₊₁, Inf)))"
break
end
D̄⁻¹ₖ₊₁ .= D̄⁻¹ₖ .^ 2
APₖ₊₁ = mul!(APₖ₊₁, APₖ, APₖ)
Xₖ, Xₖ₊₁ = Xₖ₊₁, Xₖ
APₖ, APₖ₊₁ = APₖ₊₁, APₖ
D̄⁻¹ₖ, D̄⁻¹ₖ₊₁ = D̄⁻¹ₖ₊₁, D̄⁻¹ₖ
end
Z .+= Xₖ
if eltype(ΔA) <: Real
ΔAc = Z * V'
ΔAc = mul!(AP, Z, V') # recycle AP
ΔA .+= real.(ΔAc)
else
ΔA = mul!(ΔA, Z, V', 1, 1)
Expand Down Expand Up @@ -211,15 +230,13 @@ across eigenvectors associated with degenerate eigenvalues), so the correspondin
`ΔV` are projected out.
"""
function remove_eig_gauge_dependence!(
ΔV, D, V, ind = axes(ΔV, 2);
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
)
length(ind) == size(ΔV, 2) || throw(DimensionMismatch())
indV = axes(V, 2)[ind]
Vp = view(V, :, indV)
Ddiag = view(diagview(D), indV)
gaugepart = Vp' * ΔV
Ddiag = diagview(D)
gaugepart = V' * ΔV
gaugepart[abs.(transpose(Ddiag) .- Ddiag) .>= degeneracy_atol] .= 0
mul!(ΔV, Vp / (Vp' * Vp), gaugepart, -1, 1)
ViG = V / LinearAlgebra.cholesky!(V' * V)
mul!(ΔV, ViG, gaugepart, -1, 1)
return ΔV
end
Loading
Loading