From 1dbdd10636732c120237607457650ea8cc1228c2 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 3 May 2026 17:32:52 -0400 Subject: [PATCH 1/4] periodic indexing for BP environments --- src/environments/bp_environments.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/environments/bp_environments.jl b/src/environments/bp_environments.jl index ea7ae9301..b8a3b297d 100644 --- a/src/environments/bp_environments.jl +++ b/src/environments/bp_environments.jl @@ -145,12 +145,36 @@ end Base.eltype(::Type{BPEnv{T}}) where {T} = T Base.size(env::BPEnv, args...) = size(env.messages, args...) +function Base.getindex(env::BPEnv, dir::Int, row::Int, col::Int) + return env.messages[dir, mod1(row, size(env, 2)), mod1(col, size(env, 3))] +end +Base.getindex(env::BPEnv, dir::Int, idx::CartesianIndex{2}) = env[dir, idx[1], idx[2]] +Base.getindex(env::BPEnv, idx::CartesianIndex{3}) = env[idx[1], idx[2], idx[3]] Base.getindex(env::BPEnv, args...) = Base.getindex(env.messages, args...) + +function Base.setindex!(env::BPEnv, val, dir::Int, row::Int, col::Int) + env.messages[dir, mod1(row, size(env, 2)), mod1(col, size(env, 3))] = val + return env +end +function Base.setindex!(env::BPEnv, val, dir::Int, idx::CartesianIndex{2}) + env[dir, idx[1], idx[2]] = val + return env +end +function Base.setindex!(env::BPEnv, val, idx::CartesianIndex{3}) + env[idx[1], idx[2], idx[3]] = val + return env +end +function Base.setindex!(env::BPEnv, val, args...) + Base.setindex!(env.messages, val, args...) + return env +end Base.axes(env::BPEnv, args...) = Base.axes(env.messages, args...) Base.eachindex(env::BPEnv) = eachindex(IndexCartesian(), env.messages) VectorInterface.scalartype(::Type{BPEnv{T}}) where {T} = scalartype(T) TensorKit.spacetype(::Type{BPEnv{T}}) where {T} = spacetype(T) +Base.copy(x::BPEnv) = BPEnv(copy.(x.messages)) + function eachcoordinate(x::BPEnv) return collect(Iterators.product(axes(x, 2), axes(x, 3))) end From fe5ac55d37a24eb99267e53c5e7030ab0c31ece0 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 3 May 2026 17:33:09 -0400 Subject: [PATCH 2/4] gate application for BP --- src/PEPSKit.jl | 1 + src/algorithms/bp/simpleupdate.jl | 182 ++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+) create mode 100644 src/algorithms/bp/simpleupdate.jl diff --git a/src/PEPSKit.jl b/src/PEPSKit.jl index 4d2f583a1..6dcee2d43 100644 --- a/src/PEPSKit.jl +++ b/src/PEPSKit.jl @@ -115,6 +115,7 @@ include("algorithms/time_evolution/gaugefix_su.jl") include("algorithms/bp/beliefpropagation.jl") include("algorithms/bp/gaugefix.jl") +include("algorithms/bp/simpleupdate.jl") include("algorithms/transfermatrix.jl") include("algorithms/toolbox.jl") diff --git a/src/algorithms/bp/simpleupdate.jl b/src/algorithms/bp/simpleupdate.jl new file mode 100644 index 000000000..9262e5ac4 --- /dev/null +++ b/src/algorithms/bp/simpleupdate.jl @@ -0,0 +1,182 @@ +# Apply a gate - belief-propagation style +# --------------------------------------- + +const Gate{N, T, S} = AbstractTensorMap{T, S, N, N} + +@doc """ + apply(state, circuit::LocalCircuit, alg::SimpleUpdate, messages::BPEnv) -> state, messages, ϵ + apply!(state, circuit::LocalCircuit, alg::SimpleUpdate, messages::BPEnv) -> state, messages, ϵ + apply(state, (sites, gate)::Pair, alg::SimpleUpdate, messages::BPEnv) -> state, messages, ϵ + apply!(state, (sites, gate)::Pair, alg::SimpleUpdate, messages::BPEnv) -> state, messages, ϵ + +Apply a gate (or a `LocalCircuit` of gates) to `state` using belief-propagation simple update. +This means that the square roots of the surrounding messages are absorbed into the acted-on PEPS tensors, +the gate is applied on the reduced bond, the result is truncated by SVD according to `alg.trunc`, +and the new singular value spectrum replaces the bond message between the acted sites. + +## Returns + +- `state`: updated `InfinitePEPS`. +- `messages`: updated `BPEnv` +- `ϵ`: (maximal) local truncation error. +""" apply(state, site_gate, alg, messages::BPEnv), apply!(state, site_gate, alg, messages::BPEnv) + +apply(state, site_gate, alg, messages::BPEnv) = + apply!(copy(state), site_gate, alg, copy(messages)) + +function apply!(state, circuit::LocalCircuit, alg, messages::BPEnv) + ϵ = zero(real(scalartype(state))) + for (site, gate) in circuit.gates + state, messages, ϵ_local = apply!(state, (site => gate), alg, messages) + ϵ = max(ϵ, ϵ_local) + end + return state, messages, ϵ +end +function apply!( + state, (sites, gate)::Pair{Vector{CartesianIndex{2}}, <:Gate}, alg, messages::BPEnv + ) + if length(sites) == 1 + return apply_gate_1x1!(state, sites, gate, alg, messages) + elseif length(sites) == 2 + diff = sites[2] - sites[1] + if diff == CartesianIndex(-1, 0) || diff == CartesianIndex(1, 0) + return apply_gate_2x1!(state, sites, gate, alg, messages) + elseif diff == CartesianIndex(0, 1) || diff == CartesianIndex(0, -1) + return apply_gate_1x2!(state, sites, gate, alg, messages) + end + end + + error("Not implemented: $sites") +end + +function apply_gate_1x1!( + state, sites::Vector{CartesianIndex{2}}, gate::Gate{1}, + alg::SimpleUpdate, messages::BPEnv + ) + length(sites) == 1 || throw(ArgumentError("invalid sites: $sites")) + site = only(sites) + state[site] = gate * state[site] + return state, messages, zero(real(scalartype(state))) +end + +function apply_gate_1x2!( + state, sites::Vector{CartesianIndex{2}}, gate::Gate{2}, alg::SimpleUpdate, messages::BPEnv + ) + length(sites) == 2 || throw(ArgumentError("invalid sites: $sites")) + + # Note that we must truncate along the canonical arrow direction to not alter the result, + # so we use west - east here and map east - west to the canonical direction. + diff = sites[2] - sites[1] + if diff == CartesianIndex(0, -1) + sites = reverse(sites) + gate = permute(gate, ((2, 1), (4, 3))) + elseif diff != CartesianIndex(0, 1) + throw(ArgumentError("invalid sites: $sites")) + end + + # extract tensors and square roots of the messages + siteL, siteR = sites + A_L = state[mod1.(Tuple(siteL), size(state))...] + A_R = state[mod1.(Tuple(siteR), size(state))...] + sqrtMN_L, isqrtMN_L = sqrt_invsqrt(messages[NORTH, siteL - CartesianIndex(1, 0)]) + sqrtMN_R, isqrtMN_R = sqrt_invsqrt(messages[NORTH, siteR - CartesianIndex(1, 0)]) + sqrtME, isqrtME = sqrt_invsqrt(messages[EAST, siteR + CartesianIndex(0, 1)]) + sqrtMS_L, isqrtMS_L = sqrt_invsqrt(messages[SOUTH, siteL + CartesianIndex(1, 0)]) + sqrtMS_R, isqrtMS_R = sqrt_invsqrt(messages[SOUTH, siteR + CartesianIndex(1, 0)]) + sqrtMW, isqrtMW = sqrt_invsqrt(messages[WEST, siteL - CartesianIndex(0, 1)]) + + # settings + trunc = only(_get_cluster_trunc(alg.trunc, sites, size(state))) + + # absorb message tensors + @tensor T_L[N W S; E P] := A_L[P; N' E S' W'] * sqrtMN_L[N'; N] * + sqrtMS_L[S; S'] * sqrtMW[W; W'] + @tensor T_R[W P; N E S] := A_R[P; N' E' S' W] * sqrtMN_R[N'; N] * + sqrtMS_R[S; S'] * sqrtME[E'; E] + + # separate off the indices that are acted on for efficiency + Q_L, RQ_L = left_orth!(T_L; positive = true) + LQ_R, Q_R = right_orth!(T_R; positive = true) + + # apply gate + @tensor gated[-1 -2; -3 -4] := RQ_L[-1; 1 2] * gate[-2 -4; 2 3] * LQ_R[1 3; -3] + U, S, Vᴴ, ϵ = svd_trunc!(gated; trunc) + + sqrtS = sqrt(S) + U′ = rmul!(U, sqrtS) + Vᴴ′ = lmul!(sqrtS, Vᴴ) + + # extract new PEPS tensors + @tensor A_L′[P; N E S W] := Q_L[N' W' S'; D] * U′[D P; E] * + isqrtMN_L[N'; N] * isqrtMS_L[S; S'] * isqrtMW[W; W'] + @tensor A_R′[P; N E S W] := Q_R[D; N' E' S'] * Vᴴ′[W; D P] * + isqrtMN_R[N'; N] * isqrtMS_R[S; S'] * isqrtME[E'; E] + + # insert tensors + state[mod1.(Tuple(siteL), size(state))...] = A_L′ + state[mod1.(Tuple(siteR), size(state))...] = A_R′ + messages[EAST, siteR] = messages[WEST, siteL] = S + + return state, messages, ϵ +end + +function apply_gate_2x1!( + state, sites::Vector{CartesianIndex{2}}, gate::Gate{2}, alg::SimpleUpdate, messages::BPEnv + ) + length(sites) == 2 || throw(ArgumentError("invalid sites: $sites")) + + # Note that we must truncate along the canonical arrow direction to not alter the result, + # so we use south - north here and map north - south to the canonical direction. + diff = sites[2] - sites[1] + if diff == CartesianIndex(1, 0) + sites = reverse(sites) + gate = permute(gate, ((2, 1), (4, 3))) + elseif diff != CartesianIndex(-1, 0) + throw(ArgumentError("invalid sites: $sites")) + end + + # extract tensors and square roots of the messages + siteB, siteT = sites + A_B = state[mod1.(Tuple(siteB), size(state))...] + A_T = state[mod1.(Tuple(siteT), size(state))...] + sqrtMN, isqrtMN = sqrt_invsqrt(messages[NORTH, siteT - CartesianIndex(1, 0)]) + sqrtME_T, isqrtME_T = sqrt_invsqrt(messages[EAST, siteT + CartesianIndex(0, 1)]) + sqrtME_B, isqrtME_B = sqrt_invsqrt(messages[EAST, siteB + CartesianIndex(0, 1)]) + sqrtMS, isqrtMS = sqrt_invsqrt(messages[SOUTH, siteB + CartesianIndex(1, 0)]) + sqrtMW_T, isqrtMW_T = sqrt_invsqrt(messages[WEST, siteT - CartesianIndex(0, 1)]) + sqrtMW_B, isqrtMW_B = sqrt_invsqrt(messages[WEST, siteB - CartesianIndex(0, 1)]) + + # settings + trunc = only(_get_cluster_trunc(alg.trunc, sites, size(state))) + + # absorb message tensors (leave the inner S ← N bond free) + @tensor T_B[E S W; N P] := A_B[P; N E' S' W'] * sqrtME_B[E'; E] * + sqrtMS[S; S'] * sqrtMW_B[W; W'] + @tensor T_T[S P; N E W] := A_T[P; N' E' S W'] * sqrtMN[N'; N] * + sqrtME_T[E'; E] * sqrtMW_T[W; W'] + + # separate off the indices that are acted on for efficiency + Q_B, RQ_B = left_orth!(T_B; positive = true) + LQ_T, Q_T = right_orth!(T_T; positive = true) + + # apply gate + @tensor gated[-1 -2; -3 -4] := RQ_B[-1; 1 2] * gate[-2 -4; 2 3] * LQ_T[1 3; -3] + U, S, Vᴴ, ϵ = svd_trunc!(gated; trunc) + + sqrtS = sqrt(S) + U′ = rmul!(U, sqrtS) + Vᴴ′ = lmul!(sqrtS, Vᴴ) + + # extract new PEPS tensors + @tensor A_B′[P; N E S W] := Q_B[E' S' W'; D] * U′[D P; N] * + isqrtME_B[E'; E] * isqrtMS[S; S'] * isqrtMW_B[W; W'] + @tensor A_T′[P; N E S W] := Q_T[D; N' E' W'] * Vᴴ′[S; D P] * + isqrtMN[N'; N] * isqrtME_T[E'; E] * isqrtMW_T[W; W'] + + # insert tensors + state[mod1.(Tuple(siteB), size(state))...] = A_B′ + state[mod1.(Tuple(siteT), size(state))...] = A_T′ + messages[NORTH, siteT] = messages[SOUTH, siteB] = S + + return state, messages, ϵ +end From 5c4f4a66f8187f0c3c15d04d4af4b241bd4bf36b Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 3 May 2026 17:33:18 -0400 Subject: [PATCH 3/4] rudimentary tests --- test/bp/simpleupdate.jl | 47 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 test/bp/simpleupdate.jl diff --git a/test/bp/simpleupdate.jl b/test/bp/simpleupdate.jl new file mode 100644 index 000000000..156af01fb --- /dev/null +++ b/test/bp/simpleupdate.jl @@ -0,0 +1,47 @@ +using Test +using Random +using LinearAlgebra +using TensorKit +using Test +using TensorKit +using MPSKitModels: S_exchange +using PEPSKit + +using Random + +Random.seed!(1234) + +# ------------------------------------------------------------------------------------- +# Setup +# ------------------------------------------------------------------------------------- + +elt = Float64 +Dspace = ComplexSpace(2) + +H = real(heisenberg_XYZ(InfiniteSquare(2, 2); Jx = 1.0, Jy = 1.0, Jz = 1.0)) +peps = InfinitePEPS(randn, elt, physicalspace(H), fill(Dspace, 2, 2)) + +O = S_exchange(); + +# Compute BP messages +# ------------------- +messages = BPEnv(peps) +bp_alg = BeliefPropagation(; tol = 1.0e-10, maxiter = 100) +messages, bp_error = leading_boundary(messages, peps, bp_alg) + +E₀ = PEPSKit.expectation_value(peps, H, messages) + +dt = 0.01 +circuit = PEPSKit.trotterize(H, dt) + +su_alg = SimpleUpdate(; trunc = truncrank(10)) + +for _ in 1:100 + normalize!.(peps.A) + normalize!.(messages.messages) + peps, messages, ϵ = PEPSKit.apply!(peps, circuit, su_alg, messages) +end + +E = PEPSKit.expectation_value(peps′, H, messages′) + +leading_boundary(messages′, peps′, bp_alg) From bffc9ffe1263b0f823ea5139d509f22f1ad60aec Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 4 May 2026 16:07:38 -0400 Subject: [PATCH 4/4] WIP: add PEPO support --- src/algorithms/bp/simpleupdate.jl | 96 ++++++++++++++++++++++++++----- 1 file changed, 81 insertions(+), 15 deletions(-) diff --git a/src/algorithms/bp/simpleupdate.jl b/src/algorithms/bp/simpleupdate.jl index 9262e5ac4..2af414256 100644 --- a/src/algorithms/bp/simpleupdate.jl +++ b/src/algorithms/bp/simpleupdate.jl @@ -74,20 +74,33 @@ function apply_gate_1x2!( throw(ArgumentError("invalid sites: $sites")) end - # extract tensors and square roots of the messages - siteL, siteR = sites - A_L = state[mod1.(Tuple(siteL), size(state))...] - A_R = state[mod1.(Tuple(siteR), size(state))...] - sqrtMN_L, isqrtMN_L = sqrt_invsqrt(messages[NORTH, siteL - CartesianIndex(1, 0)]) - sqrtMN_R, isqrtMN_R = sqrt_invsqrt(messages[NORTH, siteR - CartesianIndex(1, 0)]) - sqrtME, isqrtME = sqrt_invsqrt(messages[EAST, siteR + CartesianIndex(0, 1)]) - sqrtMS_L, isqrtMS_L = sqrt_invsqrt(messages[SOUTH, siteL + CartesianIndex(1, 0)]) - sqrtMS_R, isqrtMS_R = sqrt_invsqrt(messages[SOUTH, siteR + CartesianIndex(1, 0)]) - sqrtMW, isqrtMW = sqrt_invsqrt(messages[WEST, siteL - CartesianIndex(0, 1)]) + # extract tensors and messages + siteL = CartesianIndex(mod1.(Tuple(sites[1]), (size(state, 1), size(state, 2)))...) + siteR = CartesianIndex(mod1.(Tuple(sites[2]), (size(state, 1), size(state, 2)))...) + A_L = state[siteL] + A_R = state[siteR] + MN_L = messages[NORTH, siteL - CartesianIndex(1, 0)] + MN_R = messages[NORTH, siteR - CartesianIndex(1, 0)] + ME = messages[EAST, siteR + CartesianIndex(0, 1)] + MS_L = messages[SOUTH, siteL + CartesianIndex(1, 0)] + MS_R = messages[SOUTH, siteR + CartesianIndex(1, 0)] + MW = messages[WEST, siteL - CartesianIndex(0, 1)] # settings trunc = only(_get_cluster_trunc(alg.trunc, sites, size(state))) + state[siteL], state[siteR], messages[EAST, siteR], ϵ = + _apply_gate_1x2((A_L, A_R), gate, (MN_L, MN_R, ME, MS_L, MS_R, MW); trunc) + + return state, messages, ϵ +end + +function _apply_gate_1x2((A_L, A_R)::NTuple{2, PEPSTensor}, gate, (MN_L, MN_R, ME, MS_L, MS_R, MW); trunc) + # compute square roots + (sqrtMN_L, isqrtMN_L), (sqrtMN_R, isqrtMN_R), (sqrtME, isqrtME), + (sqrtMS_L, isqrtMS_L), (sqrtMS_R, isqrtMS_R), (sqrtMW, isqrtMW) = + sqrt_invsqrt.((MN_L, MN_R, ME, MS_L, MS_R, MW)) + # absorb message tensors @tensor T_L[N W S; E P] := A_L[P; N' E S' W'] * sqrtMN_L[N'; N] * sqrtMS_L[S; S'] * sqrtMW[W; W'] @@ -112,12 +125,65 @@ function apply_gate_1x2!( @tensor A_R′[P; N E S W] := Q_R[D; N' E' S'] * Vᴴ′[W; D P] * isqrtMN_R[N'; N] * isqrtMS_R[S; S'] * isqrtME[E'; E] - # insert tensors - state[mod1.(Tuple(siteL), size(state))...] = A_L′ - state[mod1.(Tuple(siteR), size(state))...] = A_R′ - messages[EAST, siteR] = messages[WEST, siteL] = S + return A_L′, A_R′, S, ϵ +end +function _apply_gate_1x2((A_L, A_R)::NTuple{2, PEPOTensor}, gate, (MN_L, MN_R, ME, MS_L, MS_R, MW); trunc, purified::Bool = false) + # compute square roots + (sqrtMN_L, isqrtMN_L), (sqrtMN_R, isqrtMN_R), (sqrtME, isqrtME), + (sqrtMS_L, isqrtMS_L), (sqrtMS_R, isqrtMS_R), (sqrtMW, isqrtMW) = + sqrt_invsqrt.((MN_L, MN_R, ME, MS_L, MS_R, MW)) + + # Stage 1: act on the ket leg (P1). + # Absorb outer messages and place P_bra (P2) on the Q-side, P_ket (P1) on the gate-side. + # separate off the ket indices for efficiency + @tensor T_L[N W S P2; E P1] := A_L[P1 P2; N' E S' W'] * sqrtMN_L[N'; N] * + sqrtMS_L[S; S'] * sqrtMW[W; W'] + Q_L, RQ_L = left_orth!(T_L; positive = true) + @tensor T_R[W P1; N E S P2] := A_R[P1 P2; N' E' S' W] * sqrtMN_R[N'; N] * + sqrtMS_R[S; S'] * sqrtME[E'; E] + LQ_R, Q_R = right_orth!(T_R; positive = true) - return state, messages, ϵ + # apply gate + @tensor gated[-1 -2; -3 -4] := RQ_L[-1; 1 2] * gate[-2 -4; 2 3] * LQ_R[1 3; -3] + U, S, Vᴴ, ϵ = svd_trunc!(gated; trunc) + + sqrtS = sqrt(S) + U′ = rmul!(U, sqrtS) + Vᴴ′ = lmul!(sqrtS, Vᴴ) + + if alg.purified + # extract new PEPO tensors + @tensor A_L′[P1 P2; N E S W] := Q_L[N' W' S' P2; D] * U′[D P1; E] * + isqrtMN_L[N'; N] * isqrtMS_L[S; S'] * isqrtMW[W; W'] + @tensor A_R′[P1 P2; N E S W] := Q_R[D; N' E' S' P2] * Vᴴ′[W; D P1] * + isqrtMN_R[N'; N] * isqrtMS_R[S; S'] * isqrtME[E'; E] + return A_L′, A_R′, S, ϵ + end + + # Stage 2: act on the bra leg (P2) + # No need to reabsorb messages + # separate off the bra indices for efficiency + @tensor T_L[N W S P1; E P2] := Q_L[N W S P2; D] * U′[D P1; E] + Q_L, RQ_L = left_orth!(T_L; positive = true) + @tensor T_R[W P2; N E S P1] := Vᴴ′[W; D P1] * Q_R[D; N E S P2] + LQ_R, Q_R = right_orth!(T_R; positive = true) + + # apply gate + @tensor gated[-1 -2; -3 -4] := RQ_L[-1; 1 2] * gate'[2 3; -2 -4] * LQ_R[1 3; -3] + U, S, Vᴴ, ϵ₂ = svd_trunc!(gated; trunc) + ϵ = max(ϵ, ϵ₂) + + sqrtS = sqrt(S) + U′ = rmul!(U, sqrtS) + Vᴴ′ = lmul!(sqrtS, Vᴴ) + + # extract new PEPS tensors + @tensor A_L′[P1 P2; N E S W] := Q_L[N' W' S' P1; D] * U′[D P2; E] * + isqrtMN_L[N'; N] * isqrtMS_L[S; S'] * isqrtMW[W; W'] + @tensor A_R′[P1 P2; N E S W] := Q_R[D; N' E' S' P1] * Vᴴ′[W; D P2] * + isqrtMN_R[N'; N] * isqrtMS_R[S; S'] * isqrtME[E'; E] + + return A_L′, A_R′, S, ϵ end function apply_gate_2x1!(