From 9e1ebfe7940e15247bb9060cd3cb27b5beb33073 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 18 Feb 2026 07:35:49 -0500 Subject: [PATCH 01/27] More tweaks for GPU support --- ext/TensorKitCUDAExt/cutensormap.jl | 18 +--- src/tensors/abstracttensor.jl | 16 ++-- src/tensors/adjoint.jl | 2 + src/tensors/braidingtensor.jl | 16 ++-- src/tensors/diagonal.jl | 2 +- src/tensors/indexmanipulations.jl | 12 +-- src/tensors/tensor.jl | 3 +- test/cuda/tensors.jl | 128 +++++++++++++--------------- 8 files changed, 98 insertions(+), 99 deletions(-) diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index 2fefb3a24..a1d2d1b66 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -6,6 +6,9 @@ const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A} return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t)) end +function TensorMap{T, S, N₁, N₂, DA}(t::TensorMap{T, S, N₁, N₂, HA}) where {T, S, N₁, N₂, DA <: CuArray{T}, HA <: Array{T}} + return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t)) +end # project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}} @@ -101,18 +104,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S} return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)] end -function Base.convert( - TT::Type{CuTensorMap{T, S, N₁, N₂}}, - t::AbstractTensorMap{<:Any, S, N₁, N₂} - ) where {T, S, N₁, N₂} - if typeof(t) === TT - return t - else - tnew = TT(undef, space(t)) - return copy!(tnew, t) - end -end - function LinearAlgebra.isposdef(t::CuTensorMap) domain(t) == codomain(t) || throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) @@ -138,10 +129,9 @@ function Base.promote_rule( return CuTensorMap{T, S, N₁, N₂} end -TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} = +TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} = CuArray{T, N, CUDA.default_memory} - # CuTensorMap exponentation: function TensorKit.exp!(t::CuTensorMap) domain(t) == codomain(t) || diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 628845526..c517a165e 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -53,9 +53,7 @@ storagetype(t) = storagetype(typeof(t)) function storagetype(::Type{T}) where {T <: AbstractTensorMap} if T isa Union # attempt to be slightly more specific by promoting unions - Ma = storagetype(T.a) - Mb = storagetype(T.b) - return promote_storagetype(Ma, Mb) + return promote_storagetype(T.a, T.b) else # fallback definition by using scalartype return similarstoragetype(scalartype(T)) @@ -103,11 +101,19 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} = # implement on tensors similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT)) -similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} = - similarstoragetype(storagetype(TT), T) +function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} + return similarstoragetype(storagetype(TT), T) +end +function similarstoragetype(::Type{<:AbstractTensorMap{T, S, N₁, N₂}}, ::Type{TA}) where {T <: Number, TA <: DenseVector, S, N₁, N₂} + return similarstoragetype(TA, T) +end +function similarstoragetype(t::AbstractTensorMap{T, S, N₁, N₂}, ::Type{TA}) where {T <: Number, TA <: DenseVector, S, N₁, N₂} + return similarstoragetype(typeof(t), TA) +end # implement on arrays similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A +similarstoragetype(::Type{A}, ::Type{A}) where {A <: DenseVector{<:Number}} = A Base.@assume_effects :foldable similarstoragetype(::Type{A}) where {A <: AbstractArray{<:Number}} = Core.Compiler.return_type(similar, Tuple{A, Int}) Base.@assume_effects :foldable similarstoragetype(::Type{A}, ::Type{T}) where {A <: AbstractArray, T <: Number} = diff --git a/src/tensors/adjoint.jl b/src/tensors/adjoint.jl index dfc1a4471..382f309b5 100644 --- a/src/tensors/adjoint.jl +++ b/src/tensors/adjoint.jl @@ -22,6 +22,8 @@ Base.adjoint(t::AbstractTensorMap) = AdjointTensorMap(t) space(t::AdjointTensorMap) = adjoint(space(parent(t))) dim(t::AdjointTensorMap) = dim(parent(t)) storagetype(::Type{AdjointTensorMap{T, S, N₁, N₂, TT}}) where {T, S, N₁, N₂, TT} = storagetype(TT) +similarstoragetype(::AdjointTensorMap{T, S, N₁, N₂, TT}, ::Type{T′}) where {T, S, N₁, N₂, TT, T′ <: Number} = similarstoragetype(TT, T′) +similarstoragetype(::AdjointTensorMap{T, S, N₁, N₂, TT}, ::Type{TA}) where {T, S, N₁, N₂, TT, TA <: DenseVector} = similarstoragetype(TT, TA) # Blocks and subblocks #---------------------- diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index d28b2e1df..beccb5fbb 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -189,12 +189,15 @@ end has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false function add_transform!( tdst::AbstractTensorMap, - tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple, + tsrc::BraidingTensor{T, S}, + (p₁, p₂)::Index2Tuple, fusiontreetransform, α::Number, β::Number, backend::AbstractBackend... - ) + ) where {T, S} + tsrc_map = similar(tdst, storagetype(tdst), space(tsrc)) + copy!(tsrc_map, tsrc) return add_transform!( - tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β, + tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β, backend... ) end @@ -261,8 +264,11 @@ function planarcontract!( backend, allocator ) # special case only defined for contracting all 4 indices of B (2 contracted + 2 open) - length.(pB) == (2, 2) || - return planarcontract!(C, A, pA, TensorMap(B), pB, pAB, α, β, backend, allocator) + if length.(pB) != (2, 2) + tB′ = TensorMap(B) + tB = TensorMapWithStorage{eltype(B), similarstoragetype(A, eltype(B)), spacetype(tB′), numout(tB′), numin(tB′)}(tB′) + return planarcontract!(C, A, pA, tB, pB, pAB, α, β, backend, allocator) + end spacecheck_contract(C, A, pA, false, B, pB, false, pAB) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index b2ac4134b..e73ad2787 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -280,7 +280,7 @@ end # ---------------- function TO.tensoradd_type(TC, A::DiagonalTensorMap, ::Index2Tuple{1, 1}, ::Bool) M = similarstoragetype(A, TC) - return DiagonalTensorMap{TC, spacetype(A), M} + return DiagonalTensorMap{scalartype(M), spacetype(A), M} end function TO.tensorcontract_type( diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index 3108abb17..e45789b44 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -17,6 +17,8 @@ for (operation, manipulation) in ( $promote_op(::Type{T}, ::Type{I}) where {T <: Number, I <: Sector} = sectorscalartype(I) <: Integer ? T : sectorscalartype(I) <: Real ? float(T) : complex(T) + $promote_op(::Type{TA}, ::Type{I}) where {TA <: DenseVector, I <: Sector} = + similarstoragetype(TA, $promote_op(eltype(TA), I)) # TODO: currently the manipulations all use sectorscalartype, change to: # $manipulation_scalartype(I) <: Integer ? T : # $manipulation_scalartype(I) <: Real ? float(T) : complex(T) @@ -342,11 +344,11 @@ See also [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i}) """ function insertleftunit( t::AbstractTensorMap, ::Val{i} = Val(numind(t) + 1); - copy::Bool = false, conj::Bool = false, dual::Bool = false + copy::Bool = false, conj::Bool = false, dual::Bool = false, ) where {i} W = insertleftunit(space(t), Val(i); conj, dual) if t isa TensorMap - return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) + return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W) else tdst = similar(t, W) for (c, b) in blocks(t) @@ -371,11 +373,11 @@ See also [`insertleftunit`](@ref insertleftunit(::AbstractTensorMap, ::Val{i}) w """ function insertrightunit( t::AbstractTensorMap, ::Val{i} = Val(numind(t)); - copy::Bool = false, conj::Bool = false, dual::Bool = false + copy::Bool = false, conj::Bool = false, dual::Bool = false, ) where {i} W = insertrightunit(space(t), Val(i); conj, dual) if t isa TensorMap - return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) + return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W) else tdst = similar(t, W) for (c, b) in blocks(t) @@ -400,7 +402,7 @@ and [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i}) wher function removeunit(t::AbstractTensorMap, ::Val{i}; copy::Bool = false) where {i} W = removeunit(space(t), Val(i)) if t isa TensorMap - return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) + return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W) else tdst = similar(t, W) for (c, b) in blocks(t) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 342c83186..65e6adae6 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -21,7 +21,6 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac end return TensorMap{T, S, N₁, N₂, A}(data, space) end - # constructors from data function TensorMap{T, S, N₁, N₂, A}( data::A, space::TensorMapSpace{S, N₁, N₂} @@ -34,6 +33,8 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac return new{T, S, N₁, N₂, A}(data, space) end end +# constructors from another TensorMap -- no-op +TensorMap{T, S, N₁, N₂, A}(t::TensorMap{T, S, N₁, N₂, A}) where {T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} = t """ Tensor{T, S, N, A<:DenseVector{T}} = TensorMap{T, S, N, 0, A} diff --git a/test/cuda/tensors.jl b/test/cuda/tensors.jl index 738440bef..7b958d915 100644 --- a/test/cuda/tensors.jl +++ b/test/cuda/tensors.jl @@ -14,6 +14,7 @@ for V in spacelist println("---------------------------------------") println("CUDA Tensors with symmetry: $Istr") println("---------------------------------------") + hasbraiding = BraidingStyle(I) isa HasBraiding symmetricbraiding = BraidingStyle(I) isa SymmetricBraiding @timedtestset "Tensors with symmetry: $Istr" verbose = true begin V1, V2, V3, V4, V5 = V @@ -209,8 +210,8 @@ for V in spacelist α = rand(T) @test norm(t, 2) ≈ norm(TensorKit.to_cpu(t), 2) @test dot(t2, t) ≈ dot(TensorKit.to_cpu(t2), TensorKit.to_cpu(t)) - @test TensorKit.to_cpu(α * t) ≈ α * TensorKit.to_cpu(t) - @test TensorKit.to_cpu(t + t) ≈ 2 * TensorKit.to_cpu(t) + @test adapt(Vector{T}, (α * t)) ≈ α * adapt(Vector{T}, t) + @test adapt(Vector{T}, (t + t)) ≈ 2 * adapt(Vector{T}, t) end end @timedtestset "Real and imaginary parts" begin @@ -263,16 +264,22 @@ for V in spacelist symmetricbraiding && @timedtestset "Permutations: test via inner product invariance" begin W = V1 ⊗ V2 ⊗ V3 ⊗ V4 ⊗ V5 t = cuRAND.rand(ComplexF64, W) + ht = adapt(Vector{ComplexF64}, t) t′ = cuRAND.randn!(similar(t)) + ht′ = adapt(Vector{ComplexF64}, t′) + dot_htt′ = dot(ht′, ht) + dot_tt′ = dot(t′, t) + @test dot_tt′ ≈ dot_htt′ + norm_t = norm(t) for k in 0:5 for p in permutations(1:5) p1 = ntuple(n -> p[n], k) p2 = ntuple(n -> p[k + n], 5 - k) t2 = @constinferred permute(t, (p1, p2)) - t2 = permute(t, (p1, p2)) - @test norm(t2) ≈ norm(t) t2′ = permute(t′, (p1, p2)) - @test dot(t2′, t2) ≈ dot(t′, t) ≈ dot(transpose(t2′), transpose(t2)) + @test norm(t2) ≈ norm_t + @test dot(t2′, t2) ≈ dot_tt′ + @test dot(transpose(t2′), transpose(t2)) ≈ dot_tt′ end t3 = @constinferred repartition(t, $k) t3 = repartition(t, k) @@ -293,29 +300,26 @@ for V in spacelist ht2 = permute(TensorKit.to_cpu(t), (p1, p2)) @test ht2 ≈ TensorKit.to_cpu(dt2) end - - dt3 = CUDA.@allowscalar repartition(t, k) + dt3 = repartition(t, k) ht3 = repartition(TensorKit.to_cpu(t), k) @test ht3 ≈ TensorKit.to_cpu(dt3) end end symmetricbraiding && @timedtestset "Full trace: test self-consistency" begin t = cuRAND.rand(ComplexF64, V1 ⊗ V2' ⊗ V2 ⊗ V1') - CUDA.@allowscalar begin - t2 = permute(t, ((1, 2), (4, 3))) - s = @constinferred tr(t2) - @test conj(s) ≈ tr(t2') - if !isdual(V1) - t2 = twist!(t2, 1) - end - if isdual(V2) - t2 = twist!(t2, 2) - end - ss = tr(t2) - @tensor s2 = t[a, b, b, a] - @tensor t3[a, b] := t[a, c, c, b] - @tensor s3 = t3[a, a] + t2 = permute(t, ((1, 2), (4, 3))) + s = @constinferred tr(t2) + @test conj(s) ≈ tr(t2') + if !isdual(V1) + t2 = twist!(t2, 1) + end + if isdual(V2) + t2 = twist!(t2, 2) end + ss = tr(t2) + @tensor s2 = t[a, b, b, a] + @tensor t3[a, b] := t[a, c, c, b] + @tensor s3 = t3[a, a] @test ss ≈ s2 @test ss ≈ s3 end @@ -328,20 +332,16 @@ for V in spacelist end symmetricbraiding && @timedtestset "Trace: test via conversion" begin t = cuRAND.rand(ComplexF64, V1 ⊗ V2' ⊗ V3 ⊗ V2 ⊗ V1' ⊗ V3') - CUDA.@allowscalar begin - @tensor t2[a, b] := t[c, d, b, d, c, a] - @tensor t3[a, b] := ad(t)[c, d, b, d, c, a] - end + @tensor t2[a, b] := t[c, d, b, d, c, a] + @tensor t3[a, b] := ad(t)[c, d, b, d, c, a] @test t3 ≈ ad(t2) end symmetricbraiding && @timedtestset "Trace and contraction" begin t1 = cuRAND.rand(ComplexF64, V1 ⊗ V2 ⊗ V3) t2 = cuRAND.rand(ComplexF64, V2' ⊗ V4 ⊗ V1') - CUDA.@allowscalar begin - t3 = t1 ⊗ t2 - @tensor ta[a, b] := t1[x, y, a] * t2[y, b, x] - @tensor tb[a, b] := t3[x, y, a, y, b, x] - end + t3 = t1 ⊗ t2 + @tensor ta[a, b] := t1[x, y, a] * t2[y, b, x] + @tensor tb[a, b] := t3[x, y, a, y, b, x] @test ta ≈ tb end if BraidingStyle(I) isa Bosonic && hasfusiontensor(I) @@ -360,44 +360,38 @@ for V in spacelist @test TensorKit.to_cpu(dHrA12) ≈ hHrA12 end end - BraidingStyle(I) isa HasBraiding && @timedtestset "Index flipping: test flipping inverse" begin + hasbraiding && @timedtestset "Index flipping: test flipping inverse" begin t = cuRAND.rand(ComplexF64, V1 ⊗ V2 ⊗ V3 ← (V4 ⊗ V5)') for i in 1:5 - CUDA.@allowscalar begin - @test t ≈ flip(flip(t, i), i; inv = true) - @test t ≈ flip(flip(t, i; inv = true), i) - end + @test t ≈ flip(flip(t, i), i; inv = true) + @test t ≈ flip(flip(t, i; inv = true), i) end end - #=@timedtestset "Index flipping: test via explicit flip" begin + symmetricbraiding && "Index flipping: test via explicit flip" begin t = cuRAND.rand(ComplexF64, V1 ⊗ V1' ← V1' ⊗ V1) - F1 = unitary(flip(V1), V1) + F1 = adapt(CuArray{ComplexF64}, unitary(flip(V1), V1)) - CUDA.@allowscalar begin - @tensor tf[a, b; c, d] := F1[a, a'] * t[a', b; c, d] - @test flip(t, 1) ≈ tf - @tensor tf[a, b; c, d] := conj(F1[b, b']) * t[a, b'; c, d] - @test twist!(flip(t, 2), 2) ≈ tf - @tensor tf[a, b; c, d] := F1[c, c'] * t[a, b; c', d] - @test flip(t, 3) ≈ tf - @tensor tf[a, b; c, d] := conj(F1[d, d']) * t[a, b; c, d'] - @test twist!(flip(t, 4), 4) ≈ tf - end + @tensor tf[a, b; c, d] := F1[a, a'] * t[a', b; c, d] + @test flip(t, 1) ≈ tf + @tensor tf[a, b; c, d] := conj(F1[b, b']) * t[a, b'; c, d] + @test twist!(flip(t, 2), 2) ≈ tf + @tensor tf[a, b; c, d] := F1[c, c'] * t[a, b; c', d] + @test flip(t, 3) ≈ tf + @tensor tf[a, b; c, d] := conj(F1[d, d']) * t[a, b; c, d'] + @test twist!(flip(t, 4), 4) ≈ tf end - @timedtestset "Index flipping: test via contraction" begin + symmetricbraiding && @timedtestset "Index flipping: test via contraction" begin t1 = cuRAND.rand(ComplexF64, V1 ⊗ V2 ⊗ V3 ← V4) t2 = cuRAND.rand(ComplexF64, V2' ⊗ V5 ← V4' ⊗ V1) - CUDA.@allowscalar begin - @tensor ta[a, b] := t1[x, y, a, z] * t2[y, b, z, x] - @tensor tb[a, b] := flip(t1, 1)[x, y, a, z] * flip(t2, 4)[y, b, z, x] - @test ta ≈ tb - @tensor tb[a, b] := flip(t1, (2, 4))[x, y, a, z] * flip(t2, (1, 3))[y, b, z, x] - @test ta ≈ tb - @tensor tb[a, b] := flip(t1, (1, 2, 4))[x, y, a, z] * flip(t2, (1, 3, 4))[y, b, z, x] - @tensor tb[a, b] := flip(t1, (1, 3))[x, y, a, z] * flip(t2, (2, 4))[y, b, z, x] - @test flip(ta, (1, 2)) ≈ tb - end - end=# # TODO + @tensor ta[a, b] := t1[x, y, a, z] * t2[y, b, z, x] + @tensor tb[a, b] := flip(t1, 1)[x, y, a, z] * flip(t2, 4)[y, b, z, x] + @test ta ≈ tb + @tensor tb[a, b] := flip(t1, (2, 4))[x, y, a, z] * flip(t2, (1, 3))[y, b, z, x] + @test ta ≈ tb + @tensor tb[a, b] := flip(t1, (1, 2, 4))[x, y, a, z] * flip(t2, (1, 3, 4))[y, b, z, x] + @tensor tb[a, b] := flip(t1, (1, 3))[x, y, a, z] * flip(t2, (2, 4))[y, b, z, x] + @test flip(ta, (1, 2)) ≈ tb + end @timedtestset "Multiplication of isometries: test properties" begin W1 = V1 ⊗ V2 ⊗ V3 W2 = (V4 ⊗ V5)' @@ -551,10 +545,8 @@ for V in spacelist t1 = cuRAND.rand(T, V1, V5') t2 = cuRAND.rand(T, V2 ⊗ V3, V4') t = @constinferred (t1 ⊗ t2) - CUDA.@allowscalar begin - @tensor t′[1 2 3; 4 5] := t1[1; 4] * t2[2 3; 5] - end - @test t ≈ t′ # This should really not be broken + @tensor t′[1 2 3; 4 5] := t1[1; 4] * t2[2 3; 5] + @test t ≈ t′ end end end @@ -567,17 +559,17 @@ end V1, V2, V3, V4, V5 = Vlist1 W1, W2, W3, W4, W5 = Vlist2 for T in (Float32, ComplexF64) - t1 = rand(T, V2 ⊗ V3, (V4 ⊗ V5)') - t2 = rand(T, W2, (W3 ⊗ W4)') + t1 = CUDA.rand(T, V2 ⊗ V3, (V4 ⊗ V5)') + t2 = CUDA.rand(T, W2, (W3 ⊗ W4)') t = @constinferred (t1 ⊠ t2) d1 = dim(codomain(t1)) d2 = dim(codomain(t2)) d3 = dim(domain(t1)) d4 = dim(domain(t2)) - At = convert(Array, t) + At = convert(Array, adapt(Vector{T}, t)) @test reshape(At, (d1, d2, d3, d4)) ≈ - reshape(convert(Array, t1), (d1, 1, d3, 1)) .* - reshape(convert(Array, t2), (1, d2, 1, d4)) + reshape(convert(Array, adapt(Vector{T}, t1)), (d1, 1, d3, 1)) .* + reshape(convert(Array, adapt(Vector{T}, t2)), (1, d2, 1, d4)) end end end From ea787aebff63e6285b9a9393c29bb040b04ad954 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 27 Apr 2026 07:53:59 -0400 Subject: [PATCH 02/27] fix typo --- test/cuda/tensors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cuda/tensors.jl b/test/cuda/tensors.jl index 7b958d915..bc40d381e 100644 --- a/test/cuda/tensors.jl +++ b/test/cuda/tensors.jl @@ -367,7 +367,7 @@ for V in spacelist @test t ≈ flip(flip(t, i; inv = true), i) end end - symmetricbraiding && "Index flipping: test via explicit flip" begin + symmetricbraiding && @timedtestset "Index flipping: test via explicit flip" begin t = cuRAND.rand(ComplexF64, V1 ⊗ V1' ← V1' ⊗ V1) F1 = adapt(CuArray{ComplexF64}, unitary(flip(V1), V1)) From c3d33e2d2170291d10d603170d75a259bc3f85cc Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 11 May 2026 09:46:54 -0400 Subject: [PATCH 03/27] Fix TC once again --- src/tensors/tensoroperations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 3fc79cf0c..f42f7d94f 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -379,7 +379,7 @@ function blas_contract!( bstyle = BraidingStyle(sectortype(C)) bstyle isa SymmetricBraiding || throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead")) - TC = scalartype(C) + TC = storagetype(C) # without this, Anew below has wrong storagetype # check which tensors have to be permuted/copied copyA = !(TO.isblascontractable(A, pA) && scalartype(A) === TC) From eb6b9856f9f9fe198d22685ebece51d8c3ebecbd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 12 May 2026 05:52:04 -0400 Subject: [PATCH 04/27] Remove unneeded Adjoint methods --- src/tensors/adjoint.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tensors/adjoint.jl b/src/tensors/adjoint.jl index 382f309b5..dfc1a4471 100644 --- a/src/tensors/adjoint.jl +++ b/src/tensors/adjoint.jl @@ -22,8 +22,6 @@ Base.adjoint(t::AbstractTensorMap) = AdjointTensorMap(t) space(t::AdjointTensorMap) = adjoint(space(parent(t))) dim(t::AdjointTensorMap) = dim(parent(t)) storagetype(::Type{AdjointTensorMap{T, S, N₁, N₂, TT}}) where {T, S, N₁, N₂, TT} = storagetype(TT) -similarstoragetype(::AdjointTensorMap{T, S, N₁, N₂, TT}, ::Type{T′}) where {T, S, N₁, N₂, TT, T′ <: Number} = similarstoragetype(TT, T′) -similarstoragetype(::AdjointTensorMap{T, S, N₁, N₂, TT}, ::Type{TA}) where {T, S, N₁, N₂, TT, TA <: DenseVector} = similarstoragetype(TT, TA) # Blocks and subblocks #---------------------- From c4653c8faa488600df7089b999ae452c6785439d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 12 May 2026 05:56:06 -0400 Subject: [PATCH 05/27] Remove unneeded TensorMapWithStorage? --- src/tensors/indexmanipulations.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index e45789b44..ac5043b44 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -348,7 +348,7 @@ function insertleftunit( ) where {i} W = insertleftunit(space(t), Val(i); conj, dual) if t isa TensorMap - return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W) + return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) else tdst = similar(t, W) for (c, b) in blocks(t) @@ -377,7 +377,7 @@ function insertrightunit( ) where {i} W = insertrightunit(space(t), Val(i); conj, dual) if t isa TensorMap - return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W) + return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) else tdst = similar(t, W) for (c, b) in blocks(t) @@ -402,7 +402,7 @@ and [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i}) wher function removeunit(t::AbstractTensorMap, ::Val{i}; copy::Bool = false) where {i} W = removeunit(space(t), Val(i)) if t isa TensorMap - return TensorMapWithStorage{scalartype(t), storagetype(t)}(copy ? Base.copy(t.data) : t.data, W) + return TensorMap{scalartype(t)}(copy ? Base.copy(t.data) : t.data, W) else tdst = similar(t, W) for (c, b) in blocks(t) From b46e23ef393ec8c4cbc7592789d976d6858aa826 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 12 May 2026 06:04:17 -0400 Subject: [PATCH 06/27] Death to to_cpu --- src/tensors/adjoint.jl | 2 -- src/tensors/tensor.jl | 5 --- test/amd/tensors.jl | 72 +++++++++++++++++++++--------------------- test/cuda/tensors.jl | 70 ++++++++++++++++++++-------------------- 4 files changed, 71 insertions(+), 78 deletions(-) diff --git a/src/tensors/adjoint.jl b/src/tensors/adjoint.jl index dfc1a4471..ca484e77b 100644 --- a/src/tensors/adjoint.jl +++ b/src/tensors/adjoint.jl @@ -50,8 +50,6 @@ Base.@propagate_inbounds function subblock(t::AdjointTensorMap, (f₁, f₂)::Tu return permutedims(conj(data), (domainind(tp)..., codomainind(tp)...)) end -to_cpu(t::AdjointTensorMap) = adjoint(to_cpu(adjoint(t))) - # Show #------ function Base.showarg(io::IO, t::AdjointTensorMap, toplevel::Bool) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 65e6adae6..7a94941c1 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -408,11 +408,6 @@ for randf in (:rand, :randn, :randexp, :randisometry) end end -# Moving arbitrary TensorMaps to CPU -#----------------------------- -to_cpu(t::TensorMapWithStorage{T, Vector{T}}) where {T} = t # no op -to_cpu(t::TensorMap) = convert(TensorMapWithStorage{scalartype(t), similarstoragetype(scalartype(t))}, t) - # Efficient copy constructors #----------------------------- Base.copy(t::TensorMap) = typeof(t)(copy(t.data), t.space) diff --git a/test/amd/tensors.jl b/test/amd/tensors.jl index 0c6898405..2ac522daa 100644 --- a/test/amd/tensors.jl +++ b/test/amd/tensors.jl @@ -97,7 +97,7 @@ for V in spacelist for T in (Int, Float32, ComplexF64) t = @constinferred AMDGPU.rand(T, W) d = convert(Dict, t) - @test TensorKit.to_cpu(t) == convert(TensorMap, d) + @test adapt(Array, t) == convert(TensorMap, d) end end symmetricbraiding && @timedtestset "Basic linear algebra" begin @@ -189,10 +189,10 @@ for V in spacelist t = AMDGPU.rand(T, W) t2 = @constinferred AMDGPU.rand!(similar(t)) α = rand(T) - @test norm(t, 2) ≈ norm(TensorKit.to_cpu(t), 2) - @test dot(t2, t) ≈ dot(TensorKit.to_cpu(t2), TensorKit.to_cpu(t)) - @test TensorKit.to_cpu(α * t) ≈ α * TensorKit.to_cpu(t) - @test TensorKit.to_cpu(t + t) ≈ 2 * TensorKit.to_cpu(t) + @test norm(t, 2) ≈ norm(adapt(Array, t), 2) + @test dot(t2, t) ≈ dot(adapt(Array, t2), adapt(Array, t)) + @test adapt(Array, α * t) ≈ α * adapt(Array, t) + @test adapt(Array, t + t) ≈ 2 * adapt(Array, t) end end @timedtestset "Real and imaginary parts" begin @@ -202,17 +202,17 @@ for V in spacelist tr = @constinferred real(t) @test scalartype(tr) <: Real - @test real(TensorKit.to_cpu(t)) == TensorKit.to_cpu(tr) + @test real(adapt(Array, t)) == adapt(Array, tr) @test storagetype(tr) == ROCVector{real(T), AMDGPU.Mem.HIPBuffer} ti = @constinferred imag(t) @test scalartype(ti) <: Real - @test imag(TensorKit.to_cpu(t)) == TensorKit.to_cpu(ti) + @test imag(adapt(Array, t)) == adapt(Array, ti) @test storagetype(ti) == ROCVector{real(T), AMDGPU.Mem.HIPBuffer} tc = @inferred complex(t) @test scalartype(tc) <: Complex - @test complex(TensorKit.to_cpu(t)) == TensorKit.to_cpu(tc) + @test complex(adapt(Array, t)) == adapt(Array, tc) @test storagetype(tc) == ROCVector{complex(T), AMDGPU.Mem.HIPBuffer} tc2 = @inferred complex(tr, ti) @@ -275,13 +275,13 @@ for V in spacelist p1 = ntuple(n -> p[n], k) p2 = ntuple(n -> p[k + n], 5 - k) dt2 = AMDGPU.@allowscalar permute(t, (p1, p2)) - ht2 = permute(TensorKit.to_cpu(t), (p1, p2)) - @test ht2 == TensorKit.to_cpu(dt2) + ht2 = permute(adapt(Array, t), (p1, p2)) + @test ht2 == adapt(Array, dt2) end dt3 = AMDGPU.@allowscalar repartition(t, k) - ht3 = repartition(TensorKit.to_cpu(t), k) - @test ht3 == TensorKit.to_cpu(dt3) + ht3 = repartition(adapt(Array, t), k) + @test ht3 == adapt(Array, dt3) end end symmetricbraiding && @timedtestset "Full trace: test self-consistency" begin @@ -339,10 +339,10 @@ for V in spacelist @tensor dHrA12[a, s1, s2, c] := drhoL[a, a'] * conj(dA1[a', t1, b]) * dA2[b, t2, c'] * drhoR[c', c] * dH[s1, s2, t1, t2] - @tensor hHrA12[a, s1, s2, c] := TensorKit.to_cpu(drhoL)[a, a'] * conj(TensorKit.to_cpu(dA1)[a', t1, b]) * - TensorKit.to_cpu(dA2)[b, t2, c'] * TensorKit.to_cpu(drhoR)[c', c] * - TensorKit.to_cpu(dH)[s1, s2, t1, t2] - @test TensorKit.to_cpu(dHrA12) ≈ hHrA12 + @tensor hHrA12[a, s1, s2, c] := adapt(Array, drhoL)[a, a'] * conj(adapt(Array, dA1)[a', t1, b]) * + adapt(Array, dA2)[b, t2, c'] * adapt(Array, drhoR)[c', c] * + adapt(Array, dH)[s1, s2, t1, t2] + @test adapt(Array, dHrA12) ≈ hHrA12 end end=# # doesn't yet work because of AdjointTensor BraidingStyle(I) isa HasBraiding && @timedtestset "Index flipping: test flipping inverse" begin @@ -422,31 +422,31 @@ for V in spacelist t1 = AMDGPU.rand(T, W1, W1) t2 = AMDGPU.rand(T, W2, W2) t = AMDGPU.rand(T, W1, W2) - ht1 = TensorKit.to_cpu(t1) - ht2 = TensorKit.to_cpu(t2) - ht = TensorKit.to_cpu(t) - @test TensorKit.to_cpu(t1 * t) ≈ ht1 * ht - @test TensorKit.to_cpu(t1' * t) ≈ ht1' * ht - @test TensorKit.to_cpu(t2 * t') ≈ ht2 * ht' - @test TensorKit.to_cpu(t2' * t') ≈ ht2' * ht' + ht1 = adapt(Array, t1) + ht2 = adapt(Array, t2) + ht = adapt(Array, t) + @test adapt(Array, t1 * t) ≈ ht1 * ht + @test adapt(Array, t1' * t) ≈ ht1' * ht + @test adapt(Array, t2 * t') ≈ ht2 * ht' + @test adapt(Array, t2' * t') ≈ ht2' * ht' #=AMDGPU.@allowscalar begin - @test TensorKit.to_cpu(inv(t1)) ≈ inv(ht1) - @test TensorKit.to_cpu(pinv(t)) ≈ pinv(ht) + @test adapt(Array, inv(t1)) ≈ inv(ht1) + @test adapt(Array, pinv(t)) ≈ pinv(ht) if T == Float32 || T == ComplexF32 continue end - @test TensorKit.to_cpu(t1 \ t) ≈ ht1 \ ht - @test TensorKit.to_cpu(t1' \ t) ≈ ht1' \ ht - @test TensorKit.to_cpu(t2 \ t') ≈ ht2 \ ht' - @test TensorKit.to_cpu(t2' \ t') ≈ ht2' \ ht' + @test adapt(Array, t1 \ t) ≈ ht1 \ ht + @test adapt(Array, t1' \ t) ≈ ht1' \ ht + @test adapt(Array, t2 \ t') ≈ ht2 \ ht' + @test adapt(Array, t2' \ t') ≈ ht2' \ ht' - @test TensorKit.to_cpu(t2 / t) ≈ ht2 / ht - @test TensorKit.to_cpu(t2' / t) ≈ ht2' / ht - @test TensorKit.to_cpu(t1 / t') ≈ ht1 / ht' - @test TensorKit.to_cpu(t1' / t') ≈ ht1' / ht' + @test adapt(Array, t2 / t) ≈ ht2 / ht + @test adapt(Array, t2' / t) ≈ ht2' / ht + @test adapt(Array, t1 / t') ≈ ht1 / ht' + @test adapt(Array, t1' / t') ≈ ht1' / ht' end=# end end @@ -456,11 +456,11 @@ for V in spacelist #=t = project_hermitian!(AMDGPU.randn(T, W, W)) s = dim(W) @test (@constinferred sqrt(t))^2 ≈ t - @test TensorKit.to_cpu(sqrt(t)) ≈ sqrt(TensorKit.to_cpu(t)) + @test adapt(Array, sqrt(t)) ≈ sqrt(adapt(Array, t)) expt = @constinferred exp(t) - @test TensorKit.to_cpu(expt) ≈ exp(TensorKit.to_cpu(t)) + @test adapt(Array, expt) ≈ exp(adapt(Array, t)) @test exp(@constinferred log(project_hermitian!(expt))) ≈ expt - @test TensorKit.to_cpu(log(project_hermitian!(expt))) ≈ log(TensorKit.to_cpu(expt)) + @test adapt(Array, log(project_hermitian!(expt))) ≈ log(adapt(Array, expt)) @test (@constinferred cos(t))^2 + (@constinferred sin(t))^2 ≈ id(storagetype(t), W) diff --git a/test/cuda/tensors.jl b/test/cuda/tensors.jl index bc40d381e..973118cac 100644 --- a/test/cuda/tensors.jl +++ b/test/cuda/tensors.jl @@ -116,7 +116,7 @@ for V in spacelist for T in (Int, Float32, ComplexF64) t = @constinferred cuRAND.rand(T, W) d = convert(Dict, t) - @test convert(Dict, TensorKit.to_cpu(t)) == d + @test convert(Dict, adapt(Array, t)) == d end end symmetricbraiding && @timedtestset "Basic linear algebra" begin @@ -208,8 +208,8 @@ for V in spacelist t = cuRAND.rand(T, W) t2 = @constinferred cuRAND.rand!(similar(t)) α = rand(T) - @test norm(t, 2) ≈ norm(TensorKit.to_cpu(t), 2) - @test dot(t2, t) ≈ dot(TensorKit.to_cpu(t2), TensorKit.to_cpu(t)) + @test norm(t, 2) ≈ norm(adapt(Array, t), 2) + @test dot(t2, t) ≈ dot(adapt(Array, t2), adapt(Array, t)) @test adapt(Vector{T}, (α * t)) ≈ α * adapt(Vector{T}, t) @test adapt(Vector{T}, (t + t)) ≈ 2 * adapt(Vector{T}, t) end @@ -221,17 +221,17 @@ for V in spacelist tr = @constinferred real(t) @test scalartype(tr) <: Real - @test real(TensorKit.to_cpu(t)) == TensorKit.to_cpu(tr) + @test real(adapt(Array, t)) == adapt(Array, tr) @test storagetype(tr) == CuVector{real(T), CUDA.DeviceMemory} ti = @constinferred imag(t) @test scalartype(ti) <: Real - @test imag(TensorKit.to_cpu(t)) == TensorKit.to_cpu(ti) + @test imag(adapt(Array, t)) == adapt(Array, ti) @test storagetype(ti) == CuVector{real(T), CUDA.DeviceMemory} tc = @inferred complex(t) @test scalartype(tc) <: Complex - @test complex(TensorKit.to_cpu(t)) == TensorKit.to_cpu(tc) + @test complex(adapt(Array, t)) == adapt(Array, tc) @test storagetype(tc) == CuVector{complex(T), CUDA.DeviceMemory} tc2 = @inferred complex(tr, ti) @@ -244,7 +244,7 @@ for V in spacelist W = V1 ⊗ V2 t = @constinferred cuRAND.randn(W ← W) @test typeof(convert(typeof(t), t')) == typeof(t) - @test typeof(TensorKit.to_cpu(t')) == typeof(TensorKit.to_cpu(t)') + @test typeof(adapt(Array, t')) == typeof(adapt(Array, t)') tc = complex(t) @test convert(typeof(tc), t) == tc @test typeof(convert(typeof(tc), t)) == typeof(tc) @@ -297,12 +297,12 @@ for V in spacelist p1 = ntuple(n -> p[n], k) p2 = ntuple(n -> p[k + n], 5 - k) dt2 = permute(t, (p1, p2)) - ht2 = permute(TensorKit.to_cpu(t), (p1, p2)) - @test ht2 ≈ TensorKit.to_cpu(dt2) + ht2 = permute(adapt(Array, t), (p1, p2)) + @test ht2 ≈ adapt(Array, dt2) end dt3 = repartition(t, k) - ht3 = repartition(TensorKit.to_cpu(t), k) - @test ht3 ≈ TensorKit.to_cpu(dt3) + ht3 = repartition(adapt(Array, t), k) + @test ht3 ≈ adapt(Array, dt3) end end symmetricbraiding && @timedtestset "Full trace: test self-consistency" begin @@ -354,10 +354,10 @@ for V in spacelist @tensor dHrA12[a, s1, s2, c] := drhoL[a, a'] * conj(dA1[a', t1, b]) * dA2[b, t2, c'] * drhoR[c', c] * dH[s1, s2, t1, t2] - @tensor hHrA12[a, s1, s2, c] := TensorKit.to_cpu(drhoL)[a, a'] * conj(TensorKit.to_cpu(dA1)[a', t1, b]) * - TensorKit.to_cpu(dA2)[b, t2, c'] * TensorKit.to_cpu(drhoR)[c', c] * - TensorKit.to_cpu(dH)[s1, s2, t1, t2] - @test TensorKit.to_cpu(dHrA12) ≈ hHrA12 + @tensor hHrA12[a, s1, s2, c] := adapt(Array, drhoL)[a, a'] * conj(adapt(Array, dA1)[a', t1, b]) * + adapt(Array, dA2)[b, t2, c'] * adapt(Array, drhoR)[c', c] * + adapt(Array, dH)[s1, s2, t1, t2] + @test adapt(Array, dHrA12) ≈ hHrA12 end end hasbraiding && @timedtestset "Index flipping: test flipping inverse" begin @@ -429,30 +429,30 @@ for V in spacelist t1 = cuRAND.rand(T, W1, W1) t2 = cuRAND.rand(T, W2, W2) t = cuRAND.rand(T, W1, W2) - ht1 = TensorKit.to_cpu(t1) - ht2 = TensorKit.to_cpu(t2) - ht = TensorKit.to_cpu(t) - @test TensorKit.to_cpu(t1 * t) ≈ ht1 * ht - @test TensorKit.to_cpu(t1' * t) ≈ ht1' * ht - @test TensorKit.to_cpu(t2 * t') ≈ ht2 * ht' - @test TensorKit.to_cpu(t2' * t') ≈ ht2' * ht' + ht1 = adapt(Array, t1) + ht2 = adapt(Array, t2) + ht = adapt(Array, t) + @test adapt(Array, t1 * t) ≈ ht1 * ht + @test adapt(Array, t1' * t) ≈ ht1' * ht + @test adapt(Array, t2 * t') ≈ ht2 * ht' + @test adapt(Array, t2' * t') ≈ ht2' * ht' - @test TensorKit.to_cpu(inv(t1)) ≈ inv(ht1) - @test TensorKit.to_cpu(pinv(t)) ≈ pinv(ht) + @test adapt(Array, inv(t1)) ≈ inv(ht1) + @test adapt(Array, pinv(t)) ≈ pinv(ht) if T == Float32 || T == ComplexF32 continue end - @test TensorKit.to_cpu(t1 \ t) ≈ ht1 \ ht - @test TensorKit.to_cpu(t1' \ t) ≈ ht1' \ ht - @test TensorKit.to_cpu(t2 \ t') ≈ ht2 \ ht' - @test TensorKit.to_cpu(t2' \ t') ≈ ht2' \ ht' + @test adapt(Array, t1 \ t) ≈ ht1 \ ht + @test adapt(Array, t1' \ t) ≈ ht1' \ ht + @test adapt(Array, t2 \ t') ≈ ht2 \ ht' + @test adapt(Array, t2' \ t') ≈ ht2' \ ht' - @test TensorKit.to_cpu(t2 / t) ≈ ht2 / ht - @test TensorKit.to_cpu(t2' / t) ≈ ht2' / ht - @test TensorKit.to_cpu(t1 / t') ≈ ht1 / ht' - @test TensorKit.to_cpu(t1' / t') ≈ ht1' / ht' + @test adapt(Array, t2 / t) ≈ ht2 / ht + @test adapt(Array, t2' / t) ≈ ht2' / ht + @test adapt(Array, t1 / t') ≈ ht1 / ht' + @test adapt(Array, t1' / t') ≈ ht1' / ht' end end symmetricbraiding && @timedtestset "Tensor functions" begin @@ -461,14 +461,14 @@ for V in spacelist t = project_hermitian!(cuRAND.randn(T, W, W)) s = dim(W) #@test (@constinferred sqrt(t))^2 ≈ t - #@test TensorKit.to_cpu(sqrt(t)) ≈ sqrt(TensorKit.to_cpu(t)) + #@test adapt(Array, sqrt(t)) ≈ sqrt(adapt(Array, t)) expt = @constinferred exp(t) - @test TensorKit.to_cpu(expt) ≈ exp(TensorKit.to_cpu(t)) + @test adapt(Array, expt) ≈ exp(adapt(Array, t)) # log doesn't work on CUDA yet (scalar indexing) #@test exp(@constinferred log(project_hermitian!(expt))) ≈ expt - #@test TensorKit.to_cpu(log(project_hermitian!(expt))) ≈ log(TensorKit.to_cpu(expt)) + #@test adapt(Array, log(project_hermitian!(expt))) ≈ log(adapt(Array, expt)) #=@test (@constinferred cos(t))^2 + (@constinferred sin(t))^2 ≈ id(storagetype(t), W) From 593206637f1518b0c616dc3c579dc170cec8ca87 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 12 May 2026 10:25:58 -0400 Subject: [PATCH 07/27] Remove unneeded similarstoragetype method --- src/tensors/abstracttensor.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index c517a165e..56c3da839 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -113,7 +113,6 @@ end # implement on arrays similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A -similarstoragetype(::Type{A}, ::Type{A}) where {A <: DenseVector{<:Number}} = A Base.@assume_effects :foldable similarstoragetype(::Type{A}) where {A <: AbstractArray{<:Number}} = Core.Compiler.return_type(similar, Tuple{A, Int}) Base.@assume_effects :foldable similarstoragetype(::Type{A}, ::Type{T}) where {A <: AbstractArray, T <: Number} = From 46e50374ed119978696fadf3c49eb796af0ef71e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 12 May 2026 10:26:36 -0400 Subject: [PATCH 08/27] Add in TensorMap constructor --- ext/TensorKitCUDAExt/truncation.jl | 4 ++-- src/tensors/braidingtensor.jl | 3 +-- src/tensors/tensor.jl | 1 + 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index 019ded97b..666718f4b 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -10,7 +10,7 @@ function MatrixAlgebraKit.findtruncated( fill!(v, dim(c)) end - perm = sortperm(parent(values); strategy.by, strategy.rev) + perm = isempty(parent(values)) ? () : sortperm(parent(values); strategy.by, strategy.rev) cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) result = similar(values, Bool) @@ -36,7 +36,7 @@ function MatrixAlgebraKit.findtruncated( end end - perm = sortperm(parent(values); by = abs, rev = false) + perm = isempty(parent(values)) ? () : sortperm(parent(values); by = abs, rev = false) cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm)) result = similar(values, Bool) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index beccb5fbb..fc9262b58 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -265,8 +265,7 @@ function planarcontract!( ) # special case only defined for contracting all 4 indices of B (2 contracted + 2 open) if length.(pB) != (2, 2) - tB′ = TensorMap(B) - tB = TensorMapWithStorage{eltype(B), similarstoragetype(A, eltype(B)), spacetype(tB′), numout(tB′), numin(tB′)}(tB′) + tB = TensorMapWithStorage{eltype(B), similarstoragetype(A, eltype(B)), spacetype(tB′), numout(tB′), numin(tB′)}(B) return planarcontract!(C, A, pA, tB, pB, pAB, α, β, backend, allocator) end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 7a94941c1..531c58638 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -35,6 +35,7 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac end # constructors from another TensorMap -- no-op TensorMap{T, S, N₁, N₂, A}(t::TensorMap{T, S, N₁, N₂, A}) where {T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} = t +TensorMap{T, S, N₁, N₂, A}(t::TensorMap{T, S, N₁, N₂}) where {T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} = TensorMap(A(t.data), space(t)) """ Tensor{T, S, N, A<:DenseVector{T}} = TensorMap{T, S, N, 0, A} From 9fdc185b3f7d89c90ca224e1695eedbc2320c49f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 12 May 2026 12:56:32 -0400 Subject: [PATCH 09/27] Restore former braiding tensor methods --- src/tensors/braidingtensor.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index fc9262b58..ca00295f8 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -189,15 +189,13 @@ end has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false function add_transform!( tdst::AbstractTensorMap, - tsrc::BraidingTensor{T, S}, + tsrc::BraidingTensor{T, S, A}, (p₁, p₂)::Index2Tuple, fusiontreetransform, α::Number, β::Number, backend::AbstractBackend... - ) where {T, S} - tsrc_map = similar(tdst, storagetype(tdst), space(tsrc)) - copy!(tsrc_map, tsrc) + ) where {T, S, A} return add_transform!( - tdst, tsrc_map, (p₁, p₂), fusiontreetransform, α, β, + tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β, backend... ) end @@ -265,8 +263,7 @@ function planarcontract!( ) # special case only defined for contracting all 4 indices of B (2 contracted + 2 open) if length.(pB) != (2, 2) - tB = TensorMapWithStorage{eltype(B), similarstoragetype(A, eltype(B)), spacetype(tB′), numout(tB′), numin(tB′)}(B) - return planarcontract!(C, A, pA, tB, pB, pAB, α, β, backend, allocator) + return planarcontract!(C, A, pA, TensorMap(B), pB, pAB, α, β, backend, allocator) end spacecheck_contract(C, A, pA, false, B, pB, false, pAB) From e948370907cd9fc5166ae70322299d997be0fcac Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 13 May 2026 03:58:48 -0400 Subject: [PATCH 10/27] Fix type issue for sortperm --- ext/TensorKitCUDAExt/truncation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index 666718f4b..42d623730 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -10,7 +10,7 @@ function MatrixAlgebraKit.findtruncated( fill!(v, dim(c)) end - perm = isempty(parent(values)) ? () : sortperm(parent(values); strategy.by, strategy.rev) + perm = isempty(parent(values)) ? Int64[] : sortperm(parent(values); strategy.by, strategy.rev) cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) result = similar(values, Bool) @@ -36,7 +36,7 @@ function MatrixAlgebraKit.findtruncated( end end - perm = isempty(parent(values)) ? () : sortperm(parent(values); by = abs, rev = false) + perm = isempty(parent(values)) ? Int64[] : sortperm(parent(values); by = abs, rev = false) cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm)) result = similar(values, Bool) From 28dca1ea4d3f1ebb42983153b7c6208ce4ed9b41 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 13 May 2026 08:23:25 -0400 Subject: [PATCH 11/27] Remove stale type params --- src/tensors/braidingtensor.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index ca00295f8..e519a4df3 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -189,11 +189,11 @@ end has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false function add_transform!( tdst::AbstractTensorMap, - tsrc::BraidingTensor{T, S, A}, + tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple, fusiontreetransform, α::Number, β::Number, backend::AbstractBackend... - ) where {T, S, A} + ) return add_transform!( tdst, TensorMap(tsrc), (p₁, p₂), fusiontreetransform, α, β, backend... From 259309ea65fca86eb711e4dcacbce57d57382714 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 14 May 2026 15:00:35 +0200 Subject: [PATCH 12/27] Apply suggestions from code review Co-authored-by: Lukas Devos --- src/tensors/braidingtensor.jl | 8 ++++---- src/tensors/indexmanipulations.jl | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index e519a4df3..ce0b99986 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -191,7 +191,7 @@ function add_transform!( tdst::AbstractTensorMap, tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple, - fusiontreetransform, + tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple, α::Number, β::Number, backend::AbstractBackend... ) return add_transform!( @@ -262,9 +262,9 @@ function planarcontract!( backend, allocator ) # special case only defined for contracting all 4 indices of B (2 contracted + 2 open) - if length.(pB) != (2, 2) - return planarcontract!(C, A, pA, TensorMap(B), pB, pAB, α, β, backend, allocator) - end + length.(pB) == (2, 2) || + planarcontract!(C, A, pA, TensorMap(B), pB, pAB, α, β, backend, allocator) + spacecheck_contract(C, A, pA, false, B, pB, false, pAB) diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index ac5043b44..f7e2f9ecf 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -344,7 +344,7 @@ See also [`insertrightunit`](@ref insertrightunit(::AbstractTensorMap, ::Val{i}) """ function insertleftunit( t::AbstractTensorMap, ::Val{i} = Val(numind(t) + 1); - copy::Bool = false, conj::Bool = false, dual::Bool = false, + copy::Bool = false, conj::Bool = false, dual::Bool = false ) where {i} W = insertleftunit(space(t), Val(i); conj, dual) if t isa TensorMap @@ -373,7 +373,7 @@ See also [`insertleftunit`](@ref insertleftunit(::AbstractTensorMap, ::Val{i}) w """ function insertrightunit( t::AbstractTensorMap, ::Val{i} = Val(numind(t)); - copy::Bool = false, conj::Bool = false, dual::Bool = false, + copy::Bool = false, conj::Bool = false, dual::Bool = false ) where {i} W = insertrightunit(space(t), Val(i); conj, dual) if t isa TensorMap From f23ce5c3d791d23744430cc79480f9b5b994bbe6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 14 May 2026 11:06:36 -0400 Subject: [PATCH 13/27] Fix bad result of suggestion --- src/tensors/braidingtensor.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index ce0b99986..cf875cdbe 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -189,8 +189,6 @@ end has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false function add_transform!( tdst::AbstractTensorMap, - tsrc::BraidingTensor, - (p₁, p₂)::Index2Tuple, tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple, α::Number, β::Number, backend::AbstractBackend... ) From 3504de358b34aaa69e516243714cdd34bb807a2d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 14 May 2026 13:12:29 -0400 Subject: [PATCH 14/27] Another fix? --- src/tensors/braidingtensor.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index cf875cdbe..f12123384 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -261,8 +261,7 @@ function planarcontract!( ) # special case only defined for contracting all 4 indices of B (2 contracted + 2 open) length.(pB) == (2, 2) || - planarcontract!(C, A, pA, TensorMap(B), pB, pAB, α, β, backend, allocator) - + return planarcontract!(C, A, pA, TensorMap(B), pB, pAB, α, β, backend, allocator) spacecheck_contract(C, A, pA, false, B, pB, false, pAB) From 830d9551a7b7eb25ed6af9350b49427f4aab9d17 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 16 May 2026 04:32:21 -0400 Subject: [PATCH 15/27] Force inds to move back to the CPU --- ext/TensorKitCUDAExt/truncation.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index 42d623730..1dcfb14fc 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -44,6 +44,18 @@ function MatrixAlgebraKit.findtruncated( return result end +function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy} + # returning a CuSectorVector wrecks things in truncate_{co}domain + # because of scalar indexing + return CUDA.CUDACore.Adapt.adapt(Vector{Int}, MatrixAlgebraKit.findtruncated(values, strategy)) +end + +function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue) + atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) + strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) + return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector{Int}, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values)) +end + # Needed until MatrixAlgebraKit patch hits... function MatrixAlgebraKit._ind_intersect(A::CuVector{Bool}, B::CuVector{Int}) result = fill!(similar(A), false) From 1fe40a7143a6eb19db84d87890f08dbed0919ed8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 16 May 2026 04:32:51 -0400 Subject: [PATCH 16/27] Return to glorious scalartype --- src/tensors/abstracttensor.jl | 6 ------ src/tensors/tensoroperations.jl | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 56c3da839..edbcfe662 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -104,12 +104,6 @@ similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoraget function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} return similarstoragetype(storagetype(TT), T) end -function similarstoragetype(::Type{<:AbstractTensorMap{T, S, N₁, N₂}}, ::Type{TA}) where {T <: Number, TA <: DenseVector, S, N₁, N₂} - return similarstoragetype(TA, T) -end -function similarstoragetype(t::AbstractTensorMap{T, S, N₁, N₂}, ::Type{TA}) where {T <: Number, TA <: DenseVector, S, N₁, N₂} - return similarstoragetype(typeof(t), TA) -end # implement on arrays similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index f42f7d94f..3fc79cf0c 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -379,7 +379,7 @@ function blas_contract!( bstyle = BraidingStyle(sectortype(C)) bstyle isa SymmetricBraiding || throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead")) - TC = storagetype(C) # without this, Anew below has wrong storagetype + TC = scalartype(C) # check which tensors have to be permuted/copied copyA = !(TO.isblascontractable(A, pA) && scalartype(A) === TC) From 55d25b5e77f07e3b28c1bb0274d9959ed4707f5d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 16 May 2026 10:41:20 -0400 Subject: [PATCH 17/27] Restore DiagonalTensorMap ctor --- src/tensors/diagonal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/diagonal.jl b/src/tensors/diagonal.jl index e73ad2787..b2ac4134b 100644 --- a/src/tensors/diagonal.jl +++ b/src/tensors/diagonal.jl @@ -280,7 +280,7 @@ end # ---------------- function TO.tensoradd_type(TC, A::DiagonalTensorMap, ::Index2Tuple{1, 1}, ::Bool) M = similarstoragetype(A, TC) - return DiagonalTensorMap{scalartype(M), spacetype(A), M} + return DiagonalTensorMap{TC, spacetype(A), M} end function TO.tensorcontract_type( From a08bd232e5c1c98f7f880cd5be0295c17209b904 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 16 May 2026 10:41:30 -0400 Subject: [PATCH 18/27] Resolve trunc ambiguity --- ext/TensorKitCUDAExt/truncation.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index 1dcfb14fc..fa82c7751 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -50,6 +50,12 @@ function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) return CUDA.CUDACore.Adapt.adapt(Vector{Int}, MatrixAlgebraKit.findtruncated(values, strategy)) end +function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByOrder) + # returning a CuSectorVector wrecks things in truncate_{co}domain + # because of scalar indexing + return CUDA.CUDACore.Adapt.adapt(Vector{Int}, MatrixAlgebraKit.findtruncated(values, strategy)) +end + function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue) atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) From 921ffc65729269cf056a403b3218b0c681fafbf7 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 16 May 2026 11:53:47 -0400 Subject: [PATCH 19/27] Remove extra CUDA ctor --- ext/TensorKitCUDAExt/cutensormap.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index a1d2d1b66..02ca2d5d6 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -6,9 +6,6 @@ const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A} return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t)) end -function TensorMap{T, S, N₁, N₂, DA}(t::TensorMap{T, S, N₁, N₂, HA}) where {T, S, N₁, N₂, DA <: CuArray{T}, HA <: Array{T}} - return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t)) -end # project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}} From ed079694fdbadfc9d9c566d1efb762b1ee136dc5 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 16 May 2026 11:55:35 -0400 Subject: [PATCH 20/27] Restore chopped argument --- src/tensors/braidingtensor.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tensors/braidingtensor.jl b/src/tensors/braidingtensor.jl index f12123384..d28b2e1df 100644 --- a/src/tensors/braidingtensor.jl +++ b/src/tensors/braidingtensor.jl @@ -190,6 +190,7 @@ has_shared_permute(t::BraidingTensor, ::Index2Tuple) = false function add_transform!( tdst::AbstractTensorMap, tsrc::BraidingTensor, (p₁, p₂)::Index2Tuple, + fusiontreetransform, α::Number, β::Number, backend::AbstractBackend... ) return add_transform!( From 285ac16641f0a118dfa421651977513d884a5ac0 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 16 May 2026 13:18:38 -0400 Subject: [PATCH 21/27] Also remove no longer needed method --- src/tensors/indexmanipulations.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tensors/indexmanipulations.jl b/src/tensors/indexmanipulations.jl index f7e2f9ecf..3108abb17 100644 --- a/src/tensors/indexmanipulations.jl +++ b/src/tensors/indexmanipulations.jl @@ -17,8 +17,6 @@ for (operation, manipulation) in ( $promote_op(::Type{T}, ::Type{I}) where {T <: Number, I <: Sector} = sectorscalartype(I) <: Integer ? T : sectorscalartype(I) <: Real ? float(T) : complex(T) - $promote_op(::Type{TA}, ::Type{I}) where {TA <: DenseVector, I <: Sector} = - similarstoragetype(TA, $promote_op(eltype(TA), I)) # TODO: currently the manipulations all use sectorscalartype, change to: # $manipulation_scalartype(I) <: Integer ? T : # $manipulation_scalartype(I) <: Real ? float(T) : complex(T) From ae018d1f06b02782500e6f35450c0143420eba13 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sun, 17 May 2026 01:22:29 -0400 Subject: [PATCH 22/27] Remove forced Int eltype --- ext/TensorKitCUDAExt/truncation.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index fa82c7751..73676b579 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -47,19 +47,19 @@ end function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) where {S <: MatrixAlgebraKit.TruncationStrategy} # returning a CuSectorVector wrecks things in truncate_{co}domain # because of scalar indexing - return CUDA.CUDACore.Adapt.adapt(Vector{Int}, MatrixAlgebraKit.findtruncated(values, strategy)) + return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) end function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByOrder) # returning a CuSectorVector wrecks things in truncate_{co}domain # because of scalar indexing - return CUDA.CUDACore.Adapt.adapt(Vector{Int}, MatrixAlgebraKit.findtruncated(values, strategy)) + return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) end function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue) atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) - return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector{Int}, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values)) + return SectorDict(c => CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated_svd(d, strategy′)) for (c, d) in pairs(values)) end # Needed until MatrixAlgebraKit patch hits... From 248f8a3141a37201fde5122b3e1a1a88bfe5202e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 18 May 2026 09:18:07 -0400 Subject: [PATCH 23/27] Get rid of no-op ctor --- src/tensors/tensor.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 531c58638..bd6609163 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -33,8 +33,6 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac return new{T, S, N₁, N₂, A}(data, space) end end -# constructors from another TensorMap -- no-op -TensorMap{T, S, N₁, N₂, A}(t::TensorMap{T, S, N₁, N₂, A}) where {T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} = t TensorMap{T, S, N₁, N₂, A}(t::TensorMap{T, S, N₁, N₂}) where {T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} = TensorMap(A(t.data), space(t)) """ From 06d1ac5a22bf7b66fcf63f3399ed243e2d3724f0 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 18 May 2026 10:03:33 -0400 Subject: [PATCH 24/27] Try to resolve ambiguity --- ext/TensorKitCUDAExt/truncation.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index 73676b579..31f46ccd4 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -56,6 +56,12 @@ function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::Ma return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) end +function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::TensorKit.Factorizations.TruncationSpace) + # returning a CuSectorVector wrecks things in truncate_{co}domain + # because of scalar indexing + return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) +end + function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue) atol = TensorKit.Factorizations.rtol_to_atol(values, strategy.p, strategy.atol, strategy.rtol) strategy′ = trunctol(; atol, strategy.by, strategy.keep_below) From 8efbb641fc0752e58dc2529e56c15bae1deada9a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 19 May 2026 03:28:24 -0400 Subject: [PATCH 25/27] Cover all truncation strategies --- ext/TensorKitCUDAExt/truncation.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index 31f46ccd4..dac852444 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -50,16 +50,12 @@ function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::S) return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) end -function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByOrder) - # returning a CuSectorVector wrecks things in truncate_{co}domain - # because of scalar indexing - return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) -end - -function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::TensorKit.Factorizations.TruncationSpace) - # returning a CuSectorVector wrecks things in truncate_{co}domain - # because of scalar indexing - return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) +for strat in (:(MatrixAlgebraKit.TruncationByOrder), :(MatrixAlgebraKit.TruncationByError), :(MatrixAlgebraKit.TruncationIntersection), :(TensorKit.Factorizations.TruncationSpace)) + @eval function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::$strat) + # returning a CuSectorVector wrecks things in truncate_{co}domain + # because of scalar indexing + return CUDA.CUDACore.Adapt.adapt(Vector, MatrixAlgebraKit.findtruncated(values, strategy)) + end end function MatrixAlgebraKit.findtruncated_svd(values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByValue) From fa16c1a8f3bc8b427eae275248a3f4ad5133b8cd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 19 May 2026 13:58:37 +0200 Subject: [PATCH 26/27] Short-circuit logic in `findtruncated` Co-authored-by: Lukas Devos --- ext/TensorKitCUDAExt/truncation.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index dac852444..e9d5ac9a7 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -10,7 +10,9 @@ function MatrixAlgebraKit.findtruncated( fill!(v, dim(c)) end - perm = isempty(parent(values)) ? Int64[] : sortperm(parent(values); strategy.by, strategy.rev) + isempty(parent(values)) && return similar(values, Bool) + + perm = sortperm(parent(values); strategy.by, strategy.rev) cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) result = similar(values, Bool) @@ -36,7 +38,9 @@ function MatrixAlgebraKit.findtruncated( end end - perm = isempty(parent(values)) ? Int64[] : sortperm(parent(values); by = abs, rev = false) + isempty(parent(values)) && return similar(values, Bool) + + perm = sortperm(parent(values); by = abs, rev = false) cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm)) result = similar(values, Bool) From 0d844fbfbe650d74d8b2028896c9c58d1e5ce55a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 19 May 2026 08:34:05 -0400 Subject: [PATCH 27/27] Formatter --- ext/TensorKitCUDAExt/truncation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorKitCUDAExt/truncation.jl b/ext/TensorKitCUDAExt/truncation.jl index e9d5ac9a7..2633b3483 100644 --- a/ext/TensorKitCUDAExt/truncation.jl +++ b/ext/TensorKitCUDAExt/truncation.jl @@ -11,7 +11,7 @@ function MatrixAlgebraKit.findtruncated( end isempty(parent(values)) && return similar(values, Bool) - + perm = sortperm(parent(values); strategy.by, strategy.rev) cumulative_dim = cumsum(Base.permute!(parent(dims), perm)) @@ -39,7 +39,7 @@ function MatrixAlgebraKit.findtruncated( end isempty(parent(values)) && return similar(values, Bool) - + perm = sortperm(parent(values); by = abs, rev = false) cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm))