Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
58b1526
add `arrayify` for adjoint tensor
lkdvos Jan 17, 2026
1e54e1e
add vectorinterface rules
lkdvos Jan 17, 2026
28e6c1d
add tensoroperations rules
lkdvos Jan 17, 2026
94112b8
add indexmanipulations rules
lkdvos Jan 18, 2026
c31996a
add mul rules
lkdvos Jan 20, 2026
ec0fe09
temporarily disable Fibonacci (complex) spaces
lkdvos Jan 20, 2026
4642ff7
bump TupleTools compat
lkdvos Jan 20, 2026
2f98e0d
add twist! rule
lkdvos Jan 21, 2026
10a50cb
add flip rule
lkdvos Jan 21, 2026
d7c050a
vector spaces arent vector spaces!
lkdvos Jan 21, 2026
c1a1e8b
insert and remove units
lkdvos Jan 21, 2026
fac47ed
mark a bunch of things as non-differentiable
lkdvos Jan 21, 2026
9b90eb6
rewrite rule for `tensortrace!` in terms of `trace_permute!`
lkdvos Jan 21, 2026
bb287fe
dont need rules for `tensoradd!`
lkdvos Jan 21, 2026
a8f8a20
add planaroperations
lkdvos Jan 22, 2026
8e1993e
rewrite rule `tensorcontract` in terms of `blas_contract!`
lkdvos Jan 22, 2026
70cdc55
add rule `tr`
lkdvos Jan 22, 2026
f85c946
give up on planartrace for now
lkdvos Jan 22, 2026
8117c76
add rule `inv`
lkdvos Jan 22, 2026
f550cfd
is_primitive in namespace
lkdvos Jan 22, 2026
a821745
share more code
lkdvos Jan 22, 2026
a17a55d
split AD tests to reduce CI pressure
lkdvos Jan 22, 2026
b32a71c
add missing imports
lkdvos Jan 22, 2026
3bb332e
remove the use of the internal `Mooncake._rdata`
lkdvos Jan 26, 2026
079740a
add comments about `NoRData()`
lkdvos Jan 26, 2026
553ee2b
add TODO
lkdvos Jan 26, 2026
53c3c34
correctly implement `_needs_tangent`
lkdvos Jan 29, 2026
c4c8cb9
update to Mooncake 0.5
lkdvos Jan 29, 2026
d3afdfe
add TensorMap tangent type
lkdvos Jan 29, 2026
4963393
fix stupid tolerance mistake
lkdvos Jan 29, 2026
1691db9
enable complex tests
lkdvos Jan 29, 2026
b648a30
add tangent type test
lkdvos Jan 29, 2026
ea8cf9e
correct arrayify
lkdvos Jan 29, 2026
0f6891b
fix indexmanipulations
lkdvos Jan 29, 2026
96bca51
bump versions
lkdvos Jan 29, 2026
aca99b2
deal with more complex sector shenanigans
lkdvos Jan 29, 2026
8c84953
properly accumulate
lkdvos Jan 29, 2026
15bf332
nicer _needs_tangent
lkdvos Jan 30, 2026
07f74b0
remove source
lkdvos Jan 30, 2026
497c8f6
fix TensorOperations
lkdvos Jan 30, 2026
185b2e6
remove duplicate method
lkdvos Jan 30, 2026
b5d7ab8
fix arg order
lkdvos Jan 30, 2026
b5c4a6c
add missing ChainRules import
lkdvos Jan 30, 2026
d8d5144
add JET compat
lkdvos Jan 31, 2026
89df24d
some cleanup
lkdvos Feb 1, 2026
c677596
more handling of scalartypes
lkdvos Feb 1, 2026
7ecf6f6
more testing
lkdvos Feb 1, 2026
d5ad4af
add specialization for `MAK.zero!`
lkdvos Feb 2, 2026
3a97abc
add tests on factorizations
lkdvos Feb 2, 2026
3c48a34
add DiagonalTensorMap tangent type
lkdvos Feb 2, 2026
842179f
specialize SVD pullback implementations
lkdvos Feb 2, 2026
50f4e70
careful about projections
lkdvos Feb 3, 2026
22593e5
disable mooncake tests on Apple
lkdvos Feb 3, 2026
977bc2c
add missing diagonal constructor
lkdvos Feb 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ jobs:
- symmetries
- tensors
- other
- autodiff
- mooncake
- chainrules
os:
- ubuntu-latest
- macOS-latest
Expand All @@ -55,7 +56,8 @@ jobs:
- symmetries
- tensors
- other
- autodiff
- mooncake
- chainrules
os:
- ubuntu-latest
- macOS-latest
Expand Down
16 changes: 10 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorKit"
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
authors = ["Jutho Haegeman, Lukas Devos"]
version = "0.16.0"
authors = ["Jutho Haegeman, Lukas Devos"]

[deps]
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
Expand All @@ -22,8 +22,8 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[extensions]
TensorKitAdaptExt = "Adapt"
Expand All @@ -34,6 +34,7 @@ TensorKitMooncakeExt = "Mooncake"

[compat]
Adapt = "4"
AllocCheck = "0.2.3"
Aqua = "0.6, 0.7, 0.8"
ArgParse = "1.2.0"
CUDA = "5.9"
Expand All @@ -42,10 +43,11 @@ ChainRulesTestUtils = "1"
Combinatorics = "1"
FiniteDifferences = "0.12"
GPUArrays = "11.3.1"
JET = "0.9, 0.10, 0.11"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.3"
Mooncake = "0.4.183"
MatrixAlgebraKit = "0.6.4"
Mooncake = "0.5"
OhMyThreads = "0.8.0"
Printf = "1"
Random = "1"
Expand All @@ -56,14 +58,15 @@ TensorKitSectors = "0.3.5"
TensorOperations = "5.1"
Test = "1"
TestExtras = "0.2,0.3"
TupleTools = "1.1"
TupleTools = "1.5"
VectorInterface = "0.4.8, 0.5"
Zygote = "0.7"
cuTENSOR = "2"
julia = "1.10"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand All @@ -72,6 +75,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand All @@ -82,4 +86,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"]
test = ["ArgParse", "Adapt", "Aqua", "AllocCheck", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake", "JET"]
14 changes: 11 additions & 3 deletions ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
module TensorKitMooncakeExt

using Mooncake
using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal
using Mooncake: @zero_derivative, @is_primitive,
DefaultCtx, MinimalCtx, ReverseMode, NoFData, NoRData, CoDual, Dual,
arrayify, primal, tangent
using TensorKit
import TensorKit as TK
using VectorInterface
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
import TensorOperations as TO
using VectorInterface: One, Zero
using MatrixAlgebraKit
using TupleTools

using Random: AbstractRNG

include("utility.jl")
include("tangent.jl")
include("linalg.jl")
include("indexmanipulations.jl")
include("vectorinterface.jl")
include("tensoroperations.jl")
include("planaroperations.jl")
include("factorizations.jl")

end
63 changes: 63 additions & 0 deletions ext/TensorKitMooncakeExt/factorizations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
for f in (:svd_compact, :svd_full)
f_pullback = Symbol(f, :_pullback)
@eval begin
@is_primitive DefaultCtx ReverseMode Tuple{typeof($f), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual)
A, dA = arrayify(A_dA)
alg = primal(alg_dalg)

USVᴴ = $f(A, primal(alg_dalg))
USVᴴ_dUSVᴴ = Mooncake.zero_fcodual(USVᴴ)
dUSVᴴ = last.(arrayify.(USVᴴ, tangent(USVᴴ_dUSVᴴ)))

function $f_pullback(::NoRData)
MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴ)
MatrixAlgebraKit.zero!.(dUSVᴴ)
return ntuple(Returns(NoRData()), 3)
end

return USVᴴ_dUSVᴴ, $f_pullback
end
end

# mutating version is not guaranteed to actually mutate
# so we can simply use the non-mutating version instead and avoid having to worry about
# storing copies and restoring state
f! = Symbol(f, :!)
f!_pullback = Symbol(f!, :_pullback)
@eval begin
@is_primitive DefaultCtx ReverseMode Tuple{typeof($f!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm}
Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) =
Mooncake.rrule!!(Mooncake.zero_fcodual($f), A_dA, alg_dalg)
end
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(
::CoDual{typeof(svd_trunc)},
A_dA::CoDual{<:AbstractTensorMap},
alg_dalg::CoDual{<:MatrixAlgebraKit.TruncatedAlgorithm}
)
A, dA = arrayify(A_dA)
alg = primal(alg_dalg)

USVᴴ = svd_compact(A, alg.alg)
USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind)

USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ))
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(tangent(USVᴴtrunc_dUSVᴴtrunc))))

function svd_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real})
abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) ||
@warn "Gradient for `svd_trunc` ignores non-zero tangents for truncation error"
MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind)
return ntuple(Returns(NoRData()), 3)
end

return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm}
Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) =
Mooncake.rrule!!(Mooncake.zero_fcodual(svd_trunc), A_dA, alg_dalg)
Loading
Loading