From 58b15268da648d248fada97ff7c676608fc60617 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 17 Jan 2026 09:03:23 +0100 Subject: [PATCH 01/54] add `arrayify` for adjoint tensor --- ext/TensorKitMooncakeExt/tangent.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 761e626f0..9fa6e401a 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -5,3 +5,11 @@ function Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) dA = typeof(A)(data, A.space) return A, dA end + +function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TensorKit.AdjointTensorMap}) + Aᴴ = Mooncake.primal(Aᴴ_ΔAᴴ) + ΔAᴴ = Mooncake.tangent(Aᴴ_ΔAᴴ) + A_ΔA = CoDual(Aᴴ', ΔAᴴ.data.parent) + A, ΔA = arrayify(A_ΔA) + return A', ΔA' +end From 1e54e1eefe336f0db524011445ab8999be577a1f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 17 Jan 2026 09:03:29 +0100 Subject: [PATCH 02/54] add vectorinterface rules --- .../TensorKitMooncakeExt.jl | 6 +- ext/TensorKitMooncakeExt/vectorinterface.jl | 93 +++++++++++++++++++ test/autodiff/mooncake.jl | 25 +++++ 3 files changed, 121 insertions(+), 3 deletions(-) create mode 100644 ext/TensorKitMooncakeExt/vectorinterface.jl diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index b35c73f4c..2cc64f49b 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -1,17 +1,17 @@ module TensorKitMooncakeExt using Mooncake -using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal +using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoFData, NoRData, CoDual, arrayify, primal using TensorKit +using VectorInterface using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize import TensorOperations as TO -using VectorInterface: One, Zero using TupleTools - include("utility.jl") include("tangent.jl") include("linalg.jl") +include("vectorinterface.jl") include("tensoroperations.jl") end diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl new file mode 100644 index 000000000..2c1bfe984 --- /dev/null +++ b/ext/TensorKitMooncakeExt/vectorinterface.jl @@ -0,0 +1,93 @@ +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, Number} + +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + α = primal(α_Δα) + + # primal call + C_cache = copy(C) + scale!(C, α) + + function scale_pullback(::NoRData) + copy!(C, C_cache) + scale!(ΔC, conj(α)) + TΔα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Δαr = TΔα === NoRData ? NoRData() : inner(C, ΔC) + return NoRData(), NoRData(), Δαr + end + + return C_ΔC, scale_pullback +end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} + +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α = primal(α_Δα) + + # primal call + C_cache = copy(C) + scale!(C, A, α) + + function scale_pullback(::NoRData) + copy!(C, C_cache) + zerovector!(ΔC) + scale!(ΔA, conj(α)) + TΔα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Δαr = TΔα === NoRData ? NoRData() : inner(C, ΔC) + return NoRData(), NoRData(), NoRData(), Δαr + end + + return C_ΔC, scale_pullback +end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} + +function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α = primal(α_Δα) + β = primal(β_Δβ) + + # primal call + C_cache = copy(C) + add!(C, A, α, β) + + function add_pullback(::NoRData) + copy!(C, C_cache) + scale!(ΔC, conj(β)) + scale!(ΔA, conj(α)) + + TΔα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Δαr = TΔα === NoRData ? NoRData() : inner(A, ΔC) + TΔβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Δβr = TΔβ === NoRData ? NoRData() : inner(C, ΔC) + + return NoRData(), NoRData(), NoRData(), Δαr, Δβr + end + + return C_ΔC, add_pullback +end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} + +function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}) + # prepare arguments + A, ΔA = arrayify(A_ΔA) + B, ΔB = arrayify(B_ΔB) + + # primal call + s = inner(A, B) + + function inner_pullback(Δs) + scale!(ΔA, B, conj(Δs)) + scale!(ΔB, A, Δs) + return NoRData(), NoRData(), NoRData() + end + + return CoDual(s, NoFData()), inner_pullback +end diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index 1cd74fa27..37eb932b5 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -68,6 +68,31 @@ for V in spacelist println("Mooncake with symmetry: $Istr") println("---------------------------------------") eltypes = (Float64,) # no complex support yet + + @timedtestset "VectorInterface with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) + + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, mode, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, mode, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode) + end + symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes atol = precision(T) rtol = precision(T) From 28e6c1ddcaa62c098512e62c31ff6186982613a3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 17 Jan 2026 12:37:19 +0100 Subject: [PATCH 03/54] add tensoroperations rules --- ext/TensorKitMooncakeExt/tensoroperations.jl | 178 +++++++++++++++++++ test/autodiff/mooncake.jl | 44 +++++ 2 files changed, 222 insertions(+) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index d663a3281..7b9a674f8 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -1,3 +1,91 @@ +# tensoradd! +# ---------- +Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TO.tensoradd!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, Bool, + Number, Number, Vararg{Any}, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TO.tensoradd!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + ba_Δba::CoDual... + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + pA = primal(pA_ΔpA) + conjA = primal(conjA_ΔconjA) + α, β = primal.((α_Δα, β_Δβ)) + ba = primal.(ba_Δba) + + # primal call + C_cache = copy(C) + TO.tensoradd!(C, A, pA, conjA, α, β, ba...) + + function tensoradd_pullback(::NoRData) + copy!(C, C_cache) + + ΔCr = tensoradd_pullback_ΔC!(ΔC, β) + ΔAr = tensoradd_pullback_ΔA!(ΔA, ΔC, A, pA, conjA, α, ba...) + Δαr = tensoradd_pullback_Δα(ΔC, A, pA, conjA, α, ba...) + Δβr = tensoradd_pullback_Δβ(ΔC, C, β) + + return NoRData(), + ΔCr, + ΔAr, NoRData(), NoRData(), + Δαr, Δβr, + map(Returns(NoRData()), ba)... + end + + return C_ΔC, tensoradd_pullback +end + +tensoradd_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) + +function tensoradd_pullback_ΔA!( + ΔA, ΔC, A, pA, conjA, α, ba... + ) + ipA = invperm(linearize(pA)) + pΔA = _repartition(ipA, A) + TO.tensoradd!(ΔA, ΔC, pΔA, conjA, conjA ? α : conj(α), Zero(), ba...) + return NoRData() +end + +function tensoradd_pullback_Δα( + ΔC, A, pA, conjA, α, ba... + ) + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Tdα === NoRData && return NoRData() + + tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)); copy = false) + Δα = TO.tensorscalar( + TO.tensorcontract( + A, ((), linearize(pA)), !conjA, + tΔC, (trivtuple(TO.numind(pA)), ()), false, + ((), ()), One(), ba... + ) + ) + return Mooncake._rdata(Δα) +end + +function tensoradd_pullback_Δβ(ΔC, C, β) + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Tdβ === NoRData && return NoRData() + + Δβ = inner(C, ΔC) + return Mooncake._rdata(Δβ) +end + +# tensorcontract! +# --------------- Mooncake.@is_primitive( DefaultCtx, ReverseMode, @@ -135,3 +223,93 @@ function tensorcontract_pullback_Δβ(ΔC, C, β) Δβ = inner(C, ΔC) return Mooncake._rdata(Δβ) end + +# tensortrace! +# ------------ +Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TO.tensortrace!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, Index2Tuple, Bool, + Number, Number, + Vararg{Any}, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TO.tensortrace!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, q_Δq::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + ba_Δba::CoDual... + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + q = primal(q_Δq) + conjA = primal(conjA_ΔconjA) + α, β = primal.((α_Δα, β_Δβ)) + ba = primal.(ba_Δba) + + # primal call + C_cache = copy(C) + TO.tensortrace!(C, A, p, q, conjA, α, β, ba...) + + function tensortrace_pullback(::NoRData) + copy!(C, C_cache) + + ΔCr = tensortrace_pullback_ΔC!(ΔC, β) + ΔAr = tensortrace_pullback_ΔA!(ΔA, ΔC, A, p, q, conjA, α, ba...) + Δαr = tensortrace_pullback_Δα(ΔC, A, p, q, conjA, α, ba...) + Δβr = tensortrace_pullback_Δβ(ΔC, C, β) + + return NoRData(), + ΔCr, + ΔAr, NoRData(), NoRData(), NoRData(), + Δαr, Δβr, + map(Returns(NoRData()), ba)... + end + + return C_ΔC, tensortrace_pullback +end + +tensortrace_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) + +function tensortrace_pullback_ΔA!( + ΔA, ΔC, A, p, q, conjA, α, ba... + ) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = _repartition(ip, A) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TO.tensorproduct!( + ΔA, ΔC, pΔC, conjA, E, pE, conjA, pdA, conjA ? α : conj(α), Zero(), ba... + ) + return NoRData() +end + +function tensortrace_pullback_Δα( + ΔC, A, p, q, conjA, α, ba... + ) + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Tdα === NoRData && return NoRData() + + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = TO.tensortrace(A, p, q, conjA) + Δα = inner(At, ΔC) + return Mooncake._rdata(Δα) +end + +function tensortrace_pullback_Δβ(ΔC, C, β) + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Tdβ === NoRData && return NoRData() + + Δβ = inner(C, ΔC) + return Mooncake._rdata(Δβ) +end diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index 37eb932b5..38fa23c15 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -97,6 +97,25 @@ for V in spacelist atol = precision(T) rtol = precision(T) + @timedtestset "tensoradd!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randindextuple(numind(A)) + + C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C1, A, p, false, α, β; atol, rtol, mode) + + C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false))) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C2, A, p, true, α, β; atol, rtol, mode) + + A = rand(Bool) ? C1 : C2 + end + end + @timedtestset "tensorcontract!" begin for _ in 1:5 d = 0 @@ -138,5 +157,30 @@ for V in spacelist end end end + + @timedtestset "tensortrace!" begin + for _ in 1:5 + k1 = rand(0:2) + k2 = rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + for conjA in (false, true) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA, Val(false))) + Mooncake.TestUtils.test_rule( + rng, tensortrace!, C, A, p, q, conjA, α, β; + atol, rtol, mode, is_primitive = false + ) + end + end + end end end From 94112b88c8afc0b43ff08c4f995d78d7556d3ee2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 18 Jan 2026 08:49:45 -0500 Subject: [PATCH 04/54] add indexmanipulations rules --- .../TensorKitMooncakeExt.jl | 2 + .../indexmanipulations.jl | 153 ++++++++++++++++ ext/TensorKitMooncakeExt/tensoroperations.jl | 166 +++++++++--------- test/autodiff/mooncake.jl | 61 ++++++- 4 files changed, 297 insertions(+), 85 deletions(-) create mode 100644 ext/TensorKitMooncakeExt/indexmanipulations.jl diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index 2cc64f49b..15e0c4c9f 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -3,6 +3,7 @@ module TensorKitMooncakeExt using Mooncake using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoFData, NoRData, CoDual, arrayify, primal using TensorKit +import TensorKit as TK using VectorInterface using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize import TensorOperations as TO @@ -11,6 +12,7 @@ using TupleTools include("utility.jl") include("tangent.jl") include("linalg.jl") +include("indexmanipulations.jl") include("vectorinterface.jl") include("tensoroperations.jl") diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl new file mode 100644 index 000000000..a0b73dde2 --- /dev/null +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -0,0 +1,153 @@ +for transform in (:permute, :transpose) + add_transform! = Symbol(:add_, transform, :!) + add_transform_pullback = Symbol(add_transform!, :_pullback) + @eval Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TK.$add_transform!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, + Number, Number, Vararg{Any}, + } + ) + + @eval function Mooncake.rrule!!( + ::CoDual{typeof(TK.$add_transform!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + ba_Δba::CoDual... + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + α, β = primal.((α_Δα, β_Δβ)) + ba = primal.(ba_Δba) + + C_cache = copy(C) + + # if we need to compute Δa, it is faster to allocate an intermediate permuted A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Ap = if Tdα === NoRData + TK.$add_transform!(C, A, p, α, β, ba...) + nothing + else + Ap = $transform(A, p) + add!(C, Ap, α, β) + Ap + end + + function $add_transform_pullback(::NoRData) + copy!(C, C_cache) + + scale!(ΔC, conj(β)) + ΔCr = NoRData() + + # ΔA + ip = invperm(linearize(p)) + pΔA = _repartition(ip, A) + TK.$add_transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...) + ΔAr = NoRData() + + # Δα + Δαr = if isnothing(Ap) + NoRData() + else + Mooncake._rdata(inner(Ap, ΔC)) + end + + # Δβ + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Δβr = if Tdβ === NoRData + NoRData() + else + Mooncake._rdata(inner(C, ΔC)) + end + + + return NoRData(), ΔCr, ΔAr, NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... + end + + return C_ΔC, $add_transform_pullback + end +end + +Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TK.add_braid!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, IndexTuple, + Number, Number, Vararg{Any}, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TK.add_braid!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, levels_Δlevels::CoDual{<:IndexTuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + ba_Δba::CoDual... + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + levels = primal(levels_Δlevels) + α, β = primal.((α_Δα, β_Δβ)) + ba = primal.(ba_Δba) + + C_cache = copy(C) + + # if we need to compute Δa, it is faster to allocate an intermediate braided A + # and store that instead of repeating the permutation in the pullback each time. + # effectively, we replace `add_permute` by `add ∘ permute`. + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Ap = if Tdα === NoRData + TK.add_braid!(C, A, p, levels, α, β, ba...) + nothing + else + Ap = braid(A, p, levels) + add!(C, Ap, α, β) + Ap + end + + function add_braid!_pullback(::NoRData) + copy!(C, C_cache) + + scale!(ΔC, conj(β)) + ΔCr = NoRData() + + # ΔA + ip = invperm(linearize(p)) + pΔA = _repartition(ip, A) + ilevels = TupleTools.permute(levels, linearize(p)) + TK.add_braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...) + ΔAr = NoRData() + + # Δα + Δαr = if isnothing(Ap) + NoRData() + else + Mooncake._rdata(inner(Ap, ΔC)) + end + + # Δβ + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Δβr = if Tdβ === NoRData + NoRData() + else + Mooncake._rdata(inner(C, ΔC)) + end + + + return NoRData(), ΔCr, ΔAr, NoRData(), NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... + end + + return C_ΔC, add_braid!_pullback +end diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 7b9a674f8..915a10356 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -1,88 +1,88 @@ # tensoradd! # ---------- -Mooncake.@is_primitive( - DefaultCtx, - ReverseMode, - Tuple{ - typeof(TO.tensoradd!), - AbstractTensorMap, - AbstractTensorMap, Index2Tuple, Bool, - Number, Number, Vararg{Any}, - } -) - -function Mooncake.rrule!!( - ::CoDual{typeof(TO.tensoradd!)}, - C_ΔC::CoDual{<:AbstractTensorMap}, - A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, - α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, - ba_Δba::CoDual... - ) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - A, ΔA = arrayify(A_ΔA) - pA = primal(pA_ΔpA) - conjA = primal(conjA_ΔconjA) - α, β = primal.((α_Δα, β_Δβ)) - ba = primal.(ba_Δba) - - # primal call - C_cache = copy(C) - TO.tensoradd!(C, A, pA, conjA, α, β, ba...) - - function tensoradd_pullback(::NoRData) - copy!(C, C_cache) - - ΔCr = tensoradd_pullback_ΔC!(ΔC, β) - ΔAr = tensoradd_pullback_ΔA!(ΔA, ΔC, A, pA, conjA, α, ba...) - Δαr = tensoradd_pullback_Δα(ΔC, A, pA, conjA, α, ba...) - Δβr = tensoradd_pullback_Δβ(ΔC, C, β) - - return NoRData(), - ΔCr, - ΔAr, NoRData(), NoRData(), - Δαr, Δβr, - map(Returns(NoRData()), ba)... - end - - return C_ΔC, tensoradd_pullback -end - -tensoradd_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) - -function tensoradd_pullback_ΔA!( - ΔA, ΔC, A, pA, conjA, α, ba... - ) - ipA = invperm(linearize(pA)) - pΔA = _repartition(ipA, A) - TO.tensoradd!(ΔA, ΔC, pΔA, conjA, conjA ? α : conj(α), Zero(), ba...) - return NoRData() -end - -function tensoradd_pullback_Δα( - ΔC, A, pA, conjA, α, ba... - ) - Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) - Tdα === NoRData && return NoRData() - - tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)); copy = false) - Δα = TO.tensorscalar( - TO.tensorcontract( - A, ((), linearize(pA)), !conjA, - tΔC, (trivtuple(TO.numind(pA)), ()), false, - ((), ()), One(), ba... - ) - ) - return Mooncake._rdata(Δα) -end - -function tensoradd_pullback_Δβ(ΔC, C, β) - Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) - Tdβ === NoRData && return NoRData() - - Δβ = inner(C, ΔC) - return Mooncake._rdata(Δβ) -end +# Mooncake.@is_primitive( +# DefaultCtx, +# ReverseMode, +# Tuple{ +# typeof(TO.tensoradd!), +# AbstractTensorMap, +# AbstractTensorMap, Index2Tuple, Bool, +# Number, Number, Vararg{Any}, +# } +# ) +# +# function Mooncake.rrule!!( +# ::CoDual{typeof(TO.tensoradd!)}, +# C_ΔC::CoDual{<:AbstractTensorMap}, +# A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, +# α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, +# ba_Δba::CoDual... +# ) +# # prepare arguments +# C, ΔC = arrayify(C_ΔC) +# A, ΔA = arrayify(A_ΔA) +# pA = primal(pA_ΔpA) +# conjA = primal(conjA_ΔconjA) +# α, β = primal.((α_Δα, β_Δβ)) +# ba = primal.(ba_Δba) +# +# # primal call +# C_cache = copy(C) +# TO.tensoradd!(C, A, pA, conjA, α, β, ba...) +# +# function tensoradd_pullback(::NoRData) +# copy!(C, C_cache) +# +# ΔCr = tensoradd_pullback_ΔC!(ΔC, β) +# ΔAr = tensoradd_pullback_ΔA!(ΔA, ΔC, A, pA, conjA, α, ba...) +# Δαr = tensoradd_pullback_Δα(ΔC, A, pA, conjA, α, ba...) +# Δβr = tensoradd_pullback_Δβ(ΔC, C, β) +# +# return NoRData(), +# ΔCr, +# ΔAr, NoRData(), NoRData(), +# Δαr, Δβr, +# map(Returns(NoRData()), ba)... +# end +# +# return C_ΔC, tensoradd_pullback +# end +# +# tensoradd_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) +# +# function tensoradd_pullback_ΔA!( +# ΔA, ΔC, A, pA, conjA, α, ba... +# ) +# ipA = invperm(linearize(pA)) +# pΔA = _repartition(ipA, A) +# TO.tensoradd!(ΔA, ΔC, pΔA, conjA, conjA ? α : conj(α), Zero(), ba...) +# return NoRData() +# end +# +# function tensoradd_pullback_Δα( +# ΔC, A, pA, conjA, α, ba... +# ) +# Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) +# Tdα === NoRData && return NoRData() +# +# tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)); copy = false) +# Δα = TO.tensorscalar( +# TO.tensorcontract( +# A, ((), linearize(pA)), !conjA, +# tΔC, (trivtuple(TO.numind(pA)), ()), false, +# ((), ()), One(), ba... +# ) +# ) +# return Mooncake._rdata(Δα) +# end +# +# function tensoradd_pullback_Δβ(ΔC, C, β) +# Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) +# Tdβ === NoRData && return NoRData() +# +# Δβ = inner(C, ΔC) +# return Mooncake._rdata(Δβ) +# end # tensorcontract! # --------------- diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index 38fa23c15..2ca21654e 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -3,6 +3,7 @@ using TensorKit using TensorOperations using Mooncake using Random +using TupleTools mode = Mooncake.ReverseMode rng = Random.default_rng() @@ -13,6 +14,14 @@ function randindextuple(N::Int, k::Int = rand(0:N)) _p = randperm(N) return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...)) end +function randcircshift(N₁::Int, N₂::Int, k::Int = rand(0:(N₁ + N₂))) + N = N₁ + N₂ + @assert 0 ≤ k ≤ N + p = TupleTools.vcat(ntuple(identity, N₁), reverse(ntuple(identity, N₂) .+ N₁)) + n = rand(0:N) + _p = TupleTools.circshift(p, n) + return (tuple(_p[1:k]...), reverse(tuple(_p[(k + 1):end]...))) +end const _repartition = @static if isdefined(Base, :get_extension) Base.get_extension(TensorKit, :TensorKitMooncakeExt)._repartition @@ -93,6 +102,54 @@ for V in spacelist Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode) end + @timedtestset "Index manipulations with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) + + symmetricbraiding && @timedtestset "add_permute!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randindextuple(numind(A)) + C = randn!(permute(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_permute!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "add_transpose!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randcircshift(numout(A), numin(A)) + C = randn!(transpose(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "add_braid!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randcircshift(numout(A), numin(A)) + levels = tuple(randperm(numind(A))) + C = randn!(transpose(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + end + symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes atol = precision(T) rtol = precision(T) @@ -107,10 +164,10 @@ for V in spacelist p = randindextuple(numind(A)) C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C1, A, p, false, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C1, A, p, false, α, β; atol, rtol, mode, is_primitive = false) C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false))) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C2, A, p, true, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C2, A, p, true, α, β; atol, rtol, mode, is_primitive = false) A = rand(Bool) ? C1 : C2 end From c31996a65f2bdade06130d8f494560b8dd495500 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 20 Jan 2026 12:05:46 -0500 Subject: [PATCH 05/54] add mul rules --- ext/TensorKitMooncakeExt/linalg.jl | 39 ++++++++++++++++++++++++++++++ test/autodiff/mooncake.jl | 18 ++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 56533d227..d0d73d951 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -1,3 +1,42 @@ +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number} + +function Mooncake.rrule!!( + ::CoDual{typeof(mul!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number} + ) + (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) + α, β = primal.((α_Δα, β_Δβ)) + + # primal call + C_cache = copy(C) + AB = if _needs_tangent(α) + AB = A * B + add!(C, AB, α, β) + AB + else + mul!(C, A, B, α, β) + nothing + end + + function mul_pullback(::NoRData) + copy!(C, C_cache) + + scale!(ΔC, conj(β)) + mul!(ΔA, ΔC, B', conj(α), One()) + mul!(ΔB, A', ΔC, conj(α), One()) + ΔCr = NoRData() + ΔAr = NoRData() + ΔBr = NoRData() + Δαr = isnothing(AB) ? NoRData() : Mooncake._rdata(inner(AB, ΔC)) + Δβr = _needs_tangent(β) ? Mooncake._rdata(inner(C, ΔC)) : NoRData() + + return NoRData(), ΔCr, ΔAr, ΔBr, Δαr, Δβr + end + + return C_ΔC, mul_pullback +end + Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real} function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real}) diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index 2ca21654e..4df18a331 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -102,6 +102,24 @@ for V in spacelist Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode) end + @timedtestset "LinearAlgebra with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) + + C = randn(T, V[1] ⊗ V[2] ← V[5]) + A = randn(T, codomain(C) ← V[3] ⊗ V[4]) + B = randn(T, domain(A) ← domain(C)) + α = randn(T) + β = randn(T) + + Mooncake.TestUtils.test_rule(rng, mul!, C, A, B, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, mul!, C, A, B; atol, rtol, mode, is_primitive = false) + + Mooncake.TestUtils.test_rule(rng, norm, C, 2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, norm, C', 2; atol, rtol, mode) + end + + @timedtestset "Index manipulations with scalartype $T" for T in eltypes atol = precision(T) rtol = precision(T) From ec0fe094119a00b73d24bc655d88e0e3c931c077 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 20 Jan 2026 14:56:06 -0500 Subject: [PATCH 06/54] temporarily disable Fibonacci (complex) spaces --- test/autodiff/mooncake.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index 4df18a331..3ca512f56 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -59,13 +59,13 @@ spacelist = ( Vect[SU2Irrep](1 // 2 => 2), Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', ), - ( - Vect[FibonacciAnyon](:I => 2, :τ => 1), - Vect[FibonacciAnyon](:I => 1, :τ => 2)', - Vect[FibonacciAnyon](:I => 2, :τ => 2)', - Vect[FibonacciAnyon](:I => 2, :τ => 3), - Vect[FibonacciAnyon](:I => 2, :τ => 2), - ), + # ( + # Vect[FibonacciAnyon](:I => 2, :τ => 1), + # Vect[FibonacciAnyon](:I => 1, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 3), + # Vect[FibonacciAnyon](:I => 2, :τ => 2), + # ), ) for V in spacelist From 4642ff799547addf0d1dff14c6833088d537f247 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 20 Jan 2026 17:10:35 -0500 Subject: [PATCH 07/54] bump TupleTools compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 22a4915d7..8b4980ecc 100644 --- a/Project.toml +++ b/Project.toml @@ -56,7 +56,7 @@ TensorKitSectors = "0.3.5" TensorOperations = "5.1" Test = "1" TestExtras = "0.2,0.3" -TupleTools = "1.1" +TupleTools = "1.5" VectorInterface = "0.4.8, 0.5" Zygote = "0.7" cuTENSOR = "2" From 2f98e0d3f1e1601735f9dd97090aad7e66f3e8b0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 14:19:11 -0500 Subject: [PATCH 08/54] add twist! rule --- .../indexmanipulations.jl | 45 +++++++++++++++++++ test/autodiff/mooncake.jl | 8 ++++ 2 files changed, 53 insertions(+) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index a0b73dde2..000ae5d83 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -151,3 +151,48 @@ function Mooncake.rrule!!( return C_ΔC, add_braid!_pullback end + +# both are needed for correctly capturing every dispatch +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(twist!), AbstractTensorMap, Any} +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), @NamedTuple{inv::Bool}, typeof(twist!), AbstractTensorMap, Any} + +function Mooncake.rrule!!(::CoDual{typeof(twist!)}, t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = false + inds = primal(inds_Δinds) + + # primal call + t_cache = copy(t) + twist!(t, inds; inv) + + function twist_pullback(::NoRData) + copy!(t, t_cache) + twist!(Δt, inds; inv = !inv) + return ntuple(Returns(NoRData()), 3) + end + + return t_Δt, twist_pullback + +end +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{@NamedTuple{inv::Bool}}, ::CoDual{typeof(twist!)}, + t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual + ) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = primal(kwargs_Δkwargs).inv + inds = primal(inds_Δinds) + + # primal call + t_cache = copy(t) + twist!(t, inds; inv) + + function twist_pullback(::NoRData) + copy!(t, t_cache) + twist!(Δt, inds; inv = !inv) + return ntuple(Returns(NoRData()), 5) + end + + return t_Δt, twist_pullback +end diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index 3ca512f56..85b251885 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -166,6 +166,14 @@ for V in spacelist A = C end end + + @timedtestset "twist!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), twist!, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), twist!, A, [1, 3]; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, twist!, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, twist!, A, [1, 3]; atol, rtol, mode) + end end symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes From 10a50cb45f4e9152f3d1908fc20a2e4d85236922 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 14:50:48 -0500 Subject: [PATCH 09/54] add flip rule --- .../indexmanipulations.jl | 44 +++++++++++++++++++ test/autodiff/mooncake.jl | 7 ++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index 000ae5d83..9e98023e2 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -196,3 +196,47 @@ function Mooncake.rrule!!( return t_Δt, twist_pullback end + +# both are needed for correctly capturing every dispatch +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(flip), AbstractTensorMap, Any} +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), @NamedTuple{inv::Bool}, typeof(flip), AbstractTensorMap, Any} + +function Mooncake.rrule!!(::CoDual{typeof(flip)}, t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = false + inds = primal(inds_Δinds) + + # primal call + t_flipped = flip(t, inds; inv) + t_flipped_Δt_flipped = Mooncake.zero_fcodual(t_flipped) + _, Δt_flipped = arrayify(t_flipped_Δt_flipped) + + function twist_pullback(::NoRData) + copy!(Δt, flip(Δt_flipped, inds; inv = !inv)) + return ntuple(Returns(NoRData()), 3) + end + + return t_flipped_Δt_flipped, twist_pullback +end +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{@NamedTuple{inv::Bool}}, ::CoDual{typeof(flip)}, + t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual + ) + # prepare arguments + t, Δt = arrayify(t_Δt) + inv = primal(kwargs_Δkwargs).inv + inds = primal(inds_Δinds) + + # primal call + t_flipped = flip(t, inds; inv) + t_flipped_Δt_flipped = Mooncake.zero_fcodual(t_flipped) + _, Δt_flipped = arrayify(t_flipped_Δt_flipped) + + function twist_pullback(::NoRData) + copy!(Δt, flip(Δt_flipped, inds; inv = !inv)) + return ntuple(Returns(NoRData()), 5) + end + + return t_flipped_Δt_flipped, twist_pullback +end diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index 85b251885..ace67dae7 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -167,12 +167,17 @@ for V in spacelist end end - @timedtestset "twist!" begin + @timedtestset "flip_n_twist!" begin A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), twist!, A, 1; atol, rtol, mode) Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), twist!, A, [1, 3]; atol, rtol, mode) Mooncake.TestUtils.test_rule(rng, twist!, A, 1; atol, rtol, mode) Mooncake.TestUtils.test_rule(rng, twist!, A, [1, 3]; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), flip, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), flip, A, [1, 3]; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, flip, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, flip, A, [1, 3]; atol, rtol, mode) end end From d7c050a1fe615fb468d2506db8424d5ee948565e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 15:37:21 -0500 Subject: [PATCH 10/54] vector spaces arent vector spaces! --- ext/TensorKitMooncakeExt/utility.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index ca2c79b54..f45aaf3bc 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -25,4 +25,8 @@ end # Ignore derivatives # ------------------ + +# A VectorSpace has no meaningful notion of a vector space (tangent space) +Mooncake.tangent_type(::Type{<:VectorSpace}) = Mooncake.NoTangent + @zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} From c1a1e8b8a4cb2a1b979ea8de83d368634efd7e37 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 15:50:17 -0500 Subject: [PATCH 11/54] insert and remove units --- .../indexmanipulations.jl | 167 ++++++++++++++++++ test/autodiff/mooncake.jl | 21 +++ 2 files changed, 188 insertions(+) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index 9e98023e2..464c18392 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -240,3 +240,170 @@ function Mooncake.rrule!!( return t_flipped_Δt_flipped, twist_pullback end + +for insertunit in (:insertleftunit, :insertrightunit) + insertunit_pullback = Symbol(insertunit, :_pullback) + @eval begin + # both are needed for correctly capturing every dispatch + Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof($insertunit), AbstractTensorMap, Val} + Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), NamedTuple, typeof($insertunit), AbstractTensorMap, Val} + + function Mooncake.rrule!!(::CoDual{typeof($insertunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{<:Val}) + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + + # tdst shares data with tsrc if <:TensorMap, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap + tsrc_cache = copy(tsrc) + tdst = $insertunit(tsrc, ival) + # note: this is somewhat of a hack that makes use of the fact that the tangent is + # encoded without any information about the space, which allows us to simply reuse + # the tangent exactly without having to modify the space information + tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + else + tsrc_cache = nothing + tdst = $insertunit(tsrc, ival) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function $insertunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + copy!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 3) + end + + return tdst_Δtdst, $insertunit_pullback + end + function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{<:NamedTuple}, + ::CoDual{typeof($insertunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{<:Val} + ) + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + kwargs = primal(kwargs_Δkwargs) + + # tdst shares data with tsrc if <:TensorMap & copy=false, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap && !get(kwargs, :copy, false) + tsrc_cache = copy(tsrc) + tdst = $insertunit(tsrc, ival; kwargs...) + # note: this is somewhat of a hack that makes use of the fact that the tangent is + # encoded without any information about the space, which allows us to simply reuse + # the tangent exactly without having to modify the space information + tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + else + tsrc_cache = nothing + tdst = $insertunit(tsrc, ival; kwargs...) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function $insertunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + copy!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 5) + end + + return tdst_Δtdst, $insertunit_pullback + end + end +end + + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(removeunit), AbstractTensorMap, Val} +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), NamedTuple, typeof(removeunit), AbstractTensorMap, Val} + +function Mooncake.rrule!!(::CoDual{typeof(removeunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{Val{i}}) where {i} + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + + # tdst shares data with tsrc if <:TensorMap, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap + tsrc_cache = copy(tsrc) + tdst = removeunit(tsrc, ival) + # note: this is somewhat of a hack that makes use of the fact that the tangent is + # encoded without any information about the space, which allows us to simply reuse + # the tangent exactly without having to modify the space information + tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + else + tsrc_cache = nothing + tdst = removeunit(tsrc, ival) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function removeunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + copy!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 3) + end + + return tdst_Δtdst, removeunit_pullback +end +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{<:NamedTuple}, + ::CoDual{typeof(removeunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{<:Val} + ) + # prepare arguments + tsrc, Δtsrc = arrayify(tsrc_Δtsrc) + ival = primal(ival_Δival) + kwargs = primal(kwargs_Δkwargs) + + # tdst shares data with tsrc if <:TensorMap & copy=false, in this case we have to deal with correctly + # sharing address spaces + if tsrc isa TensorMap && !get(kwargs, :copy, false) + tsrc_cache = copy(tsrc) + tdst = removeunit(tsrc, ival; kwargs...) + # note: this is somewhat of a hack that makes use of the fact that the tangent is + # encoded without any information about the space, which allows us to simply reuse + # the tangent exactly without having to modify the space information + tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + else + tsrc_cache = nothing + tdst = removeunit(tsrc, ival; kwargs...) + tdst_Δtdst = Mooncake.zero_fcodual(tdst) + end + + _, Δtdst = arrayify(tdst_Δtdst) + + function removeunit_pullback(::NoRData) + if isnothing(tsrc_cache) + for (c, b) in blocks(Δtdst) + copy!(block(Δtsrc, c), b) + end + else + copy!(tsrc, tsrc_cache) + # note: since data is already shared, don't have to do anything here! + end + return ntuple(Returns(NoRData()), 5) + end + + return tdst_Δtdst, removeunit_pullback +end diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index ace67dae7..a5b08fc90 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -179,6 +179,27 @@ for V in spacelist Mooncake.TestUtils.test_rule(rng, flip, A, 1; atol, rtol, mode) Mooncake.TestUtils.test_rule(rng, flip, A, [1, 3]; atol, rtol, mode) end + + @timedtestset "insert and remove units" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + + for insertunit in (insertleftunit, insertrightunit) + Mooncake.TestUtils.test_rule(rng, insertunit, A, Val(1); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, insertunit, A, Val(4); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, insertunit, A', Val(2); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false), insertunit, A, Val(1); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = true), insertunit, A, Val(2); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false, dual = true, conj = true), insertunit, A, Val(3); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false, dual = true, conj = true), insertunit, A', Val(3); atol, rtol, mode) + end + + for i in 1:4 + B = insertleftunit(A, i; dual = rand(Bool)) + Mooncake.TestUtils.test_rule(rng, removeunit, B, Val(i); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false), removeunit, B, Val(i); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = true), removeunit, B, Val(i); atol, rtol, mode) + end + end end symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes From fac47edf2f9e01e69822e8a02f1d3f4bd29c47ba Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 17:08:30 -0500 Subject: [PATCH 12/54] mark a bunch of things as non-differentiable --- ext/TensorKitMooncakeExt/utility.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index f45aaf3bc..e93de22be 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -30,3 +30,10 @@ end Mooncake.tangent_type(::Type{<:VectorSpace}) = Mooncake.NoTangent @zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} + +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.braid), HomSpace, Index2Tuple, IndexTuple} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.compose), HomSpace, HomSpace} +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract), HomSpace, Index2Tuple, Bool, HomSpace, Index2Tuple, Bool, Index2Tuple} From 9b90eb6ff89beb9295042b94e69b628964cdc608 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 17:08:53 -0500 Subject: [PATCH 13/54] rewrite rule for `tensortrace!` in terms of `trace_permute!` --- ext/TensorKitMooncakeExt/tensoroperations.jl | 53 +++++++++----------- test/autodiff/mooncake.jl | 16 +++--- 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 915a10356..989ee2830 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -230,83 +230,80 @@ Mooncake.@is_primitive( DefaultCtx, ReverseMode, Tuple{ - typeof(TO.tensortrace!), + typeof(TensorKit.trace_permute!), AbstractTensorMap, - AbstractTensorMap, Index2Tuple, Index2Tuple, Bool, + AbstractTensorMap, Index2Tuple, Index2Tuple, Number, Number, - Vararg{Any}, + Any, } ) function Mooncake.rrule!!( - ::CoDual{typeof(TO.tensortrace!)}, + ::CoDual{typeof(TensorKit.trace_permute!)}, C_ΔC::CoDual{<:AbstractTensorMap}, - A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, q_Δq::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, q_Δq::CoDual{<:Index2Tuple}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, - ba_Δba::CoDual... + backend_Δbackend::CoDual ) # prepare arguments C, ΔC = arrayify(C_ΔC) A, ΔA = arrayify(A_ΔA) p = primal(p_Δp) q = primal(q_Δq) - conjA = primal(conjA_ΔconjA) α, β = primal.((α_Δα, β_Δβ)) - ba = primal.(ba_Δba) + backend = primal(backend_Δbackend) # primal call C_cache = copy(C) - TO.tensortrace!(C, A, p, q, conjA, α, β, ba...) + TensorKit.trace_permute!(C, A, p, q, α, β, backend) - function tensortrace_pullback(::NoRData) + function trace_permute_pullback(::NoRData) copy!(C, C_cache) - ΔCr = tensortrace_pullback_ΔC!(ΔC, β) - ΔAr = tensortrace_pullback_ΔA!(ΔA, ΔC, A, p, q, conjA, α, ba...) - Δαr = tensortrace_pullback_Δα(ΔC, A, p, q, conjA, α, ba...) - Δβr = tensortrace_pullback_Δβ(ΔC, C, β) + ΔAr = trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) + Δαr = trace_permute_pullback_Δα(ΔC, A, p, q, α, backend) + Δβr = trace_permute_pullback_Δβ(ΔC, C, β) + ΔCr = trace_permute_pullback_ΔC!(ΔC, β) return NoRData(), - ΔCr, - ΔAr, NoRData(), NoRData(), NoRData(), - Δαr, Δβr, - map(Returns(NoRData()), ba)... + ΔCr, ΔAr, NoRData(), NoRData(), + Δαr, Δβr, NoRData() end - return C_ΔC, tensortrace_pullback + return C_ΔC, trace_permute_pullback end -tensortrace_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) +trace_permute_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) -function tensortrace_pullback_ΔA!( - ΔA, ΔC, A, p, q, conjA, α, ba... +function trace_permute_pullback_ΔA!( + ΔA, ΔC, A, p, q, α, backend ) ip = invperm((linearize(p)..., q[1]..., q[2]...)) pdA = _repartition(ip, A) - E = one!(TO.tensoralloc_add(scalartype(A), A, q, conjA)) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) pE = ((), trivtuple(TO.numind(q))) pΔC = (trivtuple(TO.numind(p)), ()) TO.tensorproduct!( - ΔA, ΔC, pΔC, conjA, E, pE, conjA, pdA, conjA ? α : conj(α), Zero(), ba... + ΔA, ΔC, pΔC, false, E, pE, false, pdA, conj(α), One(), backend ) return NoRData() end -function tensortrace_pullback_Δα( - ΔC, A, p, q, conjA, α, ba... +function trace_permute_pullback_Δα( + ΔC, A, p, q, α, backend ) Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) Tdα === NoRData && return NoRData() # TODO: this result might be easier to compute as: # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α - At = TO.tensortrace(A, p, q, conjA) + At = TO.tensortrace(A, p, q, false, One(), backend) Δα = inner(At, ΔC) return Mooncake._rdata(Δα) end -function tensortrace_pullback_Δβ(ΔC, C, β) +function trace_permute_pullback_Δβ(ΔC, C, β) Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) Tdβ === NoRData && return NoRData() diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index a5b08fc90..cca2c92d3 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -260,14 +260,14 @@ for V in spacelist ) Mooncake.TestUtils.test_rule( rng, tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β; - atol, rtol, mode, is_primitive + atol, rtol, mode ) end end end - @timedtestset "tensortrace!" begin + @timedtestset "trace_permute!" begin for _ in 1:5 k1 = rand(0:2) k2 = rand(1:2) @@ -282,13 +282,11 @@ for V in spacelist α = randn(T) β = randn(T) - for conjA in (false, true) - C = randn!(TensorOperations.tensoralloc_add(T, A, p, conjA, Val(false))) - Mooncake.TestUtils.test_rule( - rng, tensortrace!, C, A, p, q, conjA, α, β; - atol, rtol, mode, is_primitive = false - ) - end + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + Mooncake.TestUtils.test_rule( + rng, TensorKit.trace_permute!, C, A, p, q, α, β, TensorOperations.DefaultBackend(); + atol, rtol, mode + ) end end end From bb287fe85caf02d57c33fb69e6bcb1cea24b581f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 17:10:38 -0500 Subject: [PATCH 14/54] dont need rules for `tensoradd!` --- ext/TensorKitMooncakeExt/tensoroperations.jl | 86 -------------------- test/autodiff/mooncake.jl | 19 ----- 2 files changed, 105 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 989ee2830..7b979d4cf 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -1,89 +1,3 @@ -# tensoradd! -# ---------- -# Mooncake.@is_primitive( -# DefaultCtx, -# ReverseMode, -# Tuple{ -# typeof(TO.tensoradd!), -# AbstractTensorMap, -# AbstractTensorMap, Index2Tuple, Bool, -# Number, Number, Vararg{Any}, -# } -# ) -# -# function Mooncake.rrule!!( -# ::CoDual{typeof(TO.tensoradd!)}, -# C_ΔC::CoDual{<:AbstractTensorMap}, -# A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, -# α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, -# ba_Δba::CoDual... -# ) -# # prepare arguments -# C, ΔC = arrayify(C_ΔC) -# A, ΔA = arrayify(A_ΔA) -# pA = primal(pA_ΔpA) -# conjA = primal(conjA_ΔconjA) -# α, β = primal.((α_Δα, β_Δβ)) -# ba = primal.(ba_Δba) -# -# # primal call -# C_cache = copy(C) -# TO.tensoradd!(C, A, pA, conjA, α, β, ba...) -# -# function tensoradd_pullback(::NoRData) -# copy!(C, C_cache) -# -# ΔCr = tensoradd_pullback_ΔC!(ΔC, β) -# ΔAr = tensoradd_pullback_ΔA!(ΔA, ΔC, A, pA, conjA, α, ba...) -# Δαr = tensoradd_pullback_Δα(ΔC, A, pA, conjA, α, ba...) -# Δβr = tensoradd_pullback_Δβ(ΔC, C, β) -# -# return NoRData(), -# ΔCr, -# ΔAr, NoRData(), NoRData(), -# Δαr, Δβr, -# map(Returns(NoRData()), ba)... -# end -# -# return C_ΔC, tensoradd_pullback -# end -# -# tensoradd_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) -# -# function tensoradd_pullback_ΔA!( -# ΔA, ΔC, A, pA, conjA, α, ba... -# ) -# ipA = invperm(linearize(pA)) -# pΔA = _repartition(ipA, A) -# TO.tensoradd!(ΔA, ΔC, pΔA, conjA, conjA ? α : conj(α), Zero(), ba...) -# return NoRData() -# end -# -# function tensoradd_pullback_Δα( -# ΔC, A, pA, conjA, α, ba... -# ) -# Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) -# Tdα === NoRData && return NoRData() -# -# tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)); copy = false) -# Δα = TO.tensorscalar( -# TO.tensorcontract( -# A, ((), linearize(pA)), !conjA, -# tΔC, (trivtuple(TO.numind(pA)), ()), false, -# ((), ()), One(), ba... -# ) -# ) -# return Mooncake._rdata(Δα) -# end -# -# function tensoradd_pullback_Δβ(ΔC, C, β) -# Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) -# Tdβ === NoRData && return NoRData() -# -# Δβ = inner(C, ΔC) -# return Mooncake._rdata(Δβ) -# end - # tensorcontract! # --------------- Mooncake.@is_primitive( diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index cca2c92d3..066a3585f 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -206,25 +206,6 @@ for V in spacelist atol = precision(T) rtol = precision(T) - @timedtestset "tensoradd!" begin - A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) - α = randn(T) - β = randn(T) - - # repeat a couple times to get some distribution of arrows - for _ in 1:5 - p = randindextuple(numind(A)) - - C1 = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C1, A, p, false, α, β; atol, rtol, mode, is_primitive = false) - - C2 = randn!(TensorOperations.tensoralloc_add(T, A, p, true, Val(false))) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C2, A, p, true, α, β; atol, rtol, mode, is_primitive = false) - - A = rand(Bool) ? C1 : C2 - end - end - @timedtestset "tensorcontract!" begin for _ in 1:5 d = 0 From a8f8a20573b8520bc2debcd41d577b0b7e8b40cb Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 19:52:14 -0500 Subject: [PATCH 15/54] add planaroperations --- .../TensorKitMooncakeExt.jl | 1 + ext/TensorKitMooncakeExt/planaroperations.jl | 88 +++++++++++++++++++ test/autodiff/mooncake.jl | 73 +++++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 ext/TensorKitMooncakeExt/planaroperations.jl diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index 15e0c4c9f..4c692adb9 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -15,5 +15,6 @@ include("linalg.jl") include("indexmanipulations.jl") include("vectorinterface.jl") include("tensoroperations.jl") +include("planaroperations.jl") end diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl new file mode 100644 index 000000000..a480293af --- /dev/null +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -0,0 +1,88 @@ +# planartrace! +# ------------ +Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TensorKit.planartrace!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, Index2Tuple, + Number, Number, + Any, Any, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TensorKit.planartrace!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, p_Δp::CoDual{<:Index2Tuple}, q_Δq::CoDual{<:Index2Tuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + backend_Δbackend::CoDual, allocator_Δallocator::CoDual + ) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + p = primal(p_Δp) + q = primal(q_Δq) + α, β = primal.((α_Δα, β_Δβ)) + backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) + + # primal call + C_cache = copy(C) + TensorKit.planartrace!(C, A, p, q, α, β, backend, allocator) + + function planartrace_pullback(::NoRData) + copy!(C, C_cache) + + ΔAr = planartrace_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend, allocator) + Δαr = planartrace_pullback_Δα(ΔC, A, p, q, α, backend, allocator) + Δβr = planartrace_pullback_Δβ(ΔC, C, β) + ΔCr = planartrace_pullback_ΔC!(ΔC, β) + + return NoRData(), + ΔCr, ΔAr, NoRData(), NoRData(), + Δαr, Δβr, NoRData(), NoRData() + end + + return C_ΔC, planartrace_pullback +end + +planartrace_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) + +function planartrace_pullback_ΔA!( + ΔA, ΔC, A, p, q, α, backend, allocator + ) + ip = invperm((linearize(p)..., q[1]..., q[2]...)) + pdA = _repartition(ip, A) + E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + pE = ((), trivtuple(TO.numind(q))) + pΔC = (trivtuple(TO.numind(p)), ()) + TensorKit.planarcontract!( + ΔA, ΔC, pΔC, E, pE, pdA, conj(α), One(), backend, allocator + ) + return NoRData() +end + +function planartrace_pullback_Δα( + ΔC, A, p, q, α, backend, allocator + ) + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Tdα === NoRData && return NoRData() + + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = TO.tensoralloc_add(scalartype(A), A, p, false, Val(true), allocator) + TensorKit.planartrace!(At, A, p, q, false, One(), backend, allocator) + Δα = inner(At, ΔC) + TO.tensorfree!(At, allocator) + return Mooncake._rdata(Δα) +end + +function planartrace_pullback_Δβ(ΔC, C, β) + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Tdβ === NoRData && return NoRData() + + Δβ = inner(C, ΔC) + return Mooncake._rdata(Δβ) +end diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index 066a3585f..14aaea251 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -271,4 +271,77 @@ for V in spacelist end end end + + @timedtestset "PlanarOperations with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) + + @timedtestset "planarcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3, k1, k2, k3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 1 && break + end + k′ = rand(0:(k1 + k2)) + pA = randcircshift(k′, k1 + k2 - k′, k1) + ipA = _repartition(invperm(linearize(pA)), k′) + k′ = rand(0:(k2 + k3)) + pB = randcircshift(k′, k2 + k3 - k′, k2) + ipB = _repartition(invperm(linearize(pB)), k′) + # TODO: primal value already is broken for this? + # pAB = randcircshift(k1, k3) + pAB = _repartition(tuple((1:(k1 + k3))...), k1) + + α = randn(T) + β = randn(T) + + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.planarcontract!, C, A, pA, B, pB, pAB, α, β; + atol, rtol, mode, is_primitive = false + ) + end + end + + @timedtestset "planartrace!" begin + for _ in 1:5 + k1 = rand(0:2) + k2 = rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + k′ = rand(0:(k1 + 2k2)) + (_p, _q) = randcircshift(k′, k1 + 2 * k2 - k′, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), k′) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + Mooncake.TestUtils.test_rule( + rng, TensorKit.planartrace!, + C, A, p, q, α, β, + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) + end + end + end end From 8e1993eeb1119c87bf69a8256a1ced860078f25d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 19:52:39 -0500 Subject: [PATCH 16/54] rewrite rule `tensorcontract` in terms of `blas_contract!` --- ext/TensorKitMooncakeExt/tensoroperations.jl | 83 +++++++++----------- test/autodiff/mooncake.jl | 25 +++--- 2 files changed, 51 insertions(+), 57 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 7b979d4cf..59a398e27 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -4,72 +4,69 @@ Mooncake.@is_primitive( DefaultCtx, ReverseMode, Tuple{ - typeof(TO.tensorcontract!), + typeof(TensorKit.blas_contract!), AbstractTensorMap, - AbstractTensorMap, Index2Tuple, Bool, - AbstractTensorMap, Index2Tuple, Bool, + AbstractTensorMap, Index2Tuple, + AbstractTensorMap, Index2Tuple, Index2Tuple, Number, Number, - Vararg{Any}, + Any, Any, } ) function Mooncake.rrule!!( - ::CoDual{typeof(TO.tensorcontract!)}, + ::CoDual{typeof(TensorKit.blas_contract!)}, C_ΔC::CoDual{<:AbstractTensorMap}, - A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, - B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, conjB_ΔconjB::CoDual{Bool}, + A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, + B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, pAB_ΔpAB::CoDual{<:Index2Tuple}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, - ba_Δba::CoDual..., + backend_Δbackend::CoDual, allocator_Δallocator::CoDual ) # prepare arguments (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB)) - conjA, conjB = primal.((conjA_ΔconjA, conjB_ΔconjB)) α, β = primal.((α_Δα, β_Δβ)) - ba = primal.(ba_Δba) + backend, allocator = primal.((backend_Δbackend, allocator_Δallocator)) # primal call C_cache = copy(C) - TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) + TensorKit.blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) - function tensorcontract_pullback(::NoRData) + function blas_contract_pullback(::NoRData) copy!(C, C_cache) - ΔCr = tensorcontract_pullback_ΔC!(ΔC, β) - ΔAr = tensorcontract_pullback_ΔA!( - ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ΔAr = blas_contract_pullback_ΔA!( + ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) - ΔBr = tensorcontract_pullback_ΔB!( - ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ΔBr = blas_contract_pullback_ΔB!( + ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) - Δαr = tensorcontract_pullback_Δα( - ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + Δαr = blas_contract_pullback_Δα( + ΔC, A, pA, B, pB, pAB, α, backend, allocator ) - Δβr = tensorcontract_pullback_Δβ(ΔC, C, β) + Δβr = blas_contract_pullback_Δβ(ΔC, C, β) + ΔCr = blas_contract_pullback_ΔC!(ΔC, β) return NoRData(), ΔCr, - ΔAr, NoRData(), NoRData(), - ΔBr, NoRData(), NoRData(), + ΔAr, NoRData(), + ΔBr, NoRData(), NoRData(), Δαr, Δβr, - map(ba_ -> NoRData(), ba)... + NoRData(), NoRData() end - return C_ΔC, tensorcontract_pullback + return C_ΔC, blas_contract_pullback end -tensorcontract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) +blas_contract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) -function tensorcontract_pullback_ΔA!( - ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... +function blas_contract_pullback_ΔA!( + ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) ipAB = invperm(linearize(pAB)) pΔC = _repartition(ipAB, TO.numout(pA)) ipA = _repartition(invperm(linearize(pA)), A) - conjΔC = conjA - conjB′ = conjA ? conjB : !conjB tB = twist( B, @@ -81,24 +78,22 @@ function tensorcontract_pullback_ΔA!( TO.tensorcontract!( ΔA, - ΔC, pΔC, conjΔC, - tB, reverse(pB), conjB′, + ΔC, pΔC, false, + tB, reverse(pB), true, ipA, - conjA ? α : conj(α), Zero(), - ba... + conj(α), Zero(), + backend, allocator ) return NoRData() end -function tensorcontract_pullback_ΔB!( - ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... +function blas_contract_pullback_ΔB!( + ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) ipAB = invperm(linearize(pAB)) pΔC = _repartition(ipAB, TO.numout(pA)) ipB = _repartition(invperm(linearize(pB)), B) - conjΔC = conjB - conjA′ = conjB ? conjA : !conjA tA = twist( A, @@ -110,27 +105,27 @@ function tensorcontract_pullback_ΔB!( TO.tensorcontract!( ΔB, - tA, reverse(pA), conjA′, - ΔC, pΔC, conjΔC, + tA, reverse(pA), true, + ΔC, pΔC, false, ipB, - conjB ? α : conj(α), Zero(), ba... + conj(α), Zero(), backend, allocator ) return NoRData() end -function tensorcontract_pullback_Δα( - ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... +function blas_contract_pullback_Δα( + ΔC, A, pA, B, pB, pAB, α, backend, allocator ) Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) Tdα === NoRData && return NoRData() - AB = TO.tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator) Δα = inner(AB, ΔC) return Mooncake._rdata(Δα) end -function tensorcontract_pullback_Δβ(ΔC, C, β) +function blas_contract_pullback_Δβ(ΔC, C, β) Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) Tdβ === NoRData && return NoRData() diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index 14aaea251..a1bce9906 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -231,20 +231,19 @@ for V in spacelist β = randn(T) V2_conj = prod(conj, V2; init = one(V[1])) - for conjA in (false, true), conjB in (false, true) - A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA)) - B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB)) - C = randn!( - TensorOperations.tensoralloc_contract( - T, A, pA, conjA, B, pB, conjB, pAB, Val(false) - ) - ) - Mooncake.TestUtils.test_rule( - rng, tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β; - atol, rtol, mode + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) ) - - end + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, A, pA, B, pB, pAB, α, β, + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) end end From 70cdc5570bbd5336e4a6e1c7e0ccbe98b39c3959 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 21 Jan 2026 20:50:06 -0500 Subject: [PATCH 17/54] add rule `tr` --- ext/TensorKitMooncakeExt/linalg.jl | 16 ++++++++++++++++ test/autodiff/mooncake.jl | 8 ++++++++ 2 files changed, 24 insertions(+) diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index d0d73d951..092ddf369 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -51,3 +51,19 @@ function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorM end return CoDual(n, Mooncake.NoFData()), norm_pullback end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tr), AbstractTensorMap} + +function Mooncake.rrule!!(::CoDual{typeof(tr)}, A_ΔA::CoDual{<:AbstractTensorMap}) + A, ΔA = arrayify(A_ΔA) + trace = tr(A) + + function tr_pullback(Δtrace) + for (_, b) in blocks(ΔA) + TensorKit.diagview(b) .+= Δtrace + end + return NoRData(), NoRData() + end + + return CoDual(trace, Mooncake.NoFData()), tr_pullback +end diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index a1bce9906..e9f7d01d7 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -117,6 +117,14 @@ for V in spacelist Mooncake.TestUtils.test_rule(rng, norm, C, 2; atol, rtol, mode) Mooncake.TestUtils.test_rule(rng, norm, C', 2; atol, rtol, mode) + + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + + Mooncake.TestUtils.test_rule(rng, tr, D1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, tr, D2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, tr, D3; atol, rtol, mode) end From f85c946f7454852fc28a97cba54fb91d4cf3ee24 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 22 Jan 2026 08:02:00 -0500 Subject: [PATCH 18/54] give up on planartrace for now --- ext/TensorKitMooncakeExt/planaroperations.jl | 34 +++++++++---- src/fusiontrees/manipulations.jl | 2 +- test/autodiff/mooncake.jl | 53 +++++++++++--------- 3 files changed, 52 insertions(+), 37 deletions(-) diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl index a480293af..3d1742a3a 100644 --- a/ext/TensorKitMooncakeExt/planaroperations.jl +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -49,19 +49,31 @@ end planartrace_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) +# This implementation is slightly more involved than its non-planar counterpart +# this is because we lack a general `pAB` argument in `planarcontract`, and need +# to keep things planar along the way. +# In particular, we can't simply tensor product with multiple identities in one go +# if they aren't "contiguous", e.g. p = ((1, 4, 5), ()), q = ((2, 6), (3, 7)) function planartrace_pullback_ΔA!( ΔA, ΔC, A, p, q, α, backend, allocator ) - ip = invperm((linearize(p)..., q[1]..., q[2]...)) - pdA = _repartition(ip, A) - E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) - twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) - pE = ((), trivtuple(TO.numind(q))) - pΔC = (trivtuple(TO.numind(p)), ()) - TensorKit.planarcontract!( - ΔA, ΔC, pΔC, E, pE, pdA, conj(α), One(), backend, allocator - ) - return NoRData() + if length(q[1]) == 0 + ip = invperm(linearize(p)) + pΔA = _repartition(ip, A) + TK.add_transpose!(ΔA, ΔC, pΔA, conj(α), One(), backend, allocator) + return NoRData() + end + # if length(q[1]) == 1 + # ip = invperm((p[1]..., q[2]..., p[2]..., q[1]...)) + # pdA = _repartition(ip, A) + # E = one!(TO.tensoralloc_add(scalartype(A), A, q, false)) + # twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + # # pE = ((), trivtuple(TO.numind(q))) + # # pΔC = (trivtuple(TO.numind(p)), ()) + # TensorKit.planaradd!(ΔA, ΔC ⊗ E, pdA, conj(α), One(), backend, allocator) + # return NoRData() + # end + error("The reverse rule for `planartrace` is not yet implemented") end function planartrace_pullback_Δα( @@ -73,7 +85,7 @@ function planartrace_pullback_Δα( # TODO: this result might be easier to compute as: # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α At = TO.tensoralloc_add(scalartype(A), A, p, false, Val(true), allocator) - TensorKit.planartrace!(At, A, p, q, false, One(), backend, allocator) + TensorKit.planartrace!(At, A, p, q, One(), Zero(), backend, allocator) Δα = inner(At, ΔC) TO.tensorfree!(At, allocator) return Mooncake._rdata(Δα) diff --git a/src/fusiontrees/manipulations.jl b/src/fusiontrees/manipulations.jl index 1564b1b67..3cc6a16b6 100644 --- a/src/fusiontrees/manipulations.jl +++ b/src/fusiontrees/manipulations.jl @@ -692,7 +692,7 @@ function planar_trace( k += 1 end end - k > N₃ && throw(ArgumentError("Not a planar trace")) + k > N₃ && throw(ArgumentError(lazy"not a planar trace: ($q1, $q2)")) q1′ = let i = i, j = j map(l -> (l - (l > i) - (l > j)), TupleTools.deleteat(q1, k)) diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index e9f7d01d7..db7e0c078 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -325,30 +325,33 @@ for V in spacelist end end - @timedtestset "planartrace!" begin - for _ in 1:5 - k1 = rand(0:2) - k2 = rand(1:2) - V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) - V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) - - k′ = rand(0:(k1 + 2k2)) - (_p, _q) = randcircshift(k′, k1 + 2 * k2 - k′, k1) - p = _repartition(_p, rand(0:k1)) - q = _repartition(_q, k2) - ip = _repartition(invperm(linearize((_p, _q))), k′) - A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) - - α = randn(T) - β = randn(T) - C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) - Mooncake.TestUtils.test_rule( - rng, TensorKit.planartrace!, - C, A, p, q, α, β, - TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode - ) - end - end + # TODO: currently broken + # @timedtestset "planartrace!" begin + # for _ in 1:5 + # k1 = rand(0:2) + # k2 = rand(0:1) + # V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + # V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + # V3 = prod(x -> x ⊗ x', V2[1:k2]; init = one(V[1])) + # V4 = prod(x -> x ⊗ x', V2[(k2 + 1):end]; init = one(V[1])) + # + # k′ = rand(0:(k1 + 2k2)) + # (_p, _q) = randcircshift(k′, k1 + 2k2 - k′, k1) + # p = _repartition(_p, rand(0:k1)) + # q = (tuple(_q[1:2:end]...), tuple(_q[2:2:end]...)) + # ip = _repartition(invperm(linearize((_p, _q))), k′) + # A = randn(T, permute(prod(V1) ⊗ V3 ← V4, ip)) + # + # α = randn(T) + # β = randn(T) + # C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + # Mooncake.TestUtils.test_rule( + # rng, TensorKit.planartrace!, + # C, A, p, q, α, β, + # TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + # atol, rtol, mode + # ) + # end + # end end end From 8117c76de8a50710f9c4b72a1c207ed4a06cc72e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 22 Jan 2026 08:08:47 -0500 Subject: [PATCH 19/54] add rule `inv` --- ext/TensorKitMooncakeExt/linalg.jl | 15 +++++++++++++++ test/autodiff/mooncake.jl | 4 ++++ 2 files changed, 19 insertions(+) diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 092ddf369..a35c1cea4 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -67,3 +67,18 @@ function Mooncake.rrule!!(::CoDual{typeof(tr)}, A_ΔA::CoDual{<:AbstractTensorMa return CoDual(trace, Mooncake.NoFData()), tr_pullback end + +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(inv), AbstractTensorMap} + +function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorMap}) + A, ΔA = arrayify(A_ΔA) + Ainv_ΔAinv = Mooncake.zero_fcodual(inv(A)) + Ainv, ΔAinv = arrayify(Ainv_ΔAinv) + + function inv_pullback(::NoRData) + mul!(ΔA, Ainv' * ΔAinv, Ainv', -1, One()) + return NoRData(), NoRData() + end + + return Ainv_ΔAinv, inv_pullback +end diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl index db7e0c078..0ae368235 100644 --- a/test/autodiff/mooncake.jl +++ b/test/autodiff/mooncake.jl @@ -125,6 +125,10 @@ for V in spacelist Mooncake.TestUtils.test_rule(rng, tr, D1; atol, rtol, mode) Mooncake.TestUtils.test_rule(rng, tr, D2; atol, rtol, mode) Mooncake.TestUtils.test_rule(rng, tr, D3; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, inv, D1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, inv, D2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, inv, D3; atol, rtol, mode) end From f550cfd56679c7ae9c53b8120bc5d8751ba6efcf Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 22 Jan 2026 08:33:46 -0500 Subject: [PATCH 20/54] is_primitive in namespace --- .../TensorKitMooncakeExt.jl | 2 +- .../indexmanipulations.jl | 20 +++++++++---------- ext/TensorKitMooncakeExt/linalg.jl | 8 ++++---- ext/TensorKitMooncakeExt/planaroperations.jl | 2 +- ext/TensorKitMooncakeExt/tensoroperations.jl | 4 ++-- ext/TensorKitMooncakeExt/vectorinterface.jl | 8 ++++---- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index 4c692adb9..d3894c874 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -1,7 +1,7 @@ module TensorKitMooncakeExt using Mooncake -using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoFData, NoRData, CoDual, arrayify, primal +using Mooncake: @zero_derivative, @is_primitive, DefaultCtx, ReverseMode, NoFData, NoRData, CoDual, arrayify, primal using TensorKit import TensorKit as TK using VectorInterface diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index 464c18392..39f7dd4fd 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -1,7 +1,7 @@ for transform in (:permute, :transpose) add_transform! = Symbol(:add_, transform, :!) add_transform_pullback = Symbol(add_transform!, :_pullback) - @eval Mooncake.@is_primitive( + @eval @is_primitive( DefaultCtx, ReverseMode, Tuple{ @@ -76,7 +76,7 @@ for transform in (:permute, :transpose) end end -Mooncake.@is_primitive( +@is_primitive( DefaultCtx, ReverseMode, Tuple{ @@ -153,8 +153,8 @@ function Mooncake.rrule!!( end # both are needed for correctly capturing every dispatch -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(twist!), AbstractTensorMap, Any} -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), @NamedTuple{inv::Bool}, typeof(twist!), AbstractTensorMap, Any} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(twist!), AbstractTensorMap, Any} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), @NamedTuple{inv::Bool}, typeof(twist!), AbstractTensorMap, Any} function Mooncake.rrule!!(::CoDual{typeof(twist!)}, t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual) # prepare arguments @@ -198,8 +198,8 @@ function Mooncake.rrule!!( end # both are needed for correctly capturing every dispatch -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(flip), AbstractTensorMap, Any} -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), @NamedTuple{inv::Bool}, typeof(flip), AbstractTensorMap, Any} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(flip), AbstractTensorMap, Any} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), @NamedTuple{inv::Bool}, typeof(flip), AbstractTensorMap, Any} function Mooncake.rrule!!(::CoDual{typeof(flip)}, t_Δt::CoDual{<:AbstractTensorMap}, inds_Δinds::CoDual) # prepare arguments @@ -245,8 +245,8 @@ for insertunit in (:insertleftunit, :insertrightunit) insertunit_pullback = Symbol(insertunit, :_pullback) @eval begin # both are needed for correctly capturing every dispatch - Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof($insertunit), AbstractTensorMap, Val} - Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), NamedTuple, typeof($insertunit), AbstractTensorMap, Val} + @is_primitive DefaultCtx ReverseMode Tuple{typeof($insertunit), AbstractTensorMap, Val} + @is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), NamedTuple, typeof($insertunit), AbstractTensorMap, Val} function Mooncake.rrule!!(::CoDual{typeof($insertunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{<:Val}) # prepare arguments @@ -328,8 +328,8 @@ for insertunit in (:insertleftunit, :insertrightunit) end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(removeunit), AbstractTensorMap, Val} -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), NamedTuple, typeof(removeunit), AbstractTensorMap, Val} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(removeunit), AbstractTensorMap, Val} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(Core.kwcall), NamedTuple, typeof(removeunit), AbstractTensorMap, Val} function Mooncake.rrule!!(::CoDual{typeof(removeunit)}, tsrc_Δtsrc::CoDual{<:AbstractTensorMap}, ival_Δival::CoDual{Val{i}}) where {i} # prepare arguments diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index a35c1cea4..a75e77922 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -1,4 +1,4 @@ -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number} function Mooncake.rrule!!( ::CoDual{typeof(mul!)}, @@ -37,7 +37,7 @@ function Mooncake.rrule!!( return C_ΔC, mul_pullback end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real} function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real}) t, Δt = arrayify(tΔt) @@ -52,7 +52,7 @@ function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorM return CoDual(n, Mooncake.NoFData()), norm_pullback end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tr), AbstractTensorMap} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(tr), AbstractTensorMap} function Mooncake.rrule!!(::CoDual{typeof(tr)}, A_ΔA::CoDual{<:AbstractTensorMap}) A, ΔA = arrayify(A_ΔA) @@ -68,7 +68,7 @@ function Mooncake.rrule!!(::CoDual{typeof(tr)}, A_ΔA::CoDual{<:AbstractTensorMa return CoDual(trace, Mooncake.NoFData()), tr_pullback end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(inv), AbstractTensorMap} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(inv), AbstractTensorMap} function Mooncake.rrule!!(::CoDual{typeof(inv)}, A_ΔA::CoDual{<:AbstractTensorMap}) A, ΔA = arrayify(A_ΔA) diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl index 3d1742a3a..df75d60fe 100644 --- a/ext/TensorKitMooncakeExt/planaroperations.jl +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -1,6 +1,6 @@ # planartrace! # ------------ -Mooncake.@is_primitive( +@is_primitive( DefaultCtx, ReverseMode, Tuple{ diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 59a398e27..e38271200 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -1,6 +1,6 @@ # tensorcontract! # --------------- -Mooncake.@is_primitive( +@is_primitive( DefaultCtx, ReverseMode, Tuple{ @@ -135,7 +135,7 @@ end # tensortrace! # ------------ -Mooncake.@is_primitive( +@is_primitive( DefaultCtx, ReverseMode, Tuple{ diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl index 2c1bfe984..625aadd61 100644 --- a/ext/TensorKitMooncakeExt/vectorinterface.jl +++ b/ext/TensorKitMooncakeExt/vectorinterface.jl @@ -1,4 +1,4 @@ -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, Number} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, Number} function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) # prepare arguments @@ -20,7 +20,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTens return C_ΔC, scale_pullback end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) # prepare arguments @@ -44,7 +44,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTens return C_ΔC, scale_pullback end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) # prepare arguments @@ -73,7 +73,7 @@ function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensor return C_ΔC, add_pullback end -Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} +@is_primitive DefaultCtx ReverseMode Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}) # prepare arguments From a82174551742437ba98a588033d01f1f74f5d9e6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 22 Jan 2026 08:49:50 -0500 Subject: [PATCH 21/54] share more code --- .../indexmanipulations.jl | 44 +++++-------------- ext/TensorKitMooncakeExt/linalg.jl | 10 +++-- ext/TensorKitMooncakeExt/tensoroperations.jl | 28 ++---------- ext/TensorKitMooncakeExt/utility.jl | 6 +-- 4 files changed, 26 insertions(+), 62 deletions(-) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index 39f7dd4fd..8a97ac81c 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -31,22 +31,18 @@ for transform in (:permute, :transpose) # if we need to compute Δa, it is faster to allocate an intermediate permuted A # and store that instead of repeating the permutation in the pullback each time. # effectively, we replace `add_permute` by `add ∘ permute`. - Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) - Ap = if Tdα === NoRData - TK.$add_transform!(C, A, p, α, β, ba...) - nothing - else + Ap = if _needs_tangent(α) Ap = $transform(A, p) add!(C, Ap, α, β) Ap + else + TK.$add_transform!(C, A, p, α, β, ba...) + nothing end function $add_transform_pullback(::NoRData) copy!(C, C_cache) - scale!(ΔC, conj(β)) - ΔCr = NoRData() - # ΔA ip = invperm(linearize(p)) pΔA = _repartition(ip, A) @@ -60,14 +56,8 @@ for transform in (:permute, :transpose) Mooncake._rdata(inner(Ap, ΔC)) end - # Δβ - Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) - Δβr = if Tdβ === NoRData - NoRData() - else - Mooncake._rdata(inner(C, ΔC)) - end - + Δβr = pullback_dβ(C, ΔC, β) + ΔCr = pullback_dC!(ΔC, β) return NoRData(), ΔCr, ΔAr, NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... end @@ -107,22 +97,18 @@ function Mooncake.rrule!!( # if we need to compute Δa, it is faster to allocate an intermediate braided A # and store that instead of repeating the permutation in the pullback each time. # effectively, we replace `add_permute` by `add ∘ permute`. - Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) - Ap = if Tdα === NoRData - TK.add_braid!(C, A, p, levels, α, β, ba...) - nothing - else + Ap = if _needs_tangent(α) Ap = braid(A, p, levels) add!(C, Ap, α, β) Ap + else + TK.add_braid!(C, A, p, levels, α, β, ba...) + nothing end function add_braid!_pullback(::NoRData) copy!(C, C_cache) - scale!(ΔC, conj(β)) - ΔCr = NoRData() - # ΔA ip = invperm(linearize(p)) pΔA = _repartition(ip, A) @@ -137,14 +123,8 @@ function Mooncake.rrule!!( Mooncake._rdata(inner(Ap, ΔC)) end - # Δβ - Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) - Δβr = if Tdβ === NoRData - NoRData() - else - Mooncake._rdata(inner(C, ΔC)) - end - + Δβr = pullback_dβ(C, ΔC, β) + ΔCr = pullback_dC!(ΔC, β) return NoRData(), ΔCr, ΔAr, NoRData(), NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... end diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index a75e77922..2a77792c9 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -1,3 +1,8 @@ +# Shared +# ------ +pullback_dC!(ΔC, β) = (scale!(ΔC, conj(β)); return NoRData()) +pullback_dβ(C, ΔC, β) = _needs_tangent(β) ? inner(C, ΔC) : NoRData() + @is_primitive DefaultCtx ReverseMode Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number} function Mooncake.rrule!!( @@ -22,14 +27,13 @@ function Mooncake.rrule!!( function mul_pullback(::NoRData) copy!(C, C_cache) - scale!(ΔC, conj(β)) mul!(ΔA, ΔC, B', conj(α), One()) mul!(ΔB, A', ΔC, conj(α), One()) - ΔCr = NoRData() ΔAr = NoRData() ΔBr = NoRData() Δαr = isnothing(AB) ? NoRData() : Mooncake._rdata(inner(AB, ΔC)) - Δβr = _needs_tangent(β) ? Mooncake._rdata(inner(C, ΔC)) : NoRData() + Δβr = pullback_dβ(C, ΔC, β) + ΔCr = pullback_dC!(ΔC, β) return NoRData(), ΔCr, ΔAr, ΔBr, Δαr, Δβr end diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index e38271200..66c3f257a 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -45,8 +45,8 @@ function Mooncake.rrule!!( Δαr = blas_contract_pullback_Δα( ΔC, A, pA, B, pB, pAB, α, backend, allocator ) - Δβr = blas_contract_pullback_Δβ(ΔC, C, β) - ΔCr = blas_contract_pullback_ΔC!(ΔC, β) + Δβr = pullback_dβ(ΔC, C, β) + ΔCr = pullback_dC!(ΔC, β) return NoRData(), ΔCr, ΔAr, NoRData(), @@ -59,8 +59,6 @@ function Mooncake.rrule!!( return C_ΔC, blas_contract_pullback end -blas_contract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) - function blas_contract_pullback_ΔA!( ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator ) @@ -125,14 +123,6 @@ function blas_contract_pullback_Δα( return Mooncake._rdata(Δα) end -function blas_contract_pullback_Δβ(ΔC, C, β) - Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) - Tdβ === NoRData && return NoRData() - - Δβ = inner(C, ΔC) - return Mooncake._rdata(Δβ) -end - # tensortrace! # ------------ @is_primitive( @@ -171,8 +161,8 @@ function Mooncake.rrule!!( ΔAr = trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) Δαr = trace_permute_pullback_Δα(ΔC, A, p, q, α, backend) - Δβr = trace_permute_pullback_Δβ(ΔC, C, β) - ΔCr = trace_permute_pullback_ΔC!(ΔC, β) + Δβr = pullback_dβ(ΔC, C, β) + ΔCr = pullback_dC!(ΔC, β) return NoRData(), ΔCr, ΔAr, NoRData(), NoRData(), @@ -182,8 +172,6 @@ function Mooncake.rrule!!( return C_ΔC, trace_permute_pullback end -trace_permute_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) - function trace_permute_pullback_ΔA!( ΔA, ΔC, A, p, q, α, backend ) @@ -211,11 +199,3 @@ function trace_permute_pullback_Δα( Δα = inner(At, ΔC) return Mooncake._rdata(Δα) end - -function trace_permute_pullback_Δβ(ΔC, C, β) - Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) - Tdβ === NoRData && return NoRData() - - Δβ = inner(C, ΔC) - return Mooncake._rdata(Δβ) -end diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index e93de22be..261c1dcc2 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -1,7 +1,7 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) -_needs_tangent(::Type{<:Number}) = true -_needs_tangent(::Type{<:Integer}) = false -_needs_tangent(::Type{<:Union{One, Zero}}) = false +function _needs_tangent(::Type{T}) where {T <: Number} + return Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData() +end # IndexTuple utility # ------------------ From a17a55d65bcf423b75e733bc97ec0317c7ac90f2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 22 Jan 2026 11:56:00 -0500 Subject: [PATCH 22/54] split AD tests to reduce CI pressure properly setup setup --- .github/workflows/CI.yml | 6 +- test/autodiff/mooncake.jl | 361 -------------------- test/{autodiff => chainrules}/chainrules.jl | 0 test/mooncake/indexmanipulations.jl | 134 ++++++++ test/mooncake/linalg.jl | 80 +++++ test/mooncake/planaroperations.jl | 128 +++++++ test/mooncake/tensoroperations.jl | 121 +++++++ test/mooncake/vectorinterface.jl | 75 ++++ test/runtests.jl | 2 +- test/setup.jl | 38 +++ 10 files changed, 581 insertions(+), 364 deletions(-) delete mode 100644 test/autodiff/mooncake.jl rename test/{autodiff => chainrules}/chainrules.jl (100%) create mode 100644 test/mooncake/indexmanipulations.jl create mode 100644 test/mooncake/linalg.jl create mode 100644 test/mooncake/planaroperations.jl create mode 100644 test/mooncake/tensoroperations.jl create mode 100644 test/mooncake/vectorinterface.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 434f33ed4..8880dfcf1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -30,7 +30,8 @@ jobs: - symmetries - tensors - other - - autodiff + - mooncake + - chainrules os: - ubuntu-latest - macOS-latest @@ -55,7 +56,8 @@ jobs: - symmetries - tensors - other - - autodiff + - mooncake + - chainrules os: - ubuntu-latest - macOS-latest diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl deleted file mode 100644 index 0ae368235..000000000 --- a/test/autodiff/mooncake.jl +++ /dev/null @@ -1,361 +0,0 @@ -using Test, TestExtras -using TensorKit -using TensorOperations -using Mooncake -using Random -using TupleTools - -mode = Mooncake.ReverseMode -rng = Random.default_rng() -is_primitive = false - -function randindextuple(N::Int, k::Int = rand(0:N)) - @assert 0 ≤ k ≤ N - _p = randperm(N) - return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...)) -end -function randcircshift(N₁::Int, N₂::Int, k::Int = rand(0:(N₁ + N₂))) - N = N₁ + N₂ - @assert 0 ≤ k ≤ N - p = TupleTools.vcat(ntuple(identity, N₁), reverse(ntuple(identity, N₂) .+ N₁)) - n = rand(0:N) - _p = TupleTools.circshift(p, n) - return (tuple(_p[1:k]...), reverse(tuple(_p[(k + 1):end]...))) -end - -const _repartition = @static if isdefined(Base, :get_extension) - Base.get_extension(TensorKit, :TensorKitMooncakeExt)._repartition -else - TensorKit.TensorKitMooncakeExt._repartition -end - -spacelist = ( - (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), - ( - Vect[Z2Irrep](0 => 1, 1 => 1), - Vect[Z2Irrep](0 => 1, 1 => 2)', - Vect[Z2Irrep](0 => 2, 1 => 2)', - Vect[Z2Irrep](0 => 2, 1 => 3), - Vect[Z2Irrep](0 => 2, 1 => 2), - ), - ( - Vect[FermionParity](0 => 1, 1 => 1), - Vect[FermionParity](0 => 1, 1 => 2)', - Vect[FermionParity](0 => 2, 1 => 1)', - Vect[FermionParity](0 => 2, 1 => 3), - Vect[FermionParity](0 => 2, 1 => 2), - ), - ( - Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), - Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), - Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', - Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), - Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', - ), - ( - Vect[SU2Irrep](0 => 2, 1 // 2 => 1), - Vect[SU2Irrep](0 => 1, 1 => 1), - Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', - Vect[SU2Irrep](1 // 2 => 2), - Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', - ), - # ( - # Vect[FibonacciAnyon](:I => 2, :τ => 1), - # Vect[FibonacciAnyon](:I => 1, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 3), - # Vect[FibonacciAnyon](:I => 2, :τ => 2), - # ), -) - -for V in spacelist - I = sectortype(eltype(V)) - Istr = TensorKit.type_repr(I) - - symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding - println("---------------------------------------") - println("Mooncake with symmetry: $Istr") - println("---------------------------------------") - eltypes = (Float64,) # no complex support yet - - @timedtestset "VectorInterface with scalartype $T" for T in eltypes - atol = precision(T) - rtol = precision(T) - - C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) - A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) - α = randn(T) - β = randn(T) - - Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, mode) - - Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, mode, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, mode, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, mode) - - Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode) - end - - @timedtestset "LinearAlgebra with scalartype $T" for T in eltypes - atol = precision(T) - rtol = precision(T) - - C = randn(T, V[1] ⊗ V[2] ← V[5]) - A = randn(T, codomain(C) ← V[3] ⊗ V[4]) - B = randn(T, domain(A) ← domain(C)) - α = randn(T) - β = randn(T) - - Mooncake.TestUtils.test_rule(rng, mul!, C, A, B, α, β; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, mul!, C, A, B; atol, rtol, mode, is_primitive = false) - - Mooncake.TestUtils.test_rule(rng, norm, C, 2; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, norm, C', 2; atol, rtol, mode) - - D1 = randn(T, V[1] ← V[1]) - D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) - D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) - - Mooncake.TestUtils.test_rule(rng, tr, D1; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, tr, D2; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, tr, D3; atol, rtol, mode) - - Mooncake.TestUtils.test_rule(rng, inv, D1; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, inv, D2; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, inv, D3; atol, rtol, mode) - end - - - @timedtestset "Index manipulations with scalartype $T" for T in eltypes - atol = precision(T) - rtol = precision(T) - - symmetricbraiding && @timedtestset "add_permute!" begin - A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) - α = randn(T) - β = randn(T) - - # repeat a couple times to get some distribution of arrows - for _ in 1:5 - p = randindextuple(numind(A)) - C = randn!(permute(A, p)) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_permute!, C, A, p, α, β; atol, rtol, mode) - A = C - end - end - - @timedtestset "add_transpose!" begin - A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) - α = randn(T) - β = randn(T) - - # repeat a couple times to get some distribution of arrows - for _ in 1:5 - p = randcircshift(numout(A), numin(A)) - C = randn!(transpose(A, p)) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) - A = C - end - end - - @timedtestset "add_braid!" begin - A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) - α = randn(T) - β = randn(T) - - # repeat a couple times to get some distribution of arrows - for _ in 1:5 - p = randcircshift(numout(A), numin(A)) - levels = tuple(randperm(numind(A))) - C = randn!(transpose(A, p)) - Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) - A = C - end - end - - @timedtestset "flip_n_twist!" begin - A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), twist!, A, 1; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), twist!, A, [1, 3]; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, twist!, A, 1; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, twist!, A, [1, 3]; atol, rtol, mode) - - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), flip, A, 1; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), flip, A, [1, 3]; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, flip, A, 1; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, flip, A, [1, 3]; atol, rtol, mode) - end - - @timedtestset "insert and remove units" begin - A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) - - for insertunit in (insertleftunit, insertrightunit) - Mooncake.TestUtils.test_rule(rng, insertunit, A, Val(1); atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, insertunit, A, Val(4); atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, insertunit, A', Val(2); atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false), insertunit, A, Val(1); atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = true), insertunit, A, Val(2); atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false, dual = true, conj = true), insertunit, A, Val(3); atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false, dual = true, conj = true), insertunit, A', Val(3); atol, rtol, mode) - end - - for i in 1:4 - B = insertleftunit(A, i; dual = rand(Bool)) - Mooncake.TestUtils.test_rule(rng, removeunit, B, Val(i); atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false), removeunit, B, Val(i); atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = true), removeunit, B, Val(i); atol, rtol, mode) - end - end - end - - symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes - atol = precision(T) - rtol = precision(T) - - @timedtestset "tensorcontract!" begin - for _ in 1:5 - d = 0 - local V1, V2, V3 - # retry a couple times to make sure there are at least some nonzero elements - for _ in 1:10 - k1 = rand(0:3) - k2 = rand(0:2) - k3 = rand(0:2) - V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) - V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) - V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) - d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) - d > 0 && break - end - ipA = randindextuple(length(V1) + length(V2)) - pA = _repartition(invperm(linearize(ipA)), length(V1)) - ipB = randindextuple(length(V2) + length(V3)) - pB = _repartition(invperm(linearize(ipB)), length(V2)) - pAB = randindextuple(length(V1) + length(V3)) - - α = randn(T) - β = randn(T) - V2_conj = prod(conj, V2; init = one(V[1])) - - A = randn(T, permute(V1 ← V2, ipA)) - B = randn(T, permute(V2 ← V3, ipB)) - C = randn!( - TensorOperations.tensoralloc_contract( - T, A, pA, false, B, pB, false, pAB, Val(false) - ) - ) - Mooncake.TestUtils.test_rule( - rng, TensorKit.blas_contract!, - C, A, pA, B, pB, pAB, α, β, - TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode - ) - end - end - - @timedtestset "trace_permute!" begin - for _ in 1:5 - k1 = rand(0:2) - k2 = rand(1:2) - V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) - V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) - - (_p, _q) = randindextuple(k1 + 2 * k2, k1) - p = _repartition(_p, rand(0:k1)) - q = _repartition(_q, k2) - ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) - A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) - - α = randn(T) - β = randn(T) - C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) - Mooncake.TestUtils.test_rule( - rng, TensorKit.trace_permute!, C, A, p, q, α, β, TensorOperations.DefaultBackend(); - atol, rtol, mode - ) - end - end - end - - @timedtestset "PlanarOperations with scalartype $T" for T in eltypes - atol = precision(T) - rtol = precision(T) - - @timedtestset "planarcontract!" begin - for _ in 1:5 - d = 0 - local V1, V2, V3, k1, k2, k3 - # retry a couple times to make sure there are at least some nonzero elements - for _ in 1:10 - k1 = rand(0:3) - k2 = rand(0:2) - k3 = rand(0:2) - V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) - V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) - V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) - d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) - d > 1 && break - end - k′ = rand(0:(k1 + k2)) - pA = randcircshift(k′, k1 + k2 - k′, k1) - ipA = _repartition(invperm(linearize(pA)), k′) - k′ = rand(0:(k2 + k3)) - pB = randcircshift(k′, k2 + k3 - k′, k2) - ipB = _repartition(invperm(linearize(pB)), k′) - # TODO: primal value already is broken for this? - # pAB = randcircshift(k1, k3) - pAB = _repartition(tuple((1:(k1 + k3))...), k1) - - α = randn(T) - β = randn(T) - - A = randn(T, permute(V1 ← V2, ipA)) - B = randn(T, permute(V2 ← V3, ipB)) - C = randn!( - TensorOperations.tensoralloc_contract( - T, A, pA, false, B, pB, false, pAB, Val(false) - ) - ) - Mooncake.TestUtils.test_rule( - rng, TensorKit.planarcontract!, C, A, pA, B, pB, pAB, α, β; - atol, rtol, mode, is_primitive = false - ) - end - end - - # TODO: currently broken - # @timedtestset "planartrace!" begin - # for _ in 1:5 - # k1 = rand(0:2) - # k2 = rand(0:1) - # V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) - # V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) - # V3 = prod(x -> x ⊗ x', V2[1:k2]; init = one(V[1])) - # V4 = prod(x -> x ⊗ x', V2[(k2 + 1):end]; init = one(V[1])) - # - # k′ = rand(0:(k1 + 2k2)) - # (_p, _q) = randcircshift(k′, k1 + 2k2 - k′, k1) - # p = _repartition(_p, rand(0:k1)) - # q = (tuple(_q[1:2:end]...), tuple(_q[2:2:end]...)) - # ip = _repartition(invperm(linearize((_p, _q))), k′) - # A = randn(T, permute(prod(V1) ⊗ V3 ← V4, ip)) - # - # α = randn(T) - # β = randn(T) - # C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) - # Mooncake.TestUtils.test_rule( - # rng, TensorKit.planartrace!, - # C, A, p, q, α, β, - # TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - # atol, rtol, mode - # ) - # end - # end - end -end diff --git a/test/autodiff/chainrules.jl b/test/chainrules/chainrules.jl similarity index 100% rename from test/autodiff/chainrules.jl rename to test/chainrules/chainrules.jl diff --git a/test/mooncake/indexmanipulations.jl b/test/mooncake/indexmanipulations.jl new file mode 100644 index 000000000..a2909c38f --- /dev/null +++ b/test/mooncake/indexmanipulations.jl @@ -0,0 +1,134 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + # ( + # Vect[FibonacciAnyon](:I => 2, :τ => 1), + # Vect[FibonacciAnyon](:I => 1, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 3), + # Vect[FibonacciAnyon](:I => 2, :τ => 2), + # ), +) +eltypes = (Float64,) # no complex support yet + +@timedtestset "Mooncake - Index Manipulations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = precision(T) + rtol = precision(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + + symmetricbraiding && @timedtestset "add_permute!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randindextuple(numind(A)) + C = randn!(permute(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_permute!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "add_transpose!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randcircshift(numout(A), numin(A)) + C = randn!(transpose(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "add_braid!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + # repeat a couple times to get some distribution of arrows + for _ in 1:5 + p = randcircshift(numout(A), numin(A)) + levels = tuple(randperm(numind(A))) + C = randn!(transpose(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + A = C + end + end + + @timedtestset "flip_n_twist!" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), twist!, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), twist!, A, [1, 3]; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, twist!, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, twist!, A, [1, 3]; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), flip, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), flip, A, [1, 3]; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, flip, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, flip, A, [1, 3]; atol, rtol, mode) + end + + @timedtestset "insert and remove units" begin + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + + for insertunit in (insertleftunit, insertrightunit) + Mooncake.TestUtils.test_rule(rng, insertunit, A, Val(1); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, insertunit, A, Val(4); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, insertunit, A', Val(2); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false), insertunit, A, Val(1); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = true), insertunit, A, Val(2); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false, dual = true, conj = true), insertunit, A, Val(3); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false, dual = true, conj = true), insertunit, A', Val(3); atol, rtol, mode) + end + + for i in 1:4 + B = insertleftunit(A, i; dual = rand(Bool)) + Mooncake.TestUtils.test_rule(rng, removeunit, B, Val(i); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = false), removeunit, B, Val(i); atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; copy = true), removeunit, B, Val(i); atol, rtol, mode) + end + end +end diff --git a/test/mooncake/linalg.jl b/test/mooncake/linalg.jl new file mode 100644 index 000000000..426619549 --- /dev/null +++ b/test/mooncake/linalg.jl @@ -0,0 +1,80 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + # ( + # Vect[FibonacciAnyon](:I => 2, :τ => 1), + # Vect[FibonacciAnyon](:I => 1, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 3), + # Vect[FibonacciAnyon](:I => 2, :τ => 2), + # ), +) +eltypes = (Float64,) # no complex support yet + +@timedtestset "Mooncake - LinearAlgebra: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = precision(T) + rtol = precision(T) + + C = randn(T, V[1] ⊗ V[2] ← V[5]) + A = randn(T, codomain(C) ← V[3] ⊗ V[4]) + B = randn(T, domain(A) ← domain(C)) + α = randn(T) + β = randn(T) + + Mooncake.TestUtils.test_rule(rng, mul!, C, A, B, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, mul!, C, A, B; atol, rtol, mode, is_primitive = false) + + Mooncake.TestUtils.test_rule(rng, norm, C, 2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, norm, C', 2; atol, rtol, mode) + + D1 = randn(T, V[1] ← V[1]) + D2 = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + D3 = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← V[1] ⊗ V[2] ⊗ V[3]) + + Mooncake.TestUtils.test_rule(rng, tr, D1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, tr, D2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, tr, D3; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, inv, D1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, inv, D2; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, inv, D3; atol, rtol, mode) +end diff --git a/test/mooncake/planaroperations.jl b/test/mooncake/planaroperations.jl new file mode 100644 index 000000000..cbdc7ec76 --- /dev/null +++ b/test/mooncake/planaroperations.jl @@ -0,0 +1,128 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup +using .TestSetup: _repartition + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + # ( + # Vect[FibonacciAnyon](:I => 2, :τ => 1), + # Vect[FibonacciAnyon](:I => 1, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 3), + # Vect[FibonacciAnyon](:I => 2, :τ => 2), + # ), +) +eltypes = (Float64,) # no complex support yet + +@timedtestset "Mooncake - PlanarOperations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = precision(T) + rtol = precision(T) + + @timedtestset "planarcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3, k1, k2, k3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 1 && break + end + k′ = rand(0:(k1 + k2)) + pA = randcircshift(k′, k1 + k2 - k′, k1) + ipA = _repartition(invperm(linearize(pA)), k′) + k′ = rand(0:(k2 + k3)) + pB = randcircshift(k′, k2 + k3 - k′, k2) + ipB = _repartition(invperm(linearize(pB)), k′) + # TODO: primal value already is broken for this? + # pAB = randcircshift(k1, k3) + pAB = _repartition(tuple((1:(k1 + k3))...), k1) + + α = randn(T) + β = randn(T) + + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.planarcontract!, C, A, pA, B, pB, pAB, α, β; + atol, rtol, mode, is_primitive = false + ) + end + end + + # TODO: currently broken + # @timedtestset "planartrace!" begin + # for _ in 1:5 + # k1 = rand(0:2) + # k2 = rand(0:1) + # V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + # V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + # V3 = prod(x -> x ⊗ x', V2[1:k2]; init = one(V[1])) + # V4 = prod(x -> x ⊗ x', V2[(k2 + 1):end]; init = one(V[1])) + # + # k′ = rand(0:(k1 + 2k2)) + # (_p, _q) = randcircshift(k′, k1 + 2k2 - k′, k1) + # p = _repartition(_p, rand(0:k1)) + # q = (tuple(_q[1:2:end]...), tuple(_q[2:2:end]...)) + # ip = _repartition(invperm(linearize((_p, _q))), k′) + # A = randn(T, permute(prod(V1) ⊗ V3 ← V4, ip)) + # + # α = randn(T) + # β = randn(T) + # C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + # Mooncake.TestUtils.test_rule( + # rng, TensorKit.planartrace!, + # C, A, p, q, α, β, + # TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + # atol, rtol, mode + # ) + # end + # end +end diff --git a/test/mooncake/tensoroperations.jl b/test/mooncake/tensoroperations.jl new file mode 100644 index 000000000..43372a011 --- /dev/null +++ b/test/mooncake/tensoroperations.jl @@ -0,0 +1,121 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + # ( + # Vect[FibonacciAnyon](:I => 2, :τ => 1), + # Vect[FibonacciAnyon](:I => 1, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 3), + # Vect[FibonacciAnyon](:I => 2, :τ => 2), + # ), +) +eltypes = (Float64,) # no complex support yet + +@timedtestset "Mooncake - TensorOperations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = precision(T) + rtol = precision(T) + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + + symmetricbraiding && @timedtestset "tensorcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init = one(V[1])) + + A = randn(T, permute(V1 ← V2, ipA)) + B = randn(T, permute(V2 ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, false, B, pB, false, pAB, Val(false) + ) + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, A, pA, B, pB, pAB, α, β, + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) + end + end + + symmetricbraiding && @timedtestset "trace_permute!" begin + for _ in 1:5 + k1 = rand(0:2) + k2 = rand(1:2) + V1 = map(v -> rand(Bool) ? v' : v, rand(V, k1)) + V2 = map(v -> rand(Bool) ? v' : v, rand(V, k2)) + + (_p, _q) = randindextuple(k1 + 2 * k2, k1) + p = _repartition(_p, rand(0:k1)) + q = _repartition(_q, k2) + ip = _repartition(invperm(linearize((_p, _q))), rand(0:(k1 + 2 * k2))) + A = randn(T, permute(prod(V1) ⊗ prod(V2) ← prod(V2), ip)) + + α = randn(T) + β = randn(T) + C = randn!(TensorOperations.tensoralloc_add(T, A, p, false, Val(false))) + Mooncake.TestUtils.test_rule( + rng, TensorKit.trace_permute!, C, A, p, q, α, β, TensorOperations.DefaultBackend(); + atol, rtol, mode + ) + end + end +end diff --git a/test/mooncake/vectorinterface.jl b/test/mooncake/vectorinterface.jl new file mode 100644 index 000000000..131521c44 --- /dev/null +++ b/test/mooncake/vectorinterface.jl @@ -0,0 +1,75 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + # ( + # Vect[FibonacciAnyon](:I => 2, :τ => 1), + # Vect[FibonacciAnyon](:I => 1, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 2)', + # Vect[FibonacciAnyon](:I => 2, :τ => 3), + # Vect[FibonacciAnyon](:I => 2, :τ => 2), + # ), +) +eltypes = (Float64,) # no complex support yet + +@timedtestset "Mooncake - VectorInterface: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = precision(T) + rtol = precision(T) + + C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) + α = randn(T) + β = randn(T) + + Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, mode, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, mode, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, mode) + + Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode) +end diff --git a/test/runtests.jl b/test/runtests.jl index 3b0bfe8b0..8f58d7dc8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,7 +57,7 @@ istestfile(fn) = endswith(fn, ".jl") && !contains(fn, "setup") # somehow AD tests are unreasonably slow on Apple CI # and ChainRulesTestUtils doesn't like prereleases - if group == "autodiff" + if group == "chainrules" Sys.isapple() && get(ENV, "CI", "false") == "true" && continue isempty(VERSION.prerelease) || continue end diff --git a/test/setup.jl b/test/setup.jl index 5c8516eb9..3f6dde923 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -1,5 +1,6 @@ module TestSetup +export randindextuple, randcircshift, _repartition, trivtuple export smallset, randsector, hasfusiontensor, force_planar export random_fusion export sectorlist @@ -11,9 +12,46 @@ using Test: @test using TensorKit using TensorKit: ℙ, PlanarTrivial using Base.Iterators: take, product +using TupleTools Random.seed!(123456) +# IndexTuple utility +# ------------------ +function randindextuple(N::Int, k::Int = rand(0:N)) + @assert 0 ≤ k ≤ N + _p = randperm(N) + return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...)) +end +function randcircshift(N₁::Int, N₂::Int, k::Int = rand(0:(N₁ + N₂))) + N = N₁ + N₂ + @assert 0 ≤ k ≤ N + p = TupleTools.vcat(ntuple(identity, N₁), reverse(ntuple(identity, N₂) .+ N₁)) + n = rand(0:N) + _p = TupleTools.circshift(p, n) + return (tuple(_p[1:k]...), reverse(tuple(_p[(k + 1):end]...))) +end + +trivtuple(N) = ntuple(identity, N) + +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +# Sector utility +# -------------- smallset(::Type{I}) where {I <: Sector} = take(values(I), 5) function smallset(::Type{ProductSector{Tuple{I1, I2}}}) where {I1, I2} iter = product(smallset(I1), smallset(I2)) From b32a71c552f8da9c19343f910ef1764e1d7ddb61 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 22 Jan 2026 17:48:04 -0500 Subject: [PATCH 23/54] add missing imports --- test/setup.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/setup.jl b/test/setup.jl index 3f6dde923..9b3a51d01 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -11,6 +11,7 @@ using Random using Test: @test using TensorKit using TensorKit: ℙ, PlanarTrivial +using TensorOperations: IndexTuple, Index2Tuple using Base.Iterators: take, product using TupleTools From 3bb332ee92b34474a52254035294c26b24717e4c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 26 Jan 2026 10:29:03 -0500 Subject: [PATCH 24/54] remove the use of the internal `Mooncake._rdata` --- ext/TensorKitMooncakeExt/indexmanipulations.jl | 4 ++-- ext/TensorKitMooncakeExt/linalg.jl | 2 +- ext/TensorKitMooncakeExt/planaroperations.jl | 4 ++-- ext/TensorKitMooncakeExt/tensoroperations.jl | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index 8a97ac81c..450f391e0 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -53,7 +53,7 @@ for transform in (:permute, :transpose) Δαr = if isnothing(Ap) NoRData() else - Mooncake._rdata(inner(Ap, ΔC)) + inner(Ap, ΔC) end Δβr = pullback_dβ(C, ΔC, β) @@ -120,7 +120,7 @@ function Mooncake.rrule!!( Δαr = if isnothing(Ap) NoRData() else - Mooncake._rdata(inner(Ap, ΔC)) + inner(Ap, ΔC) end Δβr = pullback_dβ(C, ΔC, β) diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 2a77792c9..3d5ac8610 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -31,7 +31,7 @@ function Mooncake.rrule!!( mul!(ΔB, A', ΔC, conj(α), One()) ΔAr = NoRData() ΔBr = NoRData() - Δαr = isnothing(AB) ? NoRData() : Mooncake._rdata(inner(AB, ΔC)) + Δαr = isnothing(AB) ? NoRData() : inner(AB, ΔC) Δβr = pullback_dβ(C, ΔC, β) ΔCr = pullback_dC!(ΔC, β) diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl index df75d60fe..58d714d82 100644 --- a/ext/TensorKitMooncakeExt/planaroperations.jl +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -88,7 +88,7 @@ function planartrace_pullback_Δα( TensorKit.planartrace!(At, A, p, q, One(), Zero(), backend, allocator) Δα = inner(At, ΔC) TO.tensorfree!(At, allocator) - return Mooncake._rdata(Δα) + return Δα end function planartrace_pullback_Δβ(ΔC, C, β) @@ -96,5 +96,5 @@ function planartrace_pullback_Δβ(ΔC, C, β) Tdβ === NoRData && return NoRData() Δβ = inner(C, ΔC) - return Mooncake._rdata(Δβ) + return Δβ end diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 66c3f257a..30850bb8c 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -120,7 +120,7 @@ function blas_contract_pullback_Δα( AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator) Δα = inner(AB, ΔC) - return Mooncake._rdata(Δα) + return Δα end # tensortrace! @@ -197,5 +197,5 @@ function trace_permute_pullback_Δα( # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α At = TO.tensortrace(A, p, q, false, One(), backend) Δα = inner(At, ΔC) - return Mooncake._rdata(Δα) + return Δα end From 079740afb4a4bb062dccb46bc0021385a47082bb Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 26 Jan 2026 10:41:28 -0500 Subject: [PATCH 25/54] add comments about `NoRData()` --- ext/TensorKitMooncakeExt/indexmanipulations.jl | 4 ++-- ext/TensorKitMooncakeExt/planaroperations.jl | 4 ++-- ext/TensorKitMooncakeExt/tensoroperations.jl | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index 450f391e0..fe871a52d 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -57,7 +57,7 @@ for transform in (:permute, :transpose) end Δβr = pullback_dβ(C, ΔC, β) - ΔCr = pullback_dC!(ΔC, β) + ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() return NoRData(), ΔCr, ΔAr, NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... end @@ -124,7 +124,7 @@ function Mooncake.rrule!!( end Δβr = pullback_dβ(C, ΔC, β) - ΔCr = pullback_dC!(ΔC, β) + ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() return NoRData(), ΔCr, ΔAr, NoRData(), NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... end diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl index 58d714d82..5fe762cbb 100644 --- a/ext/TensorKitMooncakeExt/planaroperations.jl +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -34,10 +34,10 @@ function Mooncake.rrule!!( function planartrace_pullback(::NoRData) copy!(C, C_cache) - ΔAr = planartrace_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend, allocator) + ΔAr = planartrace_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend, allocator) # this typically returns NoRData() Δαr = planartrace_pullback_Δα(ΔC, A, p, q, α, backend, allocator) Δβr = planartrace_pullback_Δβ(ΔC, C, β) - ΔCr = planartrace_pullback_ΔC!(ΔC, β) + ΔCr = planartrace_pullback_ΔC!(ΔC, β) # this typically returns NoRData() return NoRData(), ΔCr, ΔAr, NoRData(), NoRData(), diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 30850bb8c..6c3f7442e 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -38,15 +38,15 @@ function Mooncake.rrule!!( ΔAr = blas_contract_pullback_ΔA!( ΔA, ΔC, A, pA, B, pB, pAB, α, backend, allocator - ) + ) # this typically returns NoRData() ΔBr = blas_contract_pullback_ΔB!( ΔB, ΔC, A, pA, B, pB, pAB, α, backend, allocator - ) + ) # this typically returns NoRData() Δαr = blas_contract_pullback_Δα( ΔC, A, pA, B, pB, pAB, α, backend, allocator ) Δβr = pullback_dβ(ΔC, C, β) - ΔCr = pullback_dC!(ΔC, β) + ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() return NoRData(), ΔCr, ΔAr, NoRData(), @@ -159,10 +159,10 @@ function Mooncake.rrule!!( function trace_permute_pullback(::NoRData) copy!(C, C_cache) - ΔAr = trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) + ΔAr = trace_permute_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend) # this typically returns NoRData() Δαr = trace_permute_pullback_Δα(ΔC, A, p, q, α, backend) Δβr = pullback_dβ(ΔC, C, β) - ΔCr = pullback_dC!(ΔC, β) + ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() return NoRData(), ΔCr, ΔAr, NoRData(), NoRData(), From 553ee2ba6df1358e4bb6b184cda6cf336ea7f16a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 26 Jan 2026 10:42:14 -0500 Subject: [PATCH 26/54] add TODO --- ext/TensorKitMooncakeExt/planaroperations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl index 5fe762cbb..9633dfad6 100644 --- a/ext/TensorKitMooncakeExt/planaroperations.jl +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -49,6 +49,7 @@ end planartrace_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) +# TODO: Fix planartrace pullback # This implementation is slightly more involved than its non-planar counterpart # this is because we lack a general `pAB` argument in `planarcontract`, and need # to keep things planar along the way. From 53c3c342738508cd713c39b68771dcf1c9007925 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 29 Jan 2026 09:45:04 -0500 Subject: [PATCH 27/54] correctly implement `_needs_tangent` --- ext/TensorKitMooncakeExt/utility.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index 261c1dcc2..83c6d8592 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -1,7 +1,6 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) -function _needs_tangent(::Type{T}) where {T <: Number} - return Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData() -end +_needs_tangent(::Type{T}) where {T <: Number} = + Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData # IndexTuple utility # ------------------ From c4c8cb976acd0660057eef32201868fb4fa8a9dd Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 29 Jan 2026 11:27:45 -0500 Subject: [PATCH 28/54] update to Mooncake 0.5 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8b4980ecc..e382ac435 100644 --- a/Project.toml +++ b/Project.toml @@ -45,7 +45,7 @@ GPUArrays = "11.3.1" LRUCache = "1.0.2" LinearAlgebra = "1" MatrixAlgebraKit = "0.6.3" -Mooncake = "0.4.183" +Mooncake = "0.5" OhMyThreads = "0.8.0" Printf = "1" Random = "1" From d3afdfe93b4dc207a40a8cfb4ec8c19b2ae8fc96 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 29 Jan 2026 11:27:55 -0500 Subject: [PATCH 29/54] add TensorMap tangent type --- .../TensorKitMooncakeExt.jl | 5 +- ext/TensorKitMooncakeExt/tangent.jl | 198 ++++++++++++++++++ ext/TensorKitMooncakeExt/utility.jl | 1 + 3 files changed, 203 insertions(+), 1 deletion(-) diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index d3894c874..7067bb280 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -1,13 +1,16 @@ module TensorKitMooncakeExt using Mooncake -using Mooncake: @zero_derivative, @is_primitive, DefaultCtx, ReverseMode, NoFData, NoRData, CoDual, arrayify, primal +using Mooncake: @zero_derivative, @is_primitive, + DefaultCtx, MinimalCtx, ReverseMode, NoFData, NoRData, CoDual, Dual, + arrayify, primal, tangent using TensorKit import TensorKit as TK using VectorInterface using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize import TensorOperations as TO using TupleTools +using Random: AbstractRNG include("utility.jl") include("tangent.jl") diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 9fa6e401a..12921faec 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -13,3 +13,201 @@ function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TensorKit.AdjointTensorMap}) A, ΔA = arrayify(A_ΔA) return A', ΔA' end + +# Define the tangent type of a TensorMap to be TensorMap itself. +# This has a number of benefits, but also correctly alters the +# inner product when dealing with non-abelian symmetries. +# +# Note: this implementation is technically a little lazy, since we are +# assuming that the tangent type of the underlying storage is also given +# by that same type. This should in principle work out fine, and will only +# fail for types that would be self-referential, which we choose to not support +# for now. + +Mooncake.@foldable Mooncake.tangent_type(::Type{T}) where {T <: TensorMap} = T +Mooncake.@foldable Mooncake.tangent_type(::Type{T}, ::Type{NoRData}) where {T <: TensorMap} = T + +Mooncake.@foldable Mooncake.fdata_type(::Type{T}) where {T <: TensorMap} = T +Mooncake.@foldable Mooncake.rdata_type(::Type{T}) where {T <: TensorMap} = NoRData + +Mooncake.tangent(t::TensorMap, ::NoRData) = t +Mooncake.zero_tangent_internal(t::TensorMap, ::Mooncake.MaybeCache) = zerovector(t) + +Mooncake.randn_tangent_internal(rng::AbstractRNG, p::TensorMap, ::Mooncake.MaybeCache) = + randn!(rng, similar(p)) + +Mooncake.set_to_zero_internal!!(::Mooncake.SetToZeroCache, t::TensorMap) = zerovector!(t) +Mooncake.increment_internal!!(::Mooncake.IncCache, x::T, y::T) where {T <: TensorMap} = add!(x, y) + + +Mooncake._add_to_primal_internal(::Mooncake.MaybeCache, p::T, t::T, unsafe::Bool) where {T <: TensorMap} = add(p, t) +Mooncake.tangent_to_primal_internal!!(p::T, t::T, ::Mooncake.MaybeCache) where {T <: TensorMap} = copy!(p, t) +Mooncake.primal_to_tangent_internal!!(t::T, p::T, ::Mooncake.MaybeCache) where {T <: TensorMap} = copy!(t, p) + +Mooncake._dot_internal(::Mooncake.MaybeCache, t::T, s::T) where {T <: TensorMap} = Float64(real(inner(t, s))) +Mooncake._scale_internal(::Mooncake.MaybeCache, a::Float64, t::T) where {T <: TensorMap} = scale(t, a) + +function Mooncake.TestUtils.populate_address_map_internal( + m::Mooncake.AddressMap, primal::T, tangent::T + ) where {T <: TensorMap} + return Mooncake.TestUtils.populate_address_map_internal(m, primal.data, tangent.data) +end + +function Mooncake.__verify_fdata_value(::IdDict{Any, Nothing}, p::TensorMap, f::TensorMap) + space(p) == space(f) || + throw(Mooncake.InvalidFDataException(lazy"p has space $(space(p)) but f has size $(space(f))")) + return nothing +end + + +@is_primitive MinimalCtx Tuple{typeof(Mooncake.lgetfield), <:TensorMap, Val} + +# TODO: double-check if this has to include quantum dimensinos for non-abelian? +function Mooncake.frule!!( + ::Dual{typeof(Mooncake.lgetfield)}, t::Dual{<:TensorMap}, ::Dual{Val{FieldName}} + ) where {FieldName} + y = getfield(primal(t), FieldName) + + return if FieldName === 1 || FieldName === :data + dval = tangent(t).data + Dual(val, dval) + elseif FieldName === 2 || FieldName === :space + Dual(val, NoFData()), getfield_pullback + else + throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) + end +end + +function Mooncake.rrule!!( + ::CoDual{typeof(Mooncake.lgetfield)}, t::CoDual{<:TensorMap}, ::CoDual{Val{FieldName}} + ) where {FieldName} + val = getfield(primal(t), FieldName) + getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3)) + + return if FieldName === 1 || FieldName === :data + dval = Mooncake.tangent(t).data + CoDual(val, dval), getfield_pullback + elseif FieldName === 2 || FieldName === :space + Mooncake.zero_fcodual(val), getfield_pullback + else + throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) + end +end + +@is_primitive MinimalCtx Tuple{typeof(getfield), <:TensorMap, Any, Vararg{Symbol}} + +Base.@constprop :aggressive function Mooncake.frule!!( + ::Dual{typeof(getfield)}, t::Dual{<:TensorMap}, name::Dual + ) + y = getfield(primal(t), primal(name)) + + return if primal(name) === 1 || primal(name) === :data + dval = tangent(t).data + Dual(val, dval) + elseif primal(name) === 2 || primal(name) === :space + Dual(val, NoFData()) + else + throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) + end +end + +Base.@constprop :aggressive function Mooncake.rrule!!( + ::CoDual{typeof(getfield)}, t::CoDual{<:TensorMap}, name::CoDual + ) + val = getfield(primal(t), primal(name)) + getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3)) + + return if primal(name) === 1 || primal(name) === :data + dval = Mooncake.tangent(t).data + CoDual(val, dval), getfield_pullback + elseif primal(name) === 2 || primal(name) === :space + Mooncake.zero_fcodual(val), getfield_pullback + else + throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) + end +end + +Base.@constprop :aggressive function Mooncake.frule!!( + ::Dual{typeof(getfield)}, t::Dual{<:TensorMap}, name::Dual, order::Dual + ) + y = getfield(primal(t), primal(name), primal(order)) + + return if primal(name) === 1 || primal(name) === :data + dval = tangent(t).data + Dual(val, dval) + elseif primal(name) === 2 || primal(name) === :space + Dual(val, NoFData()) + else + throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) + end +end + +Base.@constprop :aggressive function Mooncake.rrule!!( + ::CoDual{typeof(getfield)}, t::CoDual{<:TensorMap}, name::CoDual, order::CoDual + ) + val = getfield(primal(t), primal(name), primal(order)) + getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 4)) + + return if primal(name) === 1 || primal(name) === :data + dval = Mooncake.tangent(t).data + CoDual(val, dval), getfield_pullback + elseif primal(name) === 2 || primal(name) === :space + Mooncake.zero_fcodual(val), getfield_pullback + else + throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) + end +end + + +@is_primitive MinimalCtx Tuple{typeof(Mooncake.lgetfield), <:TensorMap, Val, Val} + +# TODO: double-check if this has to include quantum dimensinos for non-abelian? +function Mooncake.frule!!( + ::Dual{typeof(Mooncake.lgetfield)}, t::Dual{<:TensorMap}, ::Dual{Val{FieldName}}, ::Dual{Val{Order}} + ) where {FieldName, Order} + y = getfield(primal(t), FieldName, Order) + + return if FieldName === 1 || FieldName === :data + dval = tangent(t).data + Dual(val, dval) + elseif FieldName === 2 || FieldName === :space + Dual(val, NoFData()) + else + throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) + end +end + +function Mooncake.rrule!!( + ::CoDual{typeof(Mooncake.lgetfield)}, t::CoDual{<:TensorMap}, ::CoDual{Val{FieldName}}, ::CoDual{Val{Order}} + ) where {FieldName, Order} + val = getfield(primal(t), FieldName, Order) + getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 4)) + + return if FieldName === 1 || FieldName === :data + dval = Mooncake.tangent(t).data + CoDual(val, dval), getfield_pullback + elseif FieldName === 2 || FieldName === :space + Mooncake.zero_fcodual(val), getfield_pullback + else + throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) + end +end + + +Mooncake.@zero_derivative Mooncake.MinimalCtx Tuple{typeof(Mooncake._new_), Type{TensorMap{T, S, N₁, N₂, A}}, UndefInitializer, TensorMapSpace{S, N₁, N₂}} where {T, S, N₁, N₂, A} +@is_primitive Mooncake.MinimalCtx Tuple{typeof(Mooncake._new_), Type{TensorMap{T, S, N₁, N₂, A}}, A, TensorMapSpace{S, N₁, N₂}} where {T, S, N₁, N₂, A} + +function Mooncake.frule!!( + ::Dual{typeof(Mooncake._new_)}, ::Dual{Type{TensorMap{T, S, N₁, N₂, A}}}, data::Dual{A}, space::Dual{TensorMapSpace{S, N₁, N₂}} + ) where {T, S, N₁, N₂, A} + t = Mooncake._new_(TensorMap{T, S, N₁, N₂, A}, primal(data), primal(space)) + dt = Mooncake._new_(TensorMap{T, S, N₁, N₂, A}, tangent(data), primal(space)) + return Dual(t, dt) +end + +function Mooncake.rrule!!( + ::CoDual{typeof(Mooncake._new_)}, ::CoDual{Type{TensorMap{T, S, N₁, N₂, A}}}, data::CoDual{A}, space::CoDual{TensorMapSpace{S, N₁, N₂}} + ) where {T, S, N₁, N₂, A} + return Mooncake.zero_fcodual(Mooncake._new_(TensorMap{T, S, N₁, N₂, A}, primal(data), primal(space))), + Returns(ntuple(Returns(NoRData()), 4)) +end diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index 83c6d8592..bfbca5264 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -27,6 +27,7 @@ end # A VectorSpace has no meaningful notion of a vector space (tangent space) Mooncake.tangent_type(::Type{<:VectorSpace}) = Mooncake.NoTangent +Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent @zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} From 49633930258cfd20a2286ae3fee2cfe84709de86 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 29 Jan 2026 11:30:51 -0500 Subject: [PATCH 30/54] fix stupid tolerance mistake --- test/mooncake/indexmanipulations.jl | 4 ++-- test/mooncake/linalg.jl | 4 ++-- test/mooncake/planaroperations.jl | 4 ++-- test/mooncake/tensoroperations.jl | 4 ++-- test/mooncake/vectorinterface.jl | 4 ++-- test/setup.jl | 5 +++++ 6 files changed, 15 insertions(+), 10 deletions(-) diff --git a/test/mooncake/indexmanipulations.jl b/test/mooncake/indexmanipulations.jl index a2909c38f..5155cd46f 100644 --- a/test/mooncake/indexmanipulations.jl +++ b/test/mooncake/indexmanipulations.jl @@ -51,8 +51,8 @@ spacelist = ( eltypes = (Float64,) # no complex support yet @timedtestset "Mooncake - Index Manipulations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes - atol = precision(T) - rtol = precision(T) + atol = default_tol(T) + rtol = default_tol(T) symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding symmetricbraiding && @timedtestset "add_permute!" begin diff --git a/test/mooncake/linalg.jl b/test/mooncake/linalg.jl index 426619549..67d19a66b 100644 --- a/test/mooncake/linalg.jl +++ b/test/mooncake/linalg.jl @@ -51,8 +51,8 @@ spacelist = ( eltypes = (Float64,) # no complex support yet @timedtestset "Mooncake - LinearAlgebra: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes - atol = precision(T) - rtol = precision(T) + atol = default_tol(T) + rtol = default_tol(T) C = randn(T, V[1] ⊗ V[2] ← V[5]) A = randn(T, codomain(C) ← V[3] ⊗ V[4]) diff --git a/test/mooncake/planaroperations.jl b/test/mooncake/planaroperations.jl index cbdc7ec76..eb1a265d4 100644 --- a/test/mooncake/planaroperations.jl +++ b/test/mooncake/planaroperations.jl @@ -52,8 +52,8 @@ spacelist = ( eltypes = (Float64,) # no complex support yet @timedtestset "Mooncake - PlanarOperations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes - atol = precision(T) - rtol = precision(T) + atol = default_tol(T) + rtol = default_tol(T) @timedtestset "planarcontract!" begin for _ in 1:5 diff --git a/test/mooncake/tensoroperations.jl b/test/mooncake/tensoroperations.jl index 43372a011..e0ab89280 100644 --- a/test/mooncake/tensoroperations.jl +++ b/test/mooncake/tensoroperations.jl @@ -51,8 +51,8 @@ spacelist = ( eltypes = (Float64,) # no complex support yet @timedtestset "Mooncake - TensorOperations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes - atol = precision(T) - rtol = precision(T) + atol = default_tol(T) + rtol = default_tol(T) symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding symmetricbraiding && @timedtestset "tensorcontract!" begin diff --git a/test/mooncake/vectorinterface.jl b/test/mooncake/vectorinterface.jl index 131521c44..3e439b3fc 100644 --- a/test/mooncake/vectorinterface.jl +++ b/test/mooncake/vectorinterface.jl @@ -51,8 +51,8 @@ spacelist = ( eltypes = (Float64,) # no complex support yet @timedtestset "Mooncake - VectorInterface: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes - atol = precision(T) - rtol = precision(T) + atol = default_tol(T) + rtol = default_tol(T) C = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) A = randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) diff --git a/test/setup.jl b/test/setup.jl index 9b3a51d01..c99fbdad1 100644 --- a/test/setup.jl +++ b/test/setup.jl @@ -1,6 +1,7 @@ module TestSetup export randindextuple, randcircshift, _repartition, trivtuple +export default_tol export smallset, randsector, hasfusiontensor, force_planar export random_fusion export sectorlist @@ -51,6 +52,10 @@ function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) return _repartition(p, TensorKit.numout(t)) end +# Float32 and finite differences don't mix well +default_tol(::Type{<:Union{Float32, Complex{Float32}}}) = 1.0e-2 +default_tol(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-5 + # Sector utility # -------------- smallset(::Type{I}) where {I <: Sector} = take(values(I), 5) From 1691db934db29d3b8ac540f5670d4fcd170ad36b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 29 Jan 2026 11:31:02 -0500 Subject: [PATCH 31/54] enable complex tests --- test/mooncake/indexmanipulations.jl | 16 ++++++++-------- test/mooncake/linalg.jl | 17 ++++++++--------- test/mooncake/planaroperations.jl | 16 ++++++++-------- test/mooncake/tensoroperations.jl | 16 ++++++++-------- test/mooncake/vectorinterface.jl | 16 ++++++++-------- 5 files changed, 40 insertions(+), 41 deletions(-) diff --git a/test/mooncake/indexmanipulations.jl b/test/mooncake/indexmanipulations.jl index 5155cd46f..945f4482b 100644 --- a/test/mooncake/indexmanipulations.jl +++ b/test/mooncake/indexmanipulations.jl @@ -40,15 +40,15 @@ spacelist = ( Vect[SU2Irrep](1 // 2 => 2), Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', ), - # ( - # Vect[FibonacciAnyon](:I => 2, :τ => 1), - # Vect[FibonacciAnyon](:I => 1, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 3), - # Vect[FibonacciAnyon](:I => 2, :τ => 2), - # ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), ) -eltypes = (Float64,) # no complex support yet +eltypes = (Float64, ComplexF64) @timedtestset "Mooncake - Index Manipulations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes atol = default_tol(T) diff --git a/test/mooncake/linalg.jl b/test/mooncake/linalg.jl index 67d19a66b..ead21f7a1 100644 --- a/test/mooncake/linalg.jl +++ b/test/mooncake/linalg.jl @@ -1,6 +1,5 @@ using Test, TestExtras using TensorKit -using TensorOperations using Mooncake using Random @@ -40,15 +39,15 @@ spacelist = ( Vect[SU2Irrep](1 // 2 => 2), Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', ), - # ( - # Vect[FibonacciAnyon](:I => 2, :τ => 1), - # Vect[FibonacciAnyon](:I => 1, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 3), - # Vect[FibonacciAnyon](:I => 2, :τ => 2), - # ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), ) -eltypes = (Float64,) # no complex support yet +eltypes = (Float64, ComplexF64) @timedtestset "Mooncake - LinearAlgebra: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes atol = default_tol(T) diff --git a/test/mooncake/planaroperations.jl b/test/mooncake/planaroperations.jl index eb1a265d4..dcc424b9a 100644 --- a/test/mooncake/planaroperations.jl +++ b/test/mooncake/planaroperations.jl @@ -41,15 +41,15 @@ spacelist = ( Vect[SU2Irrep](1 // 2 => 2), Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', ), - # ( - # Vect[FibonacciAnyon](:I => 2, :τ => 1), - # Vect[FibonacciAnyon](:I => 1, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 3), - # Vect[FibonacciAnyon](:I => 2, :τ => 2), - # ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), ) -eltypes = (Float64,) # no complex support yet +eltypes = (Float64, ComplexF64) @timedtestset "Mooncake - PlanarOperations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes atol = default_tol(T) diff --git a/test/mooncake/tensoroperations.jl b/test/mooncake/tensoroperations.jl index e0ab89280..922ac227a 100644 --- a/test/mooncake/tensoroperations.jl +++ b/test/mooncake/tensoroperations.jl @@ -40,15 +40,15 @@ spacelist = ( Vect[SU2Irrep](1 // 2 => 2), Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', ), - # ( - # Vect[FibonacciAnyon](:I => 2, :τ => 1), - # Vect[FibonacciAnyon](:I => 1, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 3), - # Vect[FibonacciAnyon](:I => 2, :τ => 2), - # ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), ) -eltypes = (Float64,) # no complex support yet +eltypes = (Float64, ComplexF64) @timedtestset "Mooncake - TensorOperations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes atol = default_tol(T) diff --git a/test/mooncake/vectorinterface.jl b/test/mooncake/vectorinterface.jl index 3e439b3fc..d43f7014d 100644 --- a/test/mooncake/vectorinterface.jl +++ b/test/mooncake/vectorinterface.jl @@ -40,15 +40,15 @@ spacelist = ( Vect[SU2Irrep](1 // 2 => 2), Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', ), - # ( - # Vect[FibonacciAnyon](:I => 2, :τ => 1), - # Vect[FibonacciAnyon](:I => 1, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 2)', - # Vect[FibonacciAnyon](:I => 2, :τ => 3), - # Vect[FibonacciAnyon](:I => 2, :τ => 2), - # ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), ) -eltypes = (Float64,) # no complex support yet +eltypes = (Float64, ComplexF64) @timedtestset "Mooncake - VectorInterface: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes atol = default_tol(T) From b648a30de917d00eadefd4f698fe813a956a12ba Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 29 Jan 2026 11:31:10 -0500 Subject: [PATCH 32/54] add tangent type test --- Project.toml | 9 ++++--- test/mooncake/tangent.jl | 58 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 3 deletions(-) create mode 100644 test/mooncake/tangent.jl diff --git a/Project.toml b/Project.toml index e382ac435..7deccc56f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorKit" uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" -authors = ["Jutho Haegeman, Lukas Devos"] version = "0.16.0" +authors = ["Jutho Haegeman, Lukas Devos"] [deps] LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" @@ -22,8 +22,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [extensions] TensorKitAdaptExt = "Adapt" @@ -34,6 +34,7 @@ TensorKitMooncakeExt = "Mooncake" [compat] Adapt = "4" +AllocCheck = "0.2.3" Aqua = "0.6, 0.7, 0.8" ArgParse = "1.2.0" CUDA = "5.9" @@ -64,6 +65,7 @@ julia = "1.10" [extras] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -72,6 +74,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" @@ -82,4 +85,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"] +test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"] diff --git a/test/mooncake/tangent.jl b/test/mooncake/tangent.jl new file mode 100644 index 000000000..5b001fc51 --- /dev/null +++ b/test/mooncake/tangent.jl @@ -0,0 +1,58 @@ +using Test, TestExtras +using TensorKit +using Mooncake +using Random +using JET, AllocCheck + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup +using .TestSetup: _repartition + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +# only run on Linux since allocation tests are broken on other versions +Sys.islinux() && @timedtestset "Mooncake - Tangent type: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) + Mooncake.TestUtils.test_data(rng, A) +end From ea8cf9edad46e2b1a7f4c899ef747a3e55b5c6bf Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 29 Jan 2026 14:45:42 -0500 Subject: [PATCH 33/54] correct arrayify --- ext/TensorKitMooncakeExt/tangent.jl | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 12921faec..93adff7e7 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -1,12 +1,7 @@ -function Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) - A = Mooncake.primal(A_dA) - dA_fw = Mooncake.tangent(A_dA) - data = dA_fw.data.data - dA = typeof(A)(data, A.space) - return A, dA -end +Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) = arrayify(primal(A_dA), tangent(A_dA)) +Mooncake.arrayify(A::TensorMap, dA::TensorMap) = (A, dA) -function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TensorKit.AdjointTensorMap}) +function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TK.AdjointTensorMap}) Aᴴ = Mooncake.primal(Aᴴ_ΔAᴴ) ΔAᴴ = Mooncake.tangent(Aᴴ_ΔAᴴ) A_ΔA = CoDual(Aᴴ', ΔAᴴ.data.parent) From 0f6891b811d9892141685cedd04a1365d83b190e Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 29 Jan 2026 17:13:39 -0500 Subject: [PATCH 34/54] fix indexmanipulations --- .../indexmanipulations.jl | 47 +++++++++---------- ext/TensorKitMooncakeExt/tangent.jl | 47 ++++++++++++------- 2 files changed, 52 insertions(+), 42 deletions(-) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index fe871a52d..28c142522 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -193,7 +193,7 @@ function Mooncake.rrule!!(::CoDual{typeof(flip)}, t_Δt::CoDual{<:AbstractTensor _, Δt_flipped = arrayify(t_flipped_Δt_flipped) function twist_pullback(::NoRData) - copy!(Δt, flip(Δt_flipped, inds; inv = !inv)) + add!(Δt, flip(Δt_flipped, inds; inv = !inv)) return ntuple(Returns(NoRData()), 3) end @@ -214,7 +214,7 @@ function Mooncake.rrule!!( _, Δt_flipped = arrayify(t_flipped_Δt_flipped) function twist_pullback(::NoRData) - copy!(Δt, flip(Δt_flipped, inds; inv = !inv)) + add!(Δt, flip(Δt_flipped, inds; inv = !inv)) return ntuple(Returns(NoRData()), 5) end @@ -237,11 +237,10 @@ for insertunit in (:insertleftunit, :insertrightunit) # sharing address spaces if tsrc isa TensorMap tsrc_cache = copy(tsrc) - tdst = $insertunit(tsrc, ival) - # note: this is somewhat of a hack that makes use of the fact that the tangent is - # encoded without any information about the space, which allows us to simply reuse - # the tangent exactly without having to modify the space information - tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + tdst_Δtdst = CoDual( + $insertunit(tsrc, ival), + $insertunit(Mooncake.tangent(tsrc_Δtsrc), ival) + ) else tsrc_cache = nothing tdst = $insertunit(tsrc, ival) @@ -253,7 +252,7 @@ for insertunit in (:insertleftunit, :insertrightunit) function $insertunit_pullback(::NoRData) if isnothing(tsrc_cache) for (c, b) in blocks(Δtdst) - copy!(block(Δtsrc, c), b) + add!(block(Δtsrc, c), b) end else copy!(tsrc, tsrc_cache) @@ -278,10 +277,10 @@ for insertunit in (:insertleftunit, :insertrightunit) if tsrc isa TensorMap && !get(kwargs, :copy, false) tsrc_cache = copy(tsrc) tdst = $insertunit(tsrc, ival; kwargs...) - # note: this is somewhat of a hack that makes use of the fact that the tangent is - # encoded without any information about the space, which allows us to simply reuse - # the tangent exactly without having to modify the space information - tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + tdst_Δtdst = CoDual( + $insertunit(tsrc, ival; kwargs...), + $insertunit(Mooncake.tangent(tsrc_Δtsrc), ival; kwargs...) + ) else tsrc_cache = nothing tdst = $insertunit(tsrc, ival; kwargs...) @@ -293,7 +292,7 @@ for insertunit in (:insertleftunit, :insertrightunit) function $insertunit_pullback(::NoRData) if isnothing(tsrc_cache) for (c, b) in blocks(Δtdst) - copy!(block(Δtsrc, c), b) + add!(block(Δtsrc, c), b) end else copy!(tsrc, tsrc_cache) @@ -320,11 +319,10 @@ function Mooncake.rrule!!(::CoDual{typeof(removeunit)}, tsrc_Δtsrc::CoDual{<:Ab # sharing address spaces if tsrc isa TensorMap tsrc_cache = copy(tsrc) - tdst = removeunit(tsrc, ival) - # note: this is somewhat of a hack that makes use of the fact that the tangent is - # encoded without any information about the space, which allows us to simply reuse - # the tangent exactly without having to modify the space information - tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + tdst_Δtdst = CoDual( + removeunit(tsrc, ival), + removeunit(Mooncake.tangent(tsrc_Δtsrc), ival) + ) else tsrc_cache = nothing tdst = removeunit(tsrc, ival) @@ -336,7 +334,7 @@ function Mooncake.rrule!!(::CoDual{typeof(removeunit)}, tsrc_Δtsrc::CoDual{<:Ab function removeunit_pullback(::NoRData) if isnothing(tsrc_cache) for (c, b) in blocks(Δtdst) - copy!(block(Δtsrc, c), b) + add!(block(Δtsrc, c), b) end else copy!(tsrc, tsrc_cache) @@ -360,11 +358,10 @@ function Mooncake.rrule!!( # sharing address spaces if tsrc isa TensorMap && !get(kwargs, :copy, false) tsrc_cache = copy(tsrc) - tdst = removeunit(tsrc, ival; kwargs...) - # note: this is somewhat of a hack that makes use of the fact that the tangent is - # encoded without any information about the space, which allows us to simply reuse - # the tangent exactly without having to modify the space information - tdst_Δtdst = CoDual(tdst, Mooncake.tangent(tsrc_Δtsrc)) + tdst_Δtdst = CoDual( + removeunit(tsrc, ival; kwargs...), + removeunit(Mooncake.tangent(tsrc_Δtsrc), ival) + ) else tsrc_cache = nothing tdst = removeunit(tsrc, ival; kwargs...) @@ -376,7 +373,7 @@ function Mooncake.rrule!!( function removeunit_pullback(::NoRData) if isnothing(tsrc_cache) for (c, b) in blocks(Δtdst) - copy!(block(Δtsrc, c), b) + add!(block(Δtsrc, c), b) end else copy!(tsrc, tsrc_cache) diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 93adff7e7..65f0cc7b9 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -19,41 +19,54 @@ end # fail for types that would be self-referential, which we choose to not support # for now. -Mooncake.@foldable Mooncake.tangent_type(::Type{T}) where {T <: TensorMap} = T Mooncake.@foldable Mooncake.tangent_type(::Type{T}, ::Type{NoRData}) where {T <: TensorMap} = T +Mooncake.@foldable Mooncake.tangent_type(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A} = + TK.tensormaptype(S, N₁, N₂, Mooncake.tangent_type(A)) -Mooncake.@foldable Mooncake.fdata_type(::Type{T}) where {T <: TensorMap} = T +Mooncake.@foldable Mooncake.fdata_type(::Type{T}) where {T <: TensorMap} = Mooncake.tangent_type(T) Mooncake.@foldable Mooncake.rdata_type(::Type{T}) where {T <: TensorMap} = NoRData Mooncake.tangent(t::TensorMap, ::NoRData) = t -Mooncake.zero_tangent_internal(t::TensorMap, ::Mooncake.MaybeCache) = zerovector(t) +Mooncake.zero_tangent_internal(t::TensorMap, c::Mooncake.MaybeCache) = + TensorMap(Mooncake.zero_tangent_internal(t.data, c), space(t)) -Mooncake.randn_tangent_internal(rng::AbstractRNG, p::TensorMap, ::Mooncake.MaybeCache) = - randn!(rng, similar(p)) +Mooncake.randn_tangent_internal(rng::AbstractRNG, p::TensorMap, c::Mooncake.MaybeCache) = + TensorMap(Mooncake.randn_tangent_internal(rng, p.data, c), space(p)) Mooncake.set_to_zero_internal!!(::Mooncake.SetToZeroCache, t::TensorMap) = zerovector!(t) -Mooncake.increment_internal!!(::Mooncake.IncCache, x::T, y::T) where {T <: TensorMap} = add!(x, y) - +function Mooncake.increment!!(x::TensorMap, y::TensorMap) + data = Mooncake.increment!!(x.data, y.data) + return x.data === data ? x : TensorMap(data, space(x)) +end +function Mooncake.increment_internal!!(c::Mooncake.IncCache, x::TensorMap, y::TensorMap) + data = Mooncake.increment_internal!!(c, x.data, y.data) + return x.data === data ? x : TensorMap(data, space(x)) +end -Mooncake._add_to_primal_internal(::Mooncake.MaybeCache, p::T, t::T, unsafe::Bool) where {T <: TensorMap} = add(p, t) -Mooncake.tangent_to_primal_internal!!(p::T, t::T, ::Mooncake.MaybeCache) where {T <: TensorMap} = copy!(p, t) +Mooncake._add_to_primal_internal(c::Mooncake.MaybeCache, p::TensorMap, t::TensorMap, unsafe::Bool) = + TensorMap(Mooncake._add_to_primal_internal(c, p.data, t.data, unsafe), space(p)) +function Mooncake.tangent_to_primal_internal!!(p::TensorMap, t::TensorMap, c::Mooncake.MaybeCache) + data = Mooncake.tangent_to_primal_internal!!(p.data, t.data, c) + data === p.data || copy!(p.data, data) + return p +end Mooncake.primal_to_tangent_internal!!(t::T, p::T, ::Mooncake.MaybeCache) where {T <: TensorMap} = copy!(t, p) -Mooncake._dot_internal(::Mooncake.MaybeCache, t::T, s::T) where {T <: TensorMap} = Float64(real(inner(t, s))) -Mooncake._scale_internal(::Mooncake.MaybeCache, a::Float64, t::T) where {T <: TensorMap} = scale(t, a) +Mooncake._dot_internal(::Mooncake.MaybeCache, t::TensorMap, s::TensorMap) = Float64(real(inner(t, s))) +Mooncake._scale_internal(::Mooncake.MaybeCache, a::Float64, t::TensorMap) = scale(t, a) -function Mooncake.TestUtils.populate_address_map_internal( - m::Mooncake.AddressMap, primal::T, tangent::T - ) where {T <: TensorMap} - return Mooncake.TestUtils.populate_address_map_internal(m, primal.data, tangent.data) -end +Mooncake.TestUtils.populate_address_map_internal(m::Mooncake.TestUtils.AddressMap, primal::TensorMap, tangent::TensorMap) = + Mooncake.populate_address_map_internal(m, primal.data, tangent.data) +@inline Mooncake.TestUtils.__get_data_field(t::TensorMap, n) = getfield(t, n) function Mooncake.__verify_fdata_value(::IdDict{Any, Nothing}, p::TensorMap, f::TensorMap) space(p) == space(f) || throw(Mooncake.InvalidFDataException(lazy"p has space $(space(p)) but f has size $(space(f))")) return nothing end - +function Mooncake.__verify_fdata_value(c::IdDict{Any, Nothing}, p::TensorMap, t::TensorMap) + return Mooncake.__verify_fdata_value(c, p.data, t.data) +end @is_primitive MinimalCtx Tuple{typeof(Mooncake.lgetfield), <:TensorMap, Val} From 96bca518bdf4776199e256697fcf6c7feae277f4 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 29 Jan 2026 17:20:13 -0500 Subject: [PATCH 35/54] bump versions --- Project.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7deccc56f..f73ca4768 100644 --- a/Project.toml +++ b/Project.toml @@ -45,7 +45,7 @@ FiniteDifferences = "0.12" GPUArrays = "11.3.1" LRUCache = "1.0.2" LinearAlgebra = "1" -MatrixAlgebraKit = "0.6.3" +MatrixAlgebraKit = "0.6.4" Mooncake = "0.5" OhMyThreads = "0.8.0" Printf = "1" @@ -86,3 +86,6 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"] + +[sources] +MatrixAlgebraKit = {url = "https://github.com/QuantumKitHub/MatrixAlgebraKit.jl", rev = "v0.6.4"} From aca99b2e9d514dae0b03e601b2fd72a2b58d684a Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 29 Jan 2026 17:43:42 -0500 Subject: [PATCH 36/54] deal with more complex sector shenanigans --- ext/TensorKitMooncakeExt/indexmanipulations.jl | 14 ++++++++------ src/tensors/indexmanipulations.jl | 2 ++ test/mooncake/indexmanipulations.jl | 11 +++++++---- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index 28c142522..b5938bc08 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -192,12 +192,13 @@ function Mooncake.rrule!!(::CoDual{typeof(flip)}, t_Δt::CoDual{<:AbstractTensor t_flipped_Δt_flipped = Mooncake.zero_fcodual(t_flipped) _, Δt_flipped = arrayify(t_flipped_Δt_flipped) - function twist_pullback(::NoRData) - add!(Δt, flip(Δt_flipped, inds; inv = !inv)) + function flip_pullback(::NoRData) + Δt_flipflipped = flip(Δt_flipped, inds; inv = !inv) + add!(Δt, scalartype(Δt) <: Real ? real(Δt_flipflipped) : Δt_flipflipped) return ntuple(Returns(NoRData()), 3) end - return t_flipped_Δt_flipped, twist_pullback + return t_flipped_Δt_flipped, flip_pullback end function Mooncake.rrule!!( ::CoDual{typeof(Core.kwcall)}, kwargs_Δkwargs::CoDual{@NamedTuple{inv::Bool}}, ::CoDual{typeof(flip)}, @@ -213,12 +214,13 @@ function Mooncake.rrule!!( t_flipped_Δt_flipped = Mooncake.zero_fcodual(t_flipped) _, Δt_flipped = arrayify(t_flipped_Δt_flipped) - function twist_pullback(::NoRData) - add!(Δt, flip(Δt_flipped, inds; inv = !inv)) + function flip_pullback(::NoRData) + Δt_flipflipped = flip(Δt_flipped, inds; inv = !inv) + add!(Δt, scalartype(Δt) <: Real ? real(Δt_flipflipped) : Δt_flipflipped) return ntuple(Returns(NoRData()), 5) end - return t_flipped_Δt_flipped, twist_pullback + return t_flipped_Δt_flipped, flip_pullback end for insertunit in (:insertleftunit, :insertrightunit) diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index 906ad6379..fa4699988 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -292,6 +292,8 @@ function twist!(t::AbstractTensorMap, inds; inv::Bool = false) msg = "Can't twist indices $inds of a tensor with only $(numind(t)) indices." throw(ArgumentError(msg)) end + (scalartype(t) <: Real && !(sectorscalartype(sectortype(t)) <: Real)) && + throw(ArgumentError("Can't in-place twist a real tensor with complex sector type")) has_shared_twist(t, inds) && return t (scalartype(t) <: Real && !(sectorscalartype(sectortype(t)) <: Real)) && diff --git a/test/mooncake/indexmanipulations.jl b/test/mooncake/indexmanipulations.jl index 945f4482b..614439b23 100644 --- a/test/mooncake/indexmanipulations.jl +++ b/test/mooncake/indexmanipulations.jl @@ -100,10 +100,13 @@ eltypes = (Float64, ComplexF64) @timedtestset "flip_n_twist!" begin A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), twist!, A, 1; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), twist!, A, [1, 3]; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, twist!, A, 1; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, twist!, A, [1, 3]; atol, rtol, mode) + + if !(T <: Real && !(sectorscalartype(sectortype(A)) <: Real)) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), twist!, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), twist!, A, [1, 3]; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, twist!, A, 1; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, twist!, A, [1, 3]; atol, rtol, mode) + end Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = false), flip, A, 1; atol, rtol, mode) Mooncake.TestUtils.test_rule(rng, Core.kwcall, (; inv = true), flip, A, [1, 3]; atol, rtol, mode) From 8c849537923e5f390e6f7b1de276a4ff7e7dee73 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 29 Jan 2026 17:57:27 -0500 Subject: [PATCH 37/54] properly accumulate --- ext/TensorKitMooncakeExt/tensoroperations.jl | 7 +++---- test/mooncake/tensoroperations.jl | 7 +++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 6c3f7442e..7c2dd1085 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -79,7 +79,7 @@ function blas_contract_pullback_ΔA!( ΔC, pΔC, false, tB, reverse(pB), true, ipA, - conj(α), Zero(), + conj(α), One(), backend, allocator ) @@ -106,7 +106,7 @@ function blas_contract_pullback_ΔB!( tA, reverse(pA), true, ΔC, pΔC, false, ipB, - conj(α), Zero(), backend, allocator + conj(α), One(), backend, allocator ) return NoRData() @@ -115,8 +115,7 @@ end function blas_contract_pullback_Δα( ΔC, A, pA, B, pB, pAB, α, backend, allocator ) - Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) - Tdα === NoRData && return NoRData() + _needs_tangent(α) || return NoRData() AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator) Δα = inner(AB, ΔC) diff --git a/test/mooncake/tensoroperations.jl b/test/mooncake/tensoroperations.jl index 922ac227a..3a4a1dd4a 100644 --- a/test/mooncake/tensoroperations.jl +++ b/test/mooncake/tensoroperations.jl @@ -1,6 +1,7 @@ using Test, TestExtras using TensorKit using TensorOperations +using VectorInterface: One, Zero using Mooncake using Random @@ -87,6 +88,12 @@ eltypes = (Float64, ComplexF64) T, A, pA, false, B, pB, false, pAB, Val(false) ) ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, A, pA, B, pB, pAB, One(), Zero(), + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) Mooncake.TestUtils.test_rule( rng, TensorKit.blas_contract!, C, A, pA, B, pB, pAB, α, β, From 15bf3326174cee21681c89ed39433698f7043f99 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 30 Jan 2026 09:48:58 -0500 Subject: [PATCH 38/54] nicer _needs_tangent --- ext/TensorKitMooncakeExt/tensoroperations.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 7c2dd1085..8dac1e4d8 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -189,8 +189,7 @@ end function trace_permute_pullback_Δα( ΔC, A, p, q, α, backend ) - Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) - Tdα === NoRData && return NoRData() + _needs_tangent(α) || return NoRData() # TODO: this result might be easier to compute as: # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α From 07f74b083a351bf029d2a90241d77e4a16b806a1 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 30 Jan 2026 09:49:15 -0500 Subject: [PATCH 39/54] remove source --- Project.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index f73ca4768..07a2b8a22 100644 --- a/Project.toml +++ b/Project.toml @@ -86,6 +86,3 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"] - -[sources] -MatrixAlgebraKit = {url = "https://github.com/QuantumKitHub/MatrixAlgebraKit.jl", rev = "v0.6.4"} From 497c8f62ecfe6ba1c55fc597131f9de0a8922939 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 30 Jan 2026 13:16:18 -0500 Subject: [PATCH 40/54] fix TensorOperations --- ext/TensorKitMooncakeExt/linalg.jl | 2 +- ext/TensorKitMooncakeExt/tensoroperations.jl | 2 +- ext/TensorKitMooncakeExt/utility.jl | 9 +++++++++ test/mooncake/tensoroperations.jl | 6 ++++++ 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 3d5ac8610..ff49e68e3 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -1,7 +1,7 @@ # Shared # ------ pullback_dC!(ΔC, β) = (scale!(ΔC, conj(β)); return NoRData()) -pullback_dβ(C, ΔC, β) = _needs_tangent(β) ? inner(C, ΔC) : NoRData() +pullback_dβ(C, ΔC, β) = _needs_tangent(β) ? project_scalar(β, inner(ΔC, C)) : NoRData() @is_primitive DefaultCtx ReverseMode Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number} diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 8dac1e4d8..c4468ef65 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -119,7 +119,7 @@ function blas_contract_pullback_Δα( AB = TO.tensorcontract(A, pA, false, B, pB, false, pAB, One(), backend, allocator) Δα = inner(AB, ΔC) - return Δα + return project_scalar(α, Δα) end # tensortrace! diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index bfbca5264..3f50bffa0 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -2,6 +2,15 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{T}) where {T <: Number} = Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + # IndexTuple utility # ------------------ trivtuple(N) = ntuple(identity, N) diff --git a/test/mooncake/tensoroperations.jl b/test/mooncake/tensoroperations.jl index 3a4a1dd4a..6802e4a73 100644 --- a/test/mooncake/tensoroperations.jl +++ b/test/mooncake/tensoroperations.jl @@ -100,6 +100,12 @@ eltypes = (Float64, ComplexF64) TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); atol, rtol, mode ) + T <: Complex && Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, A, pA, B, pB, pAB, real(α), real(β), + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) end end From 185b2e6f103211a682f96f0d02564f386219ba27 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 30 Jan 2026 13:16:22 -0500 Subject: [PATCH 41/54] remove duplicate method --- ext/TensorKitMooncakeExt/tangent.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 65f0cc7b9..fb859b471 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -59,12 +59,9 @@ Mooncake.TestUtils.populate_address_map_internal(m::Mooncake.TestUtils.AddressMa Mooncake.populate_address_map_internal(m, primal.data, tangent.data) @inline Mooncake.TestUtils.__get_data_field(t::TensorMap, n) = getfield(t, n) -function Mooncake.__verify_fdata_value(::IdDict{Any, Nothing}, p::TensorMap, f::TensorMap) - space(p) == space(f) || - throw(Mooncake.InvalidFDataException(lazy"p has space $(space(p)) but f has size $(space(f))")) - return nothing -end function Mooncake.__verify_fdata_value(c::IdDict{Any, Nothing}, p::TensorMap, t::TensorMap) + space(p) == space(t) || + throw(Mooncake.InvalidFDataException(lazy"p has space $(space(p)) but t has size $(space(t))")) return Mooncake.__verify_fdata_value(c, p.data, t.data) end From b5d7ab844760afb7e42cace421f106f82cb3c6bf Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 30 Jan 2026 14:51:16 -0500 Subject: [PATCH 42/54] fix arg order --- ext/TensorKitMooncakeExt/indexmanipulations.jl | 2 +- ext/TensorKitMooncakeExt/linalg.jl | 4 ++-- test/mooncake/planaroperations.jl | 4 ++++ 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index b5938bc08..76f2c126b 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -56,7 +56,7 @@ for transform in (:permute, :transpose) inner(Ap, ΔC) end - Δβr = pullback_dβ(C, ΔC, β) + Δβr = pullback_dβ(ΔC, C, β) ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() return NoRData(), ΔCr, ΔAr, NoRData(), Δαr, Δβr, map(Returns(NoRData()), ba)... diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index ff49e68e3..8f5306ac4 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -1,7 +1,7 @@ # Shared # ------ pullback_dC!(ΔC, β) = (scale!(ΔC, conj(β)); return NoRData()) -pullback_dβ(C, ΔC, β) = _needs_tangent(β) ? project_scalar(β, inner(ΔC, C)) : NoRData() +pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData() @is_primitive DefaultCtx ReverseMode Tuple{typeof(mul!), AbstractTensorMap, AbstractTensorMap, AbstractTensorMap, Number, Number} @@ -32,7 +32,7 @@ function Mooncake.rrule!!( ΔAr = NoRData() ΔBr = NoRData() Δαr = isnothing(AB) ? NoRData() : inner(AB, ΔC) - Δβr = pullback_dβ(C, ΔC, β) + Δβr = pullback_dβ(ΔC, C, β) ΔCr = pullback_dC!(ΔC, β) return NoRData(), ΔCr, ΔAr, ΔBr, Δαr, Δβr diff --git a/test/mooncake/planaroperations.jl b/test/mooncake/planaroperations.jl index dcc424b9a..98a9afe22 100644 --- a/test/mooncake/planaroperations.jl +++ b/test/mooncake/planaroperations.jl @@ -90,6 +90,10 @@ eltypes = (Float64, ComplexF64) T, A, pA, false, B, pB, false, pAB, Val(false) ) ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.planarcontract!, C, A, pA, B, pB, pAB, One(), Zero(); + atol, rtol, mode, is_primitive = false + ) Mooncake.TestUtils.test_rule( rng, TensorKit.planarcontract!, C, A, pA, B, pB, pAB, α, β; atol, rtol, mode, is_primitive = false From b5c4a6c0c030e3223b59a9f4b0eb2e00fceab589 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Fri, 30 Jan 2026 14:51:24 -0500 Subject: [PATCH 43/54] add missing ChainRules import --- ext/TensorKitMooncakeExt/tangent.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index fb859b471..8178bb64e 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -65,6 +65,8 @@ function Mooncake.__verify_fdata_value(c::IdDict{Any, Nothing}, p::TensorMap, t: return Mooncake.__verify_fdata_value(c, p.data, t.data) end +Mooncake.to_cr_tangent(x::TensorMap) = x + @is_primitive MinimalCtx Tuple{typeof(Mooncake.lgetfield), <:TensorMap, Val} # TODO: double-check if this has to include quantum dimensinos for non-abelian? From d8d51442c0337bd7d013b4fc30cb7f4fb979a994 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 31 Jan 2026 16:36:56 -0500 Subject: [PATCH 44/54] add JET compat --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 07a2b8a22..3881d9316 100644 --- a/Project.toml +++ b/Project.toml @@ -43,6 +43,7 @@ ChainRulesTestUtils = "1" Combinatorics = "1" FiniteDifferences = "0.12" GPUArrays = "11.3.1" +JET = "0.9, 0.10, 0.11" LRUCache = "1.0.2" LinearAlgebra = "1" MatrixAlgebraKit = "0.6.4" From 89df24d757b207eb994325ed138506822490ef9e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 1 Feb 2026 14:41:07 -0500 Subject: [PATCH 45/54] some cleanup --- .../indexmanipulations.jl | 16 ++------------- ext/TensorKitMooncakeExt/planaroperations.jl | 20 +++++-------------- test/mooncake/planaroperations.jl | 1 + 3 files changed, 8 insertions(+), 29 deletions(-) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index 76f2c126b..d8b737d56 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -49,13 +49,7 @@ for transform in (:permute, :transpose) TK.$add_transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...) ΔAr = NoRData() - # Δα - Δαr = if isnothing(Ap) - NoRData() - else - inner(Ap, ΔC) - end - + Δαr = isnothing(Ap) ? NoRData() : project_scalar(α, inner(Ap, ΔC)) Δβr = pullback_dβ(ΔC, C, β) ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() @@ -116,13 +110,7 @@ function Mooncake.rrule!!( TK.add_braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...) ΔAr = NoRData() - # Δα - Δαr = if isnothing(Ap) - NoRData() - else - inner(Ap, ΔC) - end - + Δαr = isnothing(Ap) ? NoRData() : project_scalar(α, inner(Ap, ΔC)) Δβr = pullback_dβ(C, ΔC, β) ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() diff --git a/ext/TensorKitMooncakeExt/planaroperations.jl b/ext/TensorKitMooncakeExt/planaroperations.jl index 9633dfad6..b0c929c65 100644 --- a/ext/TensorKitMooncakeExt/planaroperations.jl +++ b/ext/TensorKitMooncakeExt/planaroperations.jl @@ -36,8 +36,8 @@ function Mooncake.rrule!!( ΔAr = planartrace_pullback_ΔA!(ΔA, ΔC, A, p, q, α, backend, allocator) # this typically returns NoRData() Δαr = planartrace_pullback_Δα(ΔC, A, p, q, α, backend, allocator) - Δβr = planartrace_pullback_Δβ(ΔC, C, β) - ΔCr = planartrace_pullback_ΔC!(ΔC, β) # this typically returns NoRData() + Δβr = pullback_dβ(ΔC, C, β) + ΔCr = pullback_dC!(ΔC, β) # this typically returns NoRData() return NoRData(), ΔCr, ΔAr, NoRData(), NoRData(), @@ -47,15 +47,13 @@ function Mooncake.rrule!!( return C_ΔC, planartrace_pullback end -planartrace_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) - # TODO: Fix planartrace pullback # This implementation is slightly more involved than its non-planar counterpart # this is because we lack a general `pAB` argument in `planarcontract`, and need # to keep things planar along the way. # In particular, we can't simply tensor product with multiple identities in one go # if they aren't "contiguous", e.g. p = ((1, 4, 5), ()), q = ((2, 6), (3, 7)) -function planartrace_pullback_ΔA!( +function planartrace_pullback_dA!( ΔA, ΔC, A, p, q, α, backend, allocator ) if length(q[1]) == 0 @@ -77,7 +75,7 @@ function planartrace_pullback_ΔA!( error("The reverse rule for `planartrace` is not yet implemented") end -function planartrace_pullback_Δα( +function planartrace_pullback_dα( ΔC, A, p, q, α, backend, allocator ) Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) @@ -87,15 +85,7 @@ function planartrace_pullback_Δα( # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α At = TO.tensoralloc_add(scalartype(A), A, p, false, Val(true), allocator) TensorKit.planartrace!(At, A, p, q, One(), Zero(), backend, allocator) - Δα = inner(At, ΔC) + Δα = project_scalar(α, inner(At, ΔC)) TO.tensorfree!(At, allocator) return Δα end - -function planartrace_pullback_Δβ(ΔC, C, β) - Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) - Tdβ === NoRData && return NoRData() - - Δβ = inner(C, ΔC) - return Δβ -end diff --git a/test/mooncake/planaroperations.jl b/test/mooncake/planaroperations.jl index 98a9afe22..dc46c4c31 100644 --- a/test/mooncake/planaroperations.jl +++ b/test/mooncake/planaroperations.jl @@ -1,6 +1,7 @@ using Test, TestExtras using TensorKit using TensorOperations +using VectorInterface: Zero, One using Mooncake using Random From c677596e185b62f37a2262581766e539e525747c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 1 Feb 2026 14:41:18 -0500 Subject: [PATCH 46/54] more handling of scalartypes --- ext/TensorKitMooncakeExt/indexmanipulations.jl | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index d8b737d56..7e7e6049e 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -46,7 +46,14 @@ for transform in (:permute, :transpose) # ΔA ip = invperm(linearize(p)) pΔA = _repartition(ip, A) - TK.$add_transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...) + if scalartype(ΔA) <: Real && (!(scalartype(ΔC) <: Real) || !(scalartype(α) <: Real)) + TC = VectorInterface.promote_scale(ΔC, α) + ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false)) + TK.$add_transform!(ΔAc, ΔC, pΔA, conj(α), Zero(), ba...) + add!(ΔA, real(ΔAc)) + else + TK.$add_transform!(ΔA, ΔC, pΔA, conj(α), One(), ba...) + end ΔAr = NoRData() Δαr = isnothing(Ap) ? NoRData() : project_scalar(α, inner(Ap, ΔC)) @@ -107,7 +114,14 @@ function Mooncake.rrule!!( ip = invperm(linearize(p)) pΔA = _repartition(ip, A) ilevels = TupleTools.permute(levels, linearize(p)) - TK.add_braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...) + if scalartype(ΔA) <: Real && (!(scalartype(ΔC) <: Real) || !(scalartype(α) <: Real)) + TC = VectorInterface.promote_scale(ΔC, α) + ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false)) + TK.add_braid!(ΔAc, ΔC, pΔA, ilevels, conj(α), Zero(), ba...) + add!(ΔA, real(ΔAc)) + else + TK.add_braid!(ΔA, ΔC, pΔA, ilevels, conj(α), One(), ba...) + end ΔAr = NoRData() Δαr = isnothing(Ap) ? NoRData() : project_scalar(α, inner(Ap, ΔC)) From 7ecf6f6ff8ba57f6832c3d70f27f6d81826176c3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 1 Feb 2026 14:51:58 -0500 Subject: [PATCH 47/54] more testing --- test/mooncake/indexmanipulations.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/mooncake/indexmanipulations.jl b/test/mooncake/indexmanipulations.jl index 614439b23..c22b00b26 100644 --- a/test/mooncake/indexmanipulations.jl +++ b/test/mooncake/indexmanipulations.jl @@ -1,6 +1,7 @@ using Test, TestExtras using TensorKit using TensorOperations +using VectorInterface: Zero, One using Mooncake using Random @@ -78,7 +79,13 @@ eltypes = (Float64, ComplexF64) for _ in 1:5 p = randcircshift(numout(A), numin(A)) C = randn!(transpose(A, p)) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, One(), Zero(); atol, rtol, mode) Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + if !(T <: Real) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, real(A), p, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, real(α), β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, real(A), p, real(α), β; atol, rtol, mode) + end A = C end end From d5ad4aff3cd84e8fc59a27d7714c5c06414cf671 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 1 Feb 2026 20:31:17 -0500 Subject: [PATCH 48/54] add specialization for `MAK.zero!` --- src/factorizations/matrixalgebrakit.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/factorizations/matrixalgebrakit.jl b/src/factorizations/matrixalgebrakit.jl index 90776892c..b44f20653 100644 --- a/src/factorizations/matrixalgebrakit.jl +++ b/src/factorizations/matrixalgebrakit.jl @@ -60,6 +60,8 @@ for f! in ( end end +MAK.zero!(t::AbstractTensorMap) = zerovector!(t) + # Singular value decomposition # ---------------------------- function MAK.initialize_output(::typeof(svd_full!), t::AbstractTensorMap, ::AbstractAlgorithm) From 3a97abc29c1cc78a6a87df31fa8b1724d686d535 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 1 Feb 2026 20:31:49 -0500 Subject: [PATCH 49/54] add tests on factorizations --- test/mooncake/factorizations.jl | 218 ++++++++++++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 test/mooncake/factorizations.jl diff --git a/test/mooncake/factorizations.jl b/test/mooncake/factorizations.jl new file mode 100644 index 000000000..9380a9a1c --- /dev/null +++ b/test/mooncake/factorizations.jl @@ -0,0 +1,218 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using VectorInterface: Zero, One +using MatrixAlgebraKit +using Mooncake +using Random + +@isdefined(TestSetup) || include("../setup.jl") +using .TestSetup + +mode = Mooncake.ReverseMode +rng = Random.default_rng() + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) +eltypes = (Float64, ComplexF64) + +function remove_qrgauge_dependence!(ΔQ, t, Q) + for (c, b) in blocks(ΔQ) + m, n = size(block(t, c)) + minmn = min(m, n) + Qc = block(Q, c) + Q1 = view(Qc, 1:m, 1:minmn) + ΔQ2 = view(b, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + end + return ΔQ +end +function remove_lqgauge_dependence!(ΔQ, t, Q) + for (c, b) in blocks(ΔQ) + m, n = size(block(t, c)) + minmn = min(m, n) + Qc = block(Q, c) + Q1 = view(Qc, 1:minmn, 1:n) + ΔQ2 = view(b, (minmn + 1):n, :) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + end + return ΔQ +end +function remove_eiggauge_dependence!( + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) + ) + gaugepart = V' * ΔV + for (c, b) in blocks(gaugepart) + Dc = diagview(block(D, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_dependence!( + ΔV, D, V; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(D) + ) + gaugepart = project_antihermitian!(V' * ΔV) + for (c, b) in blocks(gaugepart) + Dc = diagview(block(D, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Dc)) .- diagview(Dc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Dc[i] - Dc[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔV, V, gaugepart, -1, 1) + return ΔV +end +function remove_svdgauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = MatrixAlgebraKit.default_pullback_degeneracy_atol(S) + ) + gaugepart = project_antihermitian!(U' * ΔU + Vᴴ * ΔVᴴ') + for (c, b) in blocks(gaugepart) + Sd = diagview(block(S, c)) + # for some reason this fails only on tests, and I cannot reproduce it in an + # interactive session. + # b[abs.(transpose(diagview(Sc)) .- diagview(Sc)) .>= degeneracy_atol] .= 0 + for j in axes(b, 2), i in axes(b, 1) + abs(Sd[i] - Sd[j]) >= degeneracy_atol && (b[i, j] = 0) + end + end + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end + +@timedtestset "Mooncake - Factorizations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes + atol = default_tol(T) + rtol = default_tol(T) + + @timedtestset "QR" begin + A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + + Mooncake.TestUtils.test_rule(rng, qr_compact, A; atol, rtol, mode, is_primitive = false) + + # qr_full/qr_null requires being careful with gauges + QR = qr_full(A) + ΔQR = Mooncake.randn_tangent(rng, QR) + remove_qrgauge_dependence!(ΔQR[1], A, QR[1]) + Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false) + # TODO: + # Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) + + A = randn(T, V[1] ⊗ V[2] ← V[1]) + + Mooncake.TestUtils.test_rule(rng, qr_compact, A; atol, rtol, mode, is_primitive = false) + + # qr_full/qr_null requires being careful with gauges + QR = qr_full(A) + ΔQR = Mooncake.randn_tangent(rng, QR) + remove_qrgauge_dependence!(ΔQR[1], A, QR[1]) + Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false) + # TODO: + # Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) + end + + @timedtestset "LQ" begin + A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) + + Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false) + + # qr_full/qr_null requires being careful with gauges + LQ = lq_full(A) + ΔLQ = Mooncake.randn_tangent(rng, LQ) + remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2]) + Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false) + # TODO: + # Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) + + A = randn(T, V[1] ⊗ V[2] ← V[1]) + + Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false) + + # qr_full/qr_null requires being careful with gauges + LQ = lq_full(A) + ΔLQ = Mooncake.randn_tangent(rng, LQ) + remove_lqgauge_dependence!(ΔLQ[2], A, LQ[2]) + Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false) + # TODO: + # Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) + end + + @timedtestset "Eigenvalue decomposition" begin + for t in (randn(T, V[1] ← V[1]), rand(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2])) + DV = eig_full(t) + ΔDV = Mooncake.randn_tangent(rng, DV) + remove_eiggauge_dependence!(ΔDV[2], DV...) + Mooncake.TestUtils.test_rule(rng, eig_full, t; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false) + + th = project_hermitian(t) + DV = eigh_full(th) + ΔDV = Mooncake.randn_tangent(rng, DV) + remove_eighgauge_dependence!(ΔDV[2], DV...) + Mooncake.TestUtils.test_rule(rng, eigh_full ∘ project_hermitian, th; output_tangent = ΔDV, atol, rtol, mode, is_primitive = false) + end + end + + @timedtestset "Singular value decomposition" begin + for t in (randn(T, V[1] ← V[1]), randn(T, V[1] ⊗ V[2] ← V[3] ⊗ V[4])) + USVᴴ = svd_compact(t) + ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ) + remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + Mooncake.TestUtils.test_rule(rng, svd_compact, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false) + + # USVᴴ = svd_full(t) + # ΔUSVᴴ = Mooncake.randn_tangent(rng, USVᴴ) + # remove_svdgauge_dependence!(ΔUSVᴴ[1], ΔUSVᴴ[3], USVᴴ...) + # Mooncake.TestUtils.test_rule(rng, svd_full, t; output_tangent = ΔUSVᴴ, atol, rtol, mode, is_primitive = false) + + V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc) + USVᴴtrunc = svd_trunc(t, alg) + ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc))) + remove_svdgauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...) + Mooncake.TestUtils.test_rule(rng, svd_trunc, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode) + end + end +end From 3c48a344e4778ab0faa156b04cc4ca1874238a0d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 2 Feb 2026 16:21:23 -0500 Subject: [PATCH 50/54] add DiagonalTensorMap tangent type --- ext/TensorKitMooncakeExt/tangent.jl | 105 ++++++++++++++++++++++------ test/mooncake/tangent.jl | 3 + 2 files changed, 88 insertions(+), 20 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 8178bb64e..5e3543f66 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -1,6 +1,9 @@ Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) = arrayify(primal(A_dA), tangent(A_dA)) Mooncake.arrayify(A::TensorMap, dA::TensorMap) = (A, dA) +Mooncake.arrayify(A_dA::CoDual{<:DiagonalTensorMap}) = arrayify(primal(A_dA), tangent(A_dA)) +Mooncake.arrayify(A::DiagonalTensorMap, dA::DiagonalTensorMap) = (A, dA) + function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TK.AdjointTensorMap}) Aᴴ = Mooncake.primal(Aᴴ_ΔAᴴ) ΔAᴴ = Mooncake.tangent(Aᴴ_ΔAᴴ) @@ -22,18 +25,27 @@ end Mooncake.@foldable Mooncake.tangent_type(::Type{T}, ::Type{NoRData}) where {T <: TensorMap} = T Mooncake.@foldable Mooncake.tangent_type(::Type{TensorMap{T, S, N₁, N₂, A}}) where {T, S, N₁, N₂, A} = TK.tensormaptype(S, N₁, N₂, Mooncake.tangent_type(A)) +Mooncake.@foldable Mooncake.tangent_type(::Type{T}, ::Type{NoRData}) where {T <: DiagonalTensorMap} = T +Mooncake.@foldable Mooncake.tangent_type(::Type{DiagonalTensorMap{T, S, A}}) where {T, S, A} = + DiagonalTensorMap{T, S, Mooncake.tangent_type(A)} + +const DiagOrTensorMap = Union{TensorMap, DiagonalTensorMap} -Mooncake.@foldable Mooncake.fdata_type(::Type{T}) where {T <: TensorMap} = Mooncake.tangent_type(T) -Mooncake.@foldable Mooncake.rdata_type(::Type{T}) where {T <: TensorMap} = NoRData +Mooncake.@foldable Mooncake.fdata_type(::Type{T}) where {T <: DiagOrTensorMap} = Mooncake.tangent_type(T) +Mooncake.@foldable Mooncake.rdata_type(::Type{T}) where {T <: DiagOrTensorMap} = NoRData -Mooncake.tangent(t::TensorMap, ::NoRData) = t +Mooncake.tangent(t::DiagOrTensorMap, ::NoRData) = t Mooncake.zero_tangent_internal(t::TensorMap, c::Mooncake.MaybeCache) = TensorMap(Mooncake.zero_tangent_internal(t.data, c), space(t)) +Mooncake.zero_tangent_internal(t::DiagonalTensorMap, c::Mooncake.MaybeCache) = + DiagonalTensorMap(Mooncake.zero_tangent_internal(t.data, c), space(t, 1)) Mooncake.randn_tangent_internal(rng::AbstractRNG, p::TensorMap, c::Mooncake.MaybeCache) = TensorMap(Mooncake.randn_tangent_internal(rng, p.data, c), space(p)) +Mooncake.randn_tangent_internal(rng::AbstractRNG, p::DiagonalTensorMap, c::Mooncake.MaybeCache) = + DiagonalTensorMap(Mooncake.randn_tangent_internal(rng, p.data, c), space(p, 1)) -Mooncake.set_to_zero_internal!!(::Mooncake.SetToZeroCache, t::TensorMap) = zerovector!(t) +Mooncake.set_to_zero_internal!!(::Mooncake.SetToZeroCache, t::DiagOrTensorMap) = zerovector!(t) function Mooncake.increment!!(x::TensorMap, y::TensorMap) data = Mooncake.increment!!(x.data, y.data) return x.data === data ? x : TensorMap(data, space(x)) @@ -42,6 +54,14 @@ function Mooncake.increment_internal!!(c::Mooncake.IncCache, x::TensorMap, y::Te data = Mooncake.increment_internal!!(c, x.data, y.data) return x.data === data ? x : TensorMap(data, space(x)) end +function Mooncake.increment!!(x::DiagonalTensorMap, y::DiagonalTensorMap) + data = Mooncake.increment!!(x.data, y.data) + return x.data === data ? x : DiagonalTensorMap(data, space(x, 1)) +end +function Mooncake.increment_internal!!(c::Mooncake.IncCache, x::DiagonalTensorMap, y::DiagonalTensorMap) + data = Mooncake.increment_internal!!(c, x.data, y.data) + return x.data === data ? x : DiagonalTensorMap(data, space(x, 1)) +end Mooncake._add_to_primal_internal(c::Mooncake.MaybeCache, p::TensorMap, t::TensorMap, unsafe::Bool) = TensorMap(Mooncake._add_to_primal_internal(c, p.data, t.data, unsafe), space(p)) @@ -51,29 +71,55 @@ function Mooncake.tangent_to_primal_internal!!(p::TensorMap, t::TensorMap, c::Mo return p end Mooncake.primal_to_tangent_internal!!(t::T, p::T, ::Mooncake.MaybeCache) where {T <: TensorMap} = copy!(t, p) +Mooncake._add_to_primal_internal(c::Mooncake.MaybeCache, p::DiagonalTensorMap, t::DiagonalTensorMap, unsafe::Bool) = + DiagonalTensorMap(Mooncake._add_to_primal_internal(c, p.data, t.data, unsafe), space(p)) +function Mooncake.tangent_to_primal_internal!!(p::DiagonalTensorMap, t::DiagonalTensorMap, c::Mooncake.MaybeCache) + data = Mooncake.tangent_to_primal_internal!!(p.data, t.data, c) + data === p.data || copy!(p.data, data) + return p +end +function Mooncake.primal_to_tangent_internal!!(t::TensorMap, p::TensorMap, c::Mooncake.MaybeCache) + data = Mooncake.primal_to_tangent_internal!!(t.data, p.data, c) + data === t.data || copy!(t.data, data) + return p +end +function Mooncake.primal_to_tangent_internal!!(t::DiagonalTensorMap, p::DiagonalTensorMap, c::Mooncake.MaybeCache) + data = Mooncake.primal_to_tangent_internal!!(t.data, p.data, c) + data === t.data || copy!(t.data, data) + return p +end Mooncake._dot_internal(::Mooncake.MaybeCache, t::TensorMap, s::TensorMap) = Float64(real(inner(t, s))) -Mooncake._scale_internal(::Mooncake.MaybeCache, a::Float64, t::TensorMap) = scale(t, a) +Mooncake._dot_internal(::Mooncake.MaybeCache, t::DiagonalTensorMap, s::DiagonalTensorMap) = Float64(real(inner(t, s))) +Mooncake._scale_internal(::Mooncake.MaybeCache, a::Float64, t::DiagOrTensorMap) = scale(t, a) Mooncake.TestUtils.populate_address_map_internal(m::Mooncake.TestUtils.AddressMap, primal::TensorMap, tangent::TensorMap) = Mooncake.populate_address_map_internal(m, primal.data, tangent.data) -@inline Mooncake.TestUtils.__get_data_field(t::TensorMap, n) = getfield(t, n) +Mooncake.TestUtils.populate_address_map_internal(m::Mooncake.TestUtils.AddressMap, primal::DiagonalTensorMap, tangent::DiagonalTensorMap) = + Mooncake.populate_address_map_internal(m, primal.data, tangent.data) +@inline Mooncake.TestUtils.__get_data_field(t::DiagOrTensorMap, n) = getfield(t, n) function Mooncake.__verify_fdata_value(c::IdDict{Any, Nothing}, p::TensorMap, t::TensorMap) space(p) == space(t) || throw(Mooncake.InvalidFDataException(lazy"p has space $(space(p)) but t has size $(space(t))")) return Mooncake.__verify_fdata_value(c, p.data, t.data) end +function Mooncake.__verify_fdata_value(c::IdDict{Any, Nothing}, p::DiagonalTensorMap, t::DiagonalTensorMap) + space(p) == space(t) || + throw(Mooncake.InvalidFDataException(lazy"p has space $(space(p)) but t has size $(space(t))")) + return Mooncake.__verify_fdata_value(c, p.data, t.data) +end -Mooncake.to_cr_tangent(x::TensorMap) = x +Mooncake.to_cr_tangent(x::DiagOrTensorMap) = x -@is_primitive MinimalCtx Tuple{typeof(Mooncake.lgetfield), <:TensorMap, Val} +@is_primitive MinimalCtx Tuple{typeof(Mooncake.lgetfield), <:DiagOrTensorMap, Val} # TODO: double-check if this has to include quantum dimensinos for non-abelian? function Mooncake.frule!!( - ::Dual{typeof(Mooncake.lgetfield)}, t::Dual{<:TensorMap}, ::Dual{Val{FieldName}} + ::Dual{typeof(Mooncake.lgetfield)}, t::Dual{<:DiagOrTensorMap}, ::Dual{Val{FieldName}} ) where {FieldName} - y = getfield(primal(t), FieldName) + val = getfield(primal(t), FieldName) + getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3)) return if FieldName === 1 || FieldName === :data dval = tangent(t).data @@ -86,7 +132,7 @@ function Mooncake.frule!!( end function Mooncake.rrule!!( - ::CoDual{typeof(Mooncake.lgetfield)}, t::CoDual{<:TensorMap}, ::CoDual{Val{FieldName}} + ::CoDual{typeof(Mooncake.lgetfield)}, t::CoDual{<:DiagOrTensorMap}, ::CoDual{Val{FieldName}} ) where {FieldName} val = getfield(primal(t), FieldName) getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3)) @@ -101,12 +147,12 @@ function Mooncake.rrule!!( end end -@is_primitive MinimalCtx Tuple{typeof(getfield), <:TensorMap, Any, Vararg{Symbol}} +@is_primitive MinimalCtx Tuple{typeof(getfield), <:DiagOrTensorMap, Any, Vararg{Symbol}} Base.@constprop :aggressive function Mooncake.frule!!( - ::Dual{typeof(getfield)}, t::Dual{<:TensorMap}, name::Dual + ::Dual{typeof(getfield)}, t::Dual{<:DiagOrTensorMap}, name::Dual ) - y = getfield(primal(t), primal(name)) + val = getfield(primal(t), primal(name)) return if primal(name) === 1 || primal(name) === :data dval = tangent(t).data @@ -119,7 +165,7 @@ Base.@constprop :aggressive function Mooncake.frule!!( end Base.@constprop :aggressive function Mooncake.rrule!!( - ::CoDual{typeof(getfield)}, t::CoDual{<:TensorMap}, name::CoDual + ::CoDual{typeof(getfield)}, t::CoDual{<:DiagOrTensorMap}, name::CoDual ) val = getfield(primal(t), primal(name)) getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 3)) @@ -135,7 +181,7 @@ Base.@constprop :aggressive function Mooncake.rrule!!( end Base.@constprop :aggressive function Mooncake.frule!!( - ::Dual{typeof(getfield)}, t::Dual{<:TensorMap}, name::Dual, order::Dual + ::Dual{typeof(getfield)}, t::Dual{<:DiagOrTensorMap}, name::Dual, order::Dual ) y = getfield(primal(t), primal(name), primal(order)) @@ -150,7 +196,7 @@ Base.@constprop :aggressive function Mooncake.frule!!( end Base.@constprop :aggressive function Mooncake.rrule!!( - ::CoDual{typeof(getfield)}, t::CoDual{<:TensorMap}, name::CoDual, order::CoDual + ::CoDual{typeof(getfield)}, t::CoDual{<:DiagOrTensorMap}, name::CoDual, order::CoDual ) val = getfield(primal(t), primal(name), primal(order)) getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 4)) @@ -166,11 +212,11 @@ Base.@constprop :aggressive function Mooncake.rrule!!( end -@is_primitive MinimalCtx Tuple{typeof(Mooncake.lgetfield), <:TensorMap, Val, Val} +@is_primitive MinimalCtx Tuple{typeof(Mooncake.lgetfield), <:DiagOrTensorMap, Val, Val} # TODO: double-check if this has to include quantum dimensinos for non-abelian? function Mooncake.frule!!( - ::Dual{typeof(Mooncake.lgetfield)}, t::Dual{<:TensorMap}, ::Dual{Val{FieldName}}, ::Dual{Val{Order}} + ::Dual{typeof(Mooncake.lgetfield)}, t::Dual{<:DiagOrTensorMap}, ::Dual{Val{FieldName}}, ::Dual{Val{Order}} ) where {FieldName, Order} y = getfield(primal(t), FieldName, Order) @@ -185,7 +231,7 @@ function Mooncake.frule!!( end function Mooncake.rrule!!( - ::CoDual{typeof(Mooncake.lgetfield)}, t::CoDual{<:TensorMap}, ::CoDual{Val{FieldName}}, ::CoDual{Val{Order}} + ::CoDual{typeof(Mooncake.lgetfield)}, t::CoDual{<:DiagOrTensorMap}, ::CoDual{Val{FieldName}}, ::CoDual{Val{Order}} ) where {FieldName, Order} val = getfield(primal(t), FieldName, Order) getfield_pullback = Mooncake.NoPullback(ntuple(Returns(NoRData()), 4)) @@ -218,3 +264,22 @@ function Mooncake.rrule!!( return Mooncake.zero_fcodual(Mooncake._new_(TensorMap{T, S, N₁, N₂, A}, primal(data), primal(space))), Returns(ntuple(Returns(NoRData()), 4)) end + + +Mooncake.@zero_derivative Mooncake.MinimalCtx Tuple{typeof(Mooncake._new_), Type{DiagonalTensorMap{T, S, A}}, UndefInitializer, S} where {T, S, A} +@is_primitive Mooncake.MinimalCtx Tuple{typeof(Mooncake._new_), Type{DiagonalTensorMap{T, S, A}}, A, S} where {T, S, A} + +function Mooncake.frule!!( + ::Dual{typeof(Mooncake._new_)}, ::Dual{Type{DiagonalTensorMap{T, S, A}}}, data::Dual{A}, space::Dual{S} + ) where {T, S, A} + t = Mooncake._new_(DiagonalTensorMap{T, S, A}, primal(data), primal(space)) + dt = Mooncake._new_(DiagonalTensorMap{T, S, A}, tangent(data), primal(space)) + return Dual(t, dt) +end + +function Mooncake.rrule!!( + ::CoDual{typeof(Mooncake._new_)}, ::CoDual{Type{DiagonalTensorMap{T, S, A}}}, data::CoDual{A}, space::CoDual{S} + ) where {T, S, A} + return Mooncake.zero_fcodual(Mooncake._new_(DiagonalTensorMap{T, S, A}, primal(data), primal(space))), + Returns(ntuple(Returns(NoRData()), 4)) +end diff --git a/test/mooncake/tangent.jl b/test/mooncake/tangent.jl index 5b001fc51..4bd763c08 100644 --- a/test/mooncake/tangent.jl +++ b/test/mooncake/tangent.jl @@ -55,4 +55,7 @@ eltypes = (Float64, ComplexF64) Sys.islinux() && @timedtestset "Mooncake - Tangent type: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes A = randn(T, V[1] ⊗ V[2] ← V[4] ⊗ V[5]) Mooncake.TestUtils.test_data(rng, A) + + D = DiagonalTensorMap{T}(undef, V[1]) + Mooncake.TestUtils.test_data(rng, D) end From 842179ff116ffba3eb5af05d3544437f3b2a8d31 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 2 Feb 2026 16:21:38 -0500 Subject: [PATCH 51/54] specialize SVD pullback implementations --- .../TensorKitMooncakeExt.jl | 2 + ext/TensorKitMooncakeExt/factorizations.jl | 63 +++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 ext/TensorKitMooncakeExt/factorizations.jl diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index 7067bb280..91e65186b 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -9,6 +9,7 @@ import TensorKit as TK using VectorInterface using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize import TensorOperations as TO +using MatrixAlgebraKit using TupleTools using Random: AbstractRNG @@ -19,5 +20,6 @@ include("indexmanipulations.jl") include("vectorinterface.jl") include("tensoroperations.jl") include("planaroperations.jl") +include("factorizations.jl") end diff --git a/ext/TensorKitMooncakeExt/factorizations.jl b/ext/TensorKitMooncakeExt/factorizations.jl new file mode 100644 index 000000000..3bb1b3ae3 --- /dev/null +++ b/ext/TensorKitMooncakeExt/factorizations.jl @@ -0,0 +1,63 @@ +for f in (:svd_compact, :svd_full) + f_pullback = Symbol(f, :_pullback) + @eval begin + @is_primitive DefaultCtx ReverseMode Tuple{typeof($f), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual) + A, dA = arrayify(A_dA) + alg = primal(alg_dalg) + + USVᴴ = $f(A, primal(alg_dalg)) + USVᴴ_dUSVᴴ = Mooncake.zero_fcodual(USVᴴ) + dUSVᴴ = last.(arrayify.(USVᴴ, tangent(USVᴴ_dUSVᴴ))) + + function $f_pullback(::NoRData) + MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴ) + MatrixAlgebraKit.zero!.(dUSVᴴ) + return ntuple(Returns(NoRData()), 3) + end + + return USVᴴ_dUSVᴴ, $f_pullback + end + end + + # mutating version is not guaranteed to actually mutate + # so we can simply use the non-mutating version instead and avoid having to worry about + # storing copies and restoring state + f! = Symbol(f, :!) + f!_pullback = Symbol(f!, :_pullback) + @eval begin + @is_primitive DefaultCtx ReverseMode Tuple{typeof($f!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm} + Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) = + Mooncake.rrule!!(Mooncake.zero_fcodual($f), A_dA, alg_dalg) + end +end + +@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!( + ::CoDual{typeof(svd_trunc)}, + A_dA::CoDual{<:AbstractTensorMap}, + alg_dalg::CoDual{<:MatrixAlgebraKit.TruncatedAlgorithm} + ) + A, dA = arrayify(A_dA) + alg = primal(alg_dalg) + + USVᴴ = svd_compact(A, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) + + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ)) + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(tangent(USVᴴtrunc_dUSVᴴtrunc)))) + + function svd_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) + abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || + @warn "Gradient for `svd_trunc` ignores non-zero tangents for truncation error" + MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) + return ntuple(Returns(NoRData()), 3) + end + + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback +end + +@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm} +Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) = + Mooncake.rrule!!(Mooncake.zero_fcodual(svd_trunc), A_dA, alg_dalg) From 50f4e70439eddd4e10ef82181a5d7c9ed502408c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 3 Feb 2026 10:17:58 -0500 Subject: [PATCH 52/54] careful about projections --- .../indexmanipulations.jl | 9 ++++--- ext/TensorKitMooncakeExt/linalg.jl | 6 ++--- ext/TensorKitMooncakeExt/utility.jl | 24 +++++++++++++++++ test/mooncake/indexmanipulations.jl | 5 ++++ test/mooncake/tensoroperations.jl | 26 ++++++++++++++----- 5 files changed, 57 insertions(+), 13 deletions(-) diff --git a/ext/TensorKitMooncakeExt/indexmanipulations.jl b/ext/TensorKitMooncakeExt/indexmanipulations.jl index 7e7e6049e..70ab384ee 100644 --- a/ext/TensorKitMooncakeExt/indexmanipulations.jl +++ b/ext/TensorKitMooncakeExt/indexmanipulations.jl @@ -46,8 +46,9 @@ for transform in (:permute, :transpose) # ΔA ip = invperm(linearize(p)) pΔA = _repartition(ip, A) - if scalartype(ΔA) <: Real && (!(scalartype(ΔC) <: Real) || !(scalartype(α) <: Real)) - TC = VectorInterface.promote_scale(ΔC, α) + + TC = VectorInterface.promote_scale(ΔC, α) + if scalartype(ΔA) <: Real && !(TC <: Real) ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false)) TK.$add_transform!(ΔAc, ΔC, pΔA, conj(α), Zero(), ba...) add!(ΔA, real(ΔAc)) @@ -114,8 +115,8 @@ function Mooncake.rrule!!( ip = invperm(linearize(p)) pΔA = _repartition(ip, A) ilevels = TupleTools.permute(levels, linearize(p)) - if scalartype(ΔA) <: Real && (!(scalartype(ΔC) <: Real) || !(scalartype(α) <: Real)) - TC = VectorInterface.promote_scale(ΔC, α) + TC = VectorInterface.promote_scale(ΔC, α) + if scalartype(ΔA) <: Real && !(TC <: Real) ΔAc = TO.tensoralloc_add(TC, ΔC, pΔA, false, Val(false)) TK.add_braid!(ΔAc, ΔC, pΔA, ilevels, conj(α), Zero(), ba...) add!(ΔA, real(ΔAc)) diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl index 8f5306ac4..6af92f7be 100644 --- a/ext/TensorKitMooncakeExt/linalg.jl +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -27,11 +27,11 @@ function Mooncake.rrule!!( function mul_pullback(::NoRData) copy!(C, C_cache) - mul!(ΔA, ΔC, B', conj(α), One()) - mul!(ΔB, A', ΔC, conj(α), One()) + project_mul!(ΔA, ΔC, B', conj(α)) + project_mul!(ΔB, A', ΔC, conj(α)) ΔAr = NoRData() ΔBr = NoRData() - Δαr = isnothing(AB) ? NoRData() : inner(AB, ΔC) + Δαr = isnothing(AB) ? NoRData() : project_scalar(α, inner(AB, ΔC)) Δβr = pullback_dβ(ΔC, C, β) ΔCr = pullback_dC!(ΔC, β) diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index 3f50bffa0..bfbeae805 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -2,6 +2,8 @@ _needs_tangent(x) = _needs_tangent(typeof(x)) _needs_tangent(::Type{T}) where {T <: Number} = Mooncake.rdata_type(Mooncake.tangent_type(T)) !== NoRData +# Projection +# ---------- """ project_scalar(x::Number, dx::Number) @@ -11,6 +13,28 @@ For example, we might compute a complex `dx` but only require the real part. project_scalar(x::Number, dx::Number) = oftype(x, dx) project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) +# in-place multiplication and accumulation which might project to (real) +# TODO: this could probably be done without allocating +function project_mul!(C, A, B, α) + TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α)) + return if !(TC <: Real) && scalartype(C) <: Real + add!(C, real(mul!(zerovector(C, TC), A, B, α))) + else + mul!(C, A, B, α, One()) + end +end +function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α) + TA = TensorKit.promote_permute(A) + TB = TensorKit.promote_permute(B) + TC = TO.promote_contract(TA, TB, scalartype(α)) + + return if scalartype(C) <: Real && !(TC <: Real) + add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero()))) + else + TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One()) + end +end + # IndexTuple utility # ------------------ trivtuple(N) = ntuple(identity, N) diff --git a/test/mooncake/indexmanipulations.jl b/test/mooncake/indexmanipulations.jl index c22b00b26..106c1b4b7 100644 --- a/test/mooncake/indexmanipulations.jl +++ b/test/mooncake/indexmanipulations.jl @@ -101,6 +101,11 @@ eltypes = (Float64, ComplexF64) levels = tuple(randperm(numind(A))) C = randn!(transpose(A, p)) Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, α, β; atol, rtol, mode) + if !(T <: Real) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, real(A), p, α, β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, real(α), β; atol, rtol, mode) + Mooncake.TestUtils.test_rule(rng, TensorKit.add_transpose!, C, A, p, real(α), real(β); atol, rtol, mode) + end A = C end end diff --git a/test/mooncake/tensoroperations.jl b/test/mooncake/tensoroperations.jl index 6802e4a73..df9d2a50b 100644 --- a/test/mooncake/tensoroperations.jl +++ b/test/mooncake/tensoroperations.jl @@ -100,12 +100,26 @@ eltypes = (Float64, ComplexF64) TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); atol, rtol, mode ) - T <: Complex && Mooncake.TestUtils.test_rule( - rng, TensorKit.blas_contract!, - C, A, pA, B, pB, pAB, real(α), real(β), - TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); - atol, rtol, mode - ) + if !(T <: Real) + Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, A, pA, B, pB, pAB, real(α), real(β), + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, real(A), pA, B, pB, pAB, real(α), real(β), + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) + Mooncake.TestUtils.test_rule( + rng, TensorKit.blas_contract!, + C, A, pA, real(B), pB, pAB, real(α), real(β), + TensorOperations.DefaultBackend(), TensorOperations.DefaultAllocator(); + atol, rtol, mode + ) + end end end From 22593e54fb7b15b0f9a00c68006c46e813afb3de Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 3 Feb 2026 13:04:41 -0500 Subject: [PATCH 53/54] disable mooncake tests on Apple --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 8f58d7dc8..ad7b4006e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,7 +57,7 @@ istestfile(fn) = endswith(fn, ".jl") && !contains(fn, "setup") # somehow AD tests are unreasonably slow on Apple CI # and ChainRulesTestUtils doesn't like prereleases - if group == "chainrules" + if group == "chainrules" || group == "mooncake" Sys.isapple() && get(ENV, "CI", "false") == "true" && continue isempty(VERSION.prerelease) || continue end From 977bc2c4ea366b5b957dabb7a4b0205764a77686 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 3 Feb 2026 16:22:27 -0500 Subject: [PATCH 54/54] add missing diagonal constructor --- ext/TensorKitMooncakeExt/tangent.jl | 32 ++++++++--------------------- src/tensors/diagonal.jl | 7 +++++-- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 5e3543f66..3a9cd9b5d 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -124,10 +124,8 @@ function Mooncake.frule!!( return if FieldName === 1 || FieldName === :data dval = tangent(t).data Dual(val, dval) - elseif FieldName === 2 || FieldName === :space + else # cannot be invalid fieldname since already called `getfield` Dual(val, NoFData()), getfield_pullback - else - throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) end end @@ -140,10 +138,8 @@ function Mooncake.rrule!!( return if FieldName === 1 || FieldName === :data dval = Mooncake.tangent(t).data CoDual(val, dval), getfield_pullback - elseif FieldName === 2 || FieldName === :space + else # cannot be invalid fieldname since already called `getfield` Mooncake.zero_fcodual(val), getfield_pullback - else - throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) end end @@ -157,10 +153,8 @@ Base.@constprop :aggressive function Mooncake.frule!!( return if primal(name) === 1 || primal(name) === :data dval = tangent(t).data Dual(val, dval) - elseif primal(name) === 2 || primal(name) === :space + else # cannot be invalid fieldname since already called `getfield` Dual(val, NoFData()) - else - throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) end end @@ -173,10 +167,8 @@ Base.@constprop :aggressive function Mooncake.rrule!!( return if primal(name) === 1 || primal(name) === :data dval = Mooncake.tangent(t).data CoDual(val, dval), getfield_pullback - elseif primal(name) === 2 || primal(name) === :space + else # cannot be invalid fieldname since already called `getfield` Mooncake.zero_fcodual(val), getfield_pullback - else - throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) end end @@ -188,10 +180,8 @@ Base.@constprop :aggressive function Mooncake.frule!!( return if primal(name) === 1 || primal(name) === :data dval = tangent(t).data Dual(val, dval) - elseif primal(name) === 2 || primal(name) === :space + else # cannot be invalid fieldname since already called `getfield` Dual(val, NoFData()) - else - throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) end end @@ -204,10 +194,8 @@ Base.@constprop :aggressive function Mooncake.rrule!!( return if primal(name) === 1 || primal(name) === :data dval = Mooncake.tangent(t).data CoDual(val, dval), getfield_pullback - elseif primal(name) === 2 || primal(name) === :space + else # cannot be invalid fieldname since already called `getfield` Mooncake.zero_fcodual(val), getfield_pullback - else - throw(ArgumentError(lazy"Invalid fieldname `$(primal(name))`")) end end @@ -223,10 +211,8 @@ function Mooncake.frule!!( return if FieldName === 1 || FieldName === :data dval = tangent(t).data Dual(val, dval) - elseif FieldName === 2 || FieldName === :space + else # cannot be invalid fieldname since already called `getfield` Dual(val, NoFData()) - else - throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) end end @@ -239,10 +225,8 @@ function Mooncake.rrule!!( return if FieldName === 1 || FieldName === :data dval = Mooncake.tangent(t).data CoDual(val, dval), getfield_pullback - elseif FieldName === 2 || FieldName === :space + else # cannot be invalid fieldname since already called `getfield` Mooncake.zero_fcodual(val), getfield_pullback - else - throw(ArgumentError(lazy"Invalid fieldname `$FieldName`")) end end diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index 447a4b8eb..2461f2cab 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -45,8 +45,7 @@ function DiagonalTensorMap{T}(::UndefInitializer, V::TensorMapSpace) where {T} return DiagonalTensorMap{T}(undef, domain(V)) end function DiagonalTensorMap{T}(::UndefInitializer, V::ProductSpace) where {T} - length(V) == 1 || - throw(ArgumentError("DiagonalTensorMap requires `numin(d) == numout(d) == 1`")) + length(V) == 1 || throw(ArgumentError("DiagonalTensorMap requires `numin(d) == numout(d) == 1`")) return DiagonalTensorMap{T}(undef, only(V)) end function DiagonalTensorMap{T}(::UndefInitializer, V::S) where {T, S <: IndexSpace} @@ -63,6 +62,10 @@ end function DiagonalTensorMap(data::DenseVector{T}, V::IndexSpace) where {T} return DiagonalTensorMap{T}(data, V) end +function DiagonalTensorMap(data::DenseVector{T}, V::TensorMapSpace) where {T} + (numin(V) == numout(V) == 1) || throw(ArgumentError("DiagonalTensorMap requires `numin(d) == numout(d) == 1`")) + return DiagonalTensorMap{T}(data, V[1]) +end function DiagonalTensorMap(t::AbstractTensorMap{T, S, 1, 1}) where {T, S} isa(t, DiagonalTensorMap) && return t