Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions src/factorizations/pullbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ for pullback! in (:qr_null_pullback!, :lq_null_pullback!)
return Δt
end
end

_notrunc_ind(t) = SectorDict(c => Colon() for c in blocksectors(t))

for pullback! in (:svd_pullback!, :eig_pullback!, :eigh_pullback!)
Expand All @@ -51,8 +50,61 @@ for pullback_trunc! in (:svd_trunc_pullback!, :eig_trunc_pullback!, :eigh_trunc_
foreachblock(Δt, t) do c, (Δb, b)
Fc = block.(F, Ref(c))
ΔFc = block.(ΔF, Ref(c))
return MAK.$pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...)
MAK.$pullback_trunc!(Δb, b, Fc, ΔFc; kwargs...)
return nothing
end
return Δt
end
end

for f in (:qr, :lq)
remove_f_gauge_dependence! = Symbol(:remove_, f, :_gauge_dependence!)
remove_f_null_gauge_dependence! = Symbol(:remove_, f, :_null_gauge_dependence!)
@eval function MAK.$remove_f_gauge_dependence!(
ΔF₁::AbstractTensorMap, ΔF₂::AbstractTensorMap, A, F₁, F₂;
kwargs...
)
foreachblock(ΔF₁, ΔF₂, A, F₁, F₂) do _, (Δf₁, Δf₂, a, f₁, f₂)
MAK.$remove_f_gauge_dependence!(Δf₁, Δf₂, a, f₁, f₂)
return nothing
end
return ΔF₁, ΔF₂
end
# Already captured by MAK implementation
# @eval function MAK.$remove_f_null_gauge_dependence!(ΔN::AbstractTensorMap, A, N; kwargs...)
# foreachblock(ΔN, A, N) do _, (Δn, a, n)
# $remove_f_gauge_dependence!(Δn, a, n)
# end
# return ΔN
# end
end

for f in (:eig, :eigh)
remove_f_gauge_dependence! = Symbol(:remove_, f, :_gauge_dependence!)
@eval function MAK.$remove_f_gauge_dependence!(ΔV::AbstractTensorMap, D, V; kwargs...)
foreachblock(ΔV, D, V) do c, (Δv, d, v)
MAK.$remove_f_gauge_dependence!(Δv, d, v; kwargs...)
return nothing
end
return ΔV
end
@eval function MAK.$remove_f_gauge_dependence!(ΔV::AbstractTensorMap, D, V, inds; kwargs...)
foreachblock(ΔV, D, V) do c, (Δv, d, v)
haskey(inds, c) || return nothing
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that if !haskey(inds, c), then indeed ΔV doesn't have such a block either and Δv is just an empty matrix? But due to the fact that a c block can exist in V and D and the union semantics of foreachblock, such c would still be generated by the iterator.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, although actually in the current implementation this never happens since all sectors are always present in the inds that come from the truncation implementations, possibly with inds[c] empty.

ind = inds[c]
MAK.$remove_f_gauge_dependence!(Δv, d, v, ind; kwargs...)
return nothing
end
return ΔV
end
end
function MAK.remove_svd_gauge_dependence!(
ΔU::AbstractTensorMap, ΔVᴴ::AbstractTensorMap, U, S, Vᴴ;
kwargs...
)
foreachblock(ΔU, ΔVᴴ, U, S, Vᴴ) do c, (Δu, Δvᴴ, u, s, vᴴ)
MAK.remove_svd_gauge_dependence!(Δu, Δvᴴ, u, s, vᴴ)
return nothing
end
return ΔU, ΔVᴴ
end
15 changes: 8 additions & 7 deletions test/chainrules/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using LinearAlgebra
using Zygote
using MatrixAlgebraKit
using MatrixAlgebraKit: diagview

using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!,
remove_eigh_gauge_dependence!, remove_eig_gauge_dependence!, remove_svd_gauge_dependence!

# Tests
# -----
Expand Down Expand Up @@ -52,7 +53,7 @@ for V in spacelist
@test_logs (:warn, r"^`qr") match_mode = :any full_pb((ΔQ, ΔR))
end

remove_qrgauge_dependence!(ΔQ, t, Q)
remove_qr_gauge_dependence!(ΔQ, ΔR, t, Q, R)

test_ad_rrule(qr_full, t; fkwargs, atol, rtol, output_tangent = (ΔQ, ΔR))
test_ad_rrule(
Expand Down Expand Up @@ -90,7 +91,7 @@ for V in spacelist
# @test_logs (:warn, r"^`lq") match_mode = :any full_pb((ΔL, ΔQ))
end

remove_lqgauge_dependence!(ΔQ, t, Q)
remove_lq_gauge_dependence!(ΔL, ΔQ, t, L, Q)

test_ad_rrule(lq_full, t; fkwargs, atol, rtol, output_tangent = (ΔL, ΔQ))
test_ad_rrule(
Expand All @@ -114,7 +115,7 @@ for V in spacelist
Δv = rand_tangent(v)
Δd = rand_tangent(d)
Δd2 = randn!(similar(d, space(d)))
remove_eiggauge_dependence!(Δv, d, v)
remove_eig_gauge_dependence!(Δv, d, v)

test_ad_rrule(eig_full, t; output_tangent = (Δd, Δv), atol, rtol)
test_ad_rrule(first ∘ eig_full, t; output_tangent = Δd, atol, rtol)
Expand All @@ -126,7 +127,7 @@ for V in spacelist
Δv = rand_tangent(v)
Δd = rand_tangent(d)
Δd2 = randn!(similar(d, space(d)))
remove_eighgauge_dependence!(Δv, d, v)
remove_eigh_gauge_dependence!(Δv, d, v)

# necessary for FiniteDifferences to not complain
eigh_full′ = eigh_full ∘ project_hermitian
Expand Down Expand Up @@ -155,7 +156,7 @@ for V in spacelist
USVᴴ = svd_compact(t)
ΔU, ΔS, ΔVᴴ = rand_tangent.(USVᴴ)
ΔS2 = randn!(similar(ΔS, space(ΔS)))
ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol)
ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, USVᴴ...; degeneracy_atol)

# test_ad_rrule(svd_full, t; output_tangent = (ΔU, ΔS, ΔVᴴ), atol, rtol)
# test_ad_rrule(svd_full, t; output_tangent = (ΔU, ΔS2, ΔVᴴ), atol, rtol)
Expand All @@ -170,7 +171,7 @@ for V in spacelist
trunc = truncspace(V_trunc)
USVᴴ_trunc = svd_trunc(t; trunc)
ΔUSVᴴ_trunc = (rand_tangent.(Base.front(USVᴴ_trunc))..., zero(last(USVᴴ_trunc)))
remove_svdgauge_dependence!(
remove_svd_gauge_dependence!(
ΔUSVᴴ_trunc[1], ΔUSVᴴ_trunc[3], Base.front(USVᴴ_trunc)...; degeneracy_atol
)
test_ad_rrule(
Expand Down
26 changes: 14 additions & 12 deletions test/mooncake/factorizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ using TensorKit
using TensorOperations
using VectorInterface: Zero, One
using MatrixAlgebraKit
using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence!,
remove_eigh_gauge_dependence!, remove_eig_gauge_dependence!, remove_svd_gauge_dependence!
using Mooncake
using Random

Expand All @@ -25,7 +27,7 @@ eltypes = (Float64, ComplexF64)
# qr_full/qr_null requires being careful with gauges
QR = qr_full(A)
ΔQR = Mooncake.randn_tangent(rng, QR)
remove_qrgauge_dependence!(ΔQR[1], A, QR[1])
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
Expand All @@ -37,7 +39,7 @@ eltypes = (Float64, ComplexF64)
# qr_full/qr_null requires being careful with gauges
QR = qr_full(A)
ΔQR = Mooncake.randn_tangent(rng, QR)
remove_qrgauge_dependence!(ΔQR[1], A, QR[1])
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false)
Expand All @@ -51,7 +53,7 @@ eltypes = (Float64, ComplexF64)
# qr_full/qr_null requires being careful with gauges
LQ = lq_full(A)
ΔLQ = Mooncake.randn_tangent(rng, LQ)
remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2])
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
Expand All @@ -63,7 +65,7 @@ eltypes = (Float64, ComplexF64)
# qr_full/qr_null requires being careful with gauges
LQ = lq_full(A)
ΔLQ = Mooncake.randn_tangent(rng, LQ)
remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2])
remove_lq_gauge_dependence!(ΔLQ..., A, LQ...)
Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false)
# TODO:
# Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false)
Expand All @@ -73,13 +75,13 @@ eltypes = (Float64, ComplexF64)
for t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]))
DV = eig_full(t)
ΔDV = Mooncake.randn_tangent(rng, DV)
remove_eiggauge_dependence!(ΔDV[2], DV...)
remove_eig_gauge_dependence!(ΔDV[2], DV...)
Mooncake.TestUtils.test_rule(rng, eig_full, t; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false)

th = project_hermitian(t)
DV = eigh_full(th)
ΔDV = Mooncake.randn_tangent(rng, DV)
remove_eighgauge_dependence!(ΔDV[2], DV...)
remove_eigh_gauge_dependence!(ΔDV[2], DV...)
Mooncake.TestUtils.test_rule(rng, eigh_full ∘ project_hermitian, th; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false)
end
end
Expand All @@ -88,20 +90,20 @@ eltypes = (Float64, ComplexF64)
for t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])'))
USVᴴ = svd_compact(t)
ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ)
remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
Mooncake.TestUtils.test_rule(rng, svd_compact, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false)

# USVᴴ = svd_full(t)
# ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ)
# remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
# Mooncake.TestUtils.test_rule(rng, svd_full, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false)
USVᴴ = svd_full(t)
ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ)
remove_svd_gauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...)
Mooncake.TestUtils.test_rule(rng, svd_full, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false)

V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t))
trunc = truncspace(V_trunc)
alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc)
USVᴴtrunc = svd_trunc(t, alg)
ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc)))
remove_svdgauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...)
remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...)
Mooncake.TestUtils.test_rule(rng, svd_trunc, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode)
end
end
Expand Down
74 changes: 0 additions & 74 deletions test/setup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ export random_fusion
export sectorlist, fast_sectorlist
# export dim_isapprox
export default_spacelist, factorization_spacelist, ad_spacelist
export remove_qrgauge_dependence!, remove_lqgauge_dependence!
export remove_eiggauge_dependence!, remove_eighgauge_dependence!, remove_svdgauge_dependence!
export test_ad_rrule
export _isunitary, _isone

Expand Down Expand Up @@ -398,78 +396,6 @@ function ad_spacelist(fast_tests::Bool)
return fast_tests ? (Vtr, VRepU₁, VfHubbard, VRepA4Twistedℤ₄) : (Vtr, VRepℤ₂, VRepCU₁, VfHubbard, VRepA4Twistedℤ₄, VIBMRepA4)
end

# Gauge-fixing tangents for AD factorization tests
# -------------------------------------------------
function remove_qrgauge_dependence!(ΔQ, t, Q)
for (c, b) in blocks(ΔQ)
m, n = size(block(t, c))
minmn = min(m, n)
Qc = block(Q, c)
Q1 = view(Qc, 1:m, 1:minmn)
ΔQ2 = view(b, :, (minmn + 1):m)
mul!(ΔQ2, Q1, Q1' * ΔQ2)
end
return ΔQ
end
function remove_lqgauge_dependence!(ΔQ, t, Q)
for (c, b) in blocks(ΔQ)
m, n = size(block(t, c))
minmn = min(m, n)
Qc = block(Q, c)
Q1 = view(Qc, 1:minmn, 1:n)
ΔQ2 = view(b, (minmn + 1):n, :)
mul!(ΔQ2, ΔQ2 * Q1', Q1)
end
return ΔQ
end
function remove_eiggauge_dependence!(
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D)
)
gaugepart = V' * ΔV
for (c, b) in blocks(gaugepart)
Dc = diagview(block(D, c))
# for some reason this fails only on tests, and I cannot reproduce it in an
# interactive session.
# b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0
for j in axes(b, 2), i in axes(b, 1)
abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0)
end
end
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
return ΔV
end
function remove_eighgauge_dependence!(
ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D)
)
gaugepart = project_antihermitian!(V' * ΔV)
for (c, b) in blocks(gaugepart)
Dc = diagview(block(D, c))
# for some reason this fails only on tests, and I cannot reproduce it in an
# interactive session.
# b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0
for j in axes(b, 2), i in axes(b, 1)
abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0)
end
end
mul!(ΔV, V, gaugepart, -1, 1)
return ΔV
end
function remove_svdgauge_dependence!(
ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(S)
)
gaugepart = project_antihermitian!(U' * ΔU + Vᴴ * ΔVᴴ')
for (c, b) in blocks(gaugepart)
Sd = diagview(block(S, c))
# for some reason this fails only on tests, and I cannot reproduce it in an
# interactive session.
# b[abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol] .= 0
for j in axes(b, 2), i in axes(b, 1)
abs(Sd[i] - Sd[j]) >= degeneracy_atol && (b[i, j] = 0)
end
end
mul!(ΔU, U, gaugepart, -1, 1)
return ΔU, ΔVᴴ
end

# ChainRules test utilities
# -------------------------
Expand Down
Loading