diff --git a/docs/src/changelog.md b/docs/src/changelog.md index 9147c4da..3283f9de 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -24,6 +24,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec ### Changed +- The Mooncake rules for truncated decompositions with `TruncatedAlgorithm` now use the pullbacks that make use of the full decomposition. ([#171](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/171)) + ### Deprecated ### Removed diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index e4ec256f..3a113c20 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -10,9 +10,9 @@ using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback! +using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra - Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any} @@ -167,14 +167,24 @@ for (f!, f, f_full, pb, adj) in ( end end -for (f!, f, f_ne!, f_ne, pb, adj) in ( - (:eig_trunc!, :eig_trunc, :eig_trunc_no_error!, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint), - (:eigh_trunc!, :eigh_trunc, :eigh_trunc_no_error!, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint), - ) +_warn_pullback_truncerror(dϵ::Real; tol = MatrixAlgebraKit.defaulttol(dϵ)) = + abs(dϵ) ≤ tol || @warn "Pullback ignores non-zero tangents for truncation error" + +for f in (:eig, :eigh) + f_trunc = Symbol(f, :_trunc) + f_trunc! = Symbol(f_trunc, :!) + f_full = Symbol(f, :_full) + f_full! = Symbol(f_full, :!) + f_pullback! = Symbol(f, :_pullback!) + f_trunc_pullback! = Symbol(f_trunc, :_pullback!) + f_adjoint! = Symbol(f, :_adjoint!) + f_trunc_no_error = Symbol(f_trunc, :_no_error) + f_trunc_no_error! = Symbol(f_trunc_no_error, :!) + @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} - function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) DV = Mooncake.primal(DV_dDV) @@ -182,54 +192,115 @@ for (f!, f, f_ne!, f_ne, pb, adj) in ( Ac = copy(A) DVc = copy.(DV) alg = Mooncake.primal(alg_dalg) - output = $f!(A, DV, alg) + output = $f_trunc!(A, DV, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) - output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} + output_codual = Mooncake.zero_fcodual(output) + function $f_adjoint!(dy::Tuple{NoRData, NoRData, <:Real}) copy!(A, Ac) Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" + _warn_pullback_truncerror(dy[3]) D′, dD′ = arrayify(Dtrunc, dDtrunc_) V′, dV′ = arrayify(Vtrunc, dVtrunc_) - $pb(dA, A, (D′, V′), (dD′, dV′)) + $f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′)) copy!(DV[1], DVc[1]) copy!(DV[2], DVc[2]) zero!(dD′) zero!(dV′) return NoRData(), NoRData(), NoRData(), NoRData() end - return output_codual, $adj + return output_codual, $f_adjoint! end - function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) + function Mooncake.rrule!!(::CoDual{typeof($f_trunc!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + DV_dDV_arr = arrayify.(Mooncake.primal(DV_dDV), Mooncake.tangent(DV_dDV)) + DV, dDV = first.(DV_dDV_arr), last.(DV_dDV_arr) + alg = Mooncake.primal(alg_dalg) + + # store state prior to primal call + Ac = copy(A) + DVc = copy.(DV) + + # compute primal - capture full DV and ind + DV = $f_full!(A, DV, alg.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(DV[1]), ind) + + # pack output - note that we allocate new dDVtrunc because these aren't overwritten in the input + DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ)) + + # define pullback + dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc)))) + function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real}) + _warn_pullback_truncerror(dϵ) + + # compute pullbacks + $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + + # restore state + copy!(A, Ac) + copy!.(DV, DVc) + + return ntuple(Returns(NoRData()), 4) + end + + return DVtrunc_dDVtrunc, $f_adjoint! + end + function Mooncake.rrule!!(::CoDual{typeof($f_trunc)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) - output = $f(A, alg) + output = $f_trunc(A, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} + function $f_adjoint!(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" + _warn_pullback_truncerror(dy[3]) D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) - $pb(dA, A, (D, V), (dD, dV)) + $f_trunc_pullback!(dA, A, (D, V), (dD, dV)) zero!(dD) zero!(dV) return NoRData(), NoRData(), NoRData() end - return output_codual, $adj + return output_codual, $f_adjoint! end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm} - function Mooncake.rrule!!(::CoDual{typeof($f_ne!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) + function Mooncake.rrule!!(::CoDual{typeof($f_trunc)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + + # compute primal - capture full DV and ind + DV = $f_full(A, alg.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(DV[1]), ind) + + # pack output + DVtrunc_dDVtrunc = Mooncake.zero_fcodual((DVtrunc..., ϵ)) + + # define pullback + dDVtrunc = last.(arrayify.(DVtrunc, Base.front(Mooncake.tangent(DVtrunc_dDVtrunc)))) + function $f_adjoint!((_, _, dϵ)::Tuple{NoRData, NoRData, Real}) + _warn_pullback_truncerror(dϵ) + $f_pullback!(dA, A, DV, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) + end + + return DVtrunc_dDVtrunc, $f_adjoint! + end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) @@ -237,48 +308,104 @@ for (f!, f, f_ne!, f_ne, pb, adj) in ( dDV = Mooncake.tangent(DV_dDV) Ac = copy(A) DVc = copy.(DV) - output = $f_ne!(A, DV, alg) + output = $f_trunc_no_error!(A, DV, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function $adj(::NoRData) + function $f_adjoint!(::NoRData) copy!(A, Ac) Dtrunc, Vtrunc = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual) D′, dD′ = arrayify(Dtrunc, dDtrunc_) V′, dV′ = arrayify(Vtrunc, dVtrunc_) - $pb(dA, A, (D′, V′), (dD′, dV′)) + $f_pullback!(dA, A, (D′, V′), (dD′, dV′)) copy!(DV[1], DVc[1]) copy!(DV[2], DVc[2]) zero!(dD′) zero!(dV′) return NoRData(), NoRData(), NoRData(), NoRData() end - return output_codual, $adj + return output_codual, $f_adjoint! + end + function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + DV_dDV_arr = arrayify.(Mooncake.primal(DV_dDV), Mooncake.tangent(DV_dDV)) + DV, dDV = first.(DV_dDV_arr), last.(DV_dDV_arr) + alg = Mooncake.primal(alg_dalg) + + # store state prior to primal call + Ac = copy(A) + DVc = copy.(DV) + + # compute primal - capture full DV and ind + DV = $f_full!(A, DV, alg.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc) + + # pack output - note that we allocate new dDVtrunc because these aren't overwritten in the input + DVtrunc_dDVtrunc = Mooncake.zero_fcodual(DVtrunc) + + # define pullback + dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) + function $f_adjoint!(::NoRData) + # compute pullbacks + $f_pullback!(dA, Ac, DVc, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + + # restore state + copy!(A, Ac) + copy!.(DV, DVc) + + return ntuple(Returns(NoRData()), 4) + end + + return DVtrunc_dDVtrunc, $f_adjoint! end - function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual) + function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) - output = $f_ne(A, alg) + output = $f_trunc_no_error(A, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function $adj(::NoRData) + function $f_adjoint!(::NoRData) Dtrunc, Vtrunc = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual) D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) - $pb(dA, A, (D, V), (dD, dV)) + $f_trunc_pullback!(dA, A, (D, V), (dD, dV)) zero!(dD) zero!(dV) return NoRData(), NoRData(), NoRData() end - return output_codual, $adj + return output_codual, $f_adjoint! + end + function Mooncake.rrule!!(::CoDual{typeof($f_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + + # compute primal - capture full DV and ind + DV = $f_full(A, alg.alg) + DVtrunc, ind = MatrixAlgebraKit.truncate($f_trunc!, DV, alg.trunc) + + # pack output + DVtrunc_dDVtrunc = Mooncake.zero_fcodual(DVtrunc) + + # define pullback + dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) + function $f_adjoint!(::NoRData) + $f_pullback!(dA, A, DV, dDVtrunc, ind) + zero!.(dDVtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) + end + + return DVtrunc_dDVtrunc, $f_adjoint! end end end @@ -419,7 +546,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS copy!(A, Ac) Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error" + _warn_pullback_truncerror(dy[4]) U′, dU′ = arrayify(Utrunc, dUtrunc_) S′, dS′ = arrayify(Strunc, dStrunc_) Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_) @@ -437,6 +564,45 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS end return output_codual, svd_trunc_adjoint end +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + USVᴴ_dUSVᴴ_arr = arrayify.(Mooncake.primal(USVᴴ_dUSVᴴ), Mooncake.tangent(USVᴴ_dUSVᴴ)) + USVᴴ, dUSVᴴ = first.(USVᴴ_dUSVᴴ_arr), last.(USVᴴ_dUSVᴴ_arr) + alg = Mooncake.primal(alg_dalg) + + # store state prior to primal call + Ac = copy(A) + USVᴴc = copy.(USVᴴ) + + # compute primal - capture full USVᴴ and ind + USVᴴ = svd_compact!(A, USVᴴ, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) + + # pack output - note that we allocate new dUSVᴴtrunc because these aren't actually + # overwritten in the input! + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ)) + + # define pullback + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))) + function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) + _warn_pullback_truncerror(dϵ) + + # compute pullbacks + svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + zero!.(dUSVᴴ) + + # restore state + copy!(A, Ac) + copy!.(USVᴴ, USVᴴc) + + return ntuple(Returns(NoRData()), 4) + end + + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint +end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) @@ -452,7 +618,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error" + _warn_pullback_truncerror(dy[4]) U, dU = arrayify(Utrunc, dUtrunc_) S, dS = arrayify(Strunc, dStrunc_) Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) @@ -464,6 +630,30 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C end return output_codual, svd_trunc_adjoint end +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + + # compute primal - capture full USVᴴ and ind + USVᴴ = svd_compact(A, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) + + # pack output + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ)) + + # define pullback + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))) + function svd_trunc_adjoint((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) + _warn_pullback_truncerror(dϵ) + svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) + end + + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint +end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) @@ -504,6 +694,42 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U end return output_codual, svd_trunc_adjoint end +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + USVᴴ_dUSVᴴ_arr = arrayify.(Mooncake.primal(USVᴴ_dUSVᴴ), Mooncake.tangent(USVᴴ_dUSVᴴ)) + USVᴴ, dUSVᴴ = first.(USVᴴ_dUSVᴴ_arr), last.(USVᴴ_dUSVᴴ_arr) + alg = Mooncake.primal(alg_dalg) + + # store state prior to primal call + Ac = copy(A) + USVᴴc = copy.(USVᴴ) + + # compute primal - capture full USVᴴ and ind + USVᴴ = svd_compact!(A, USVᴴ, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + + # pack output - note that we allocate new dUSVᴴtrunc because these aren't actually + # overwritten in the input! + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual(USVᴴtrunc) + + # define pullback + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))) + function svd_trunc_adjoint(::NoRData) + # compute pullbacks + svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + zero!.(dUSVᴴ) + + # restore state + copy!(A, Ac) + copy!.(USVᴴ, USVᴴc) + + return ntuple(Returns(NoRData()), 4) + end + + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint +end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual) @@ -530,5 +756,27 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al end return output_codual, svd_trunc_adjoint end +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual{<:TruncatedAlgorithm}) + # unpack variables + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + + # compute primal - capture full USVᴴ and ind + USVᴴ = svd_compact(A, alg.alg) + USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + + # pack output + USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual(USVᴴtrunc) + + # define pullback + dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc))) + function svd_trunc_adjoint(::NoRData) + svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) + zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required + return ntuple(Returns(NoRData()), 3) + end + + return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint +end end