diff --git a/benchmark/continuous_transition_bench.jl b/benchmark/continuous_transition_bench.jl new file mode 100644 index 000000000..f96df7aee --- /dev/null +++ b/benchmark/continuous_transition_bench.jl @@ -0,0 +1,196 @@ +#!/usr/bin/env julia +#= +ContinuousTransition Rules Benchmark Script + +Run this script to benchmark the ContinuousTransition rules performance. +Can be executed on different branches to compare optimizations. + +Usage: + julia --project=. benchmark/continuous_transition_bench.jl + julia --project=. benchmark/continuous_transition_bench.jl quick # Quick mode (only small dims) + +Output: Performance table showing timings for each rule and dimension. +=# + +using Pkg +Pkg.activate(dirname(@__DIR__)) + +using BenchmarkTools +using ReactiveMP +using BayesBase +using ExponentialFamily +using Random +using LinearAlgebra +using Distributions +using Printf + +import ReactiveMP: CTMeta, @call_rule, @call_marginalrule + +# ============================================================================ +# Test Data Generation +# ============================================================================ + +function create_benchmark_data(dx, dy) + rng = MersenneTwister(42) + da = dx * dy + + transformation = a -> reshape(a, dy, dx) + meta = CTMeta(transformation) + + Lx = rand(rng, dx, dx) + Ly = rand(rng, dy, dy) + La = rand(rng, da, da) + + μx, Σx = rand(rng, dx), Lx * Lx' + dx * I + μy, Σy = rand(rng, dy), Ly * Ly' + dy * I + μa, Σa = rand(rng, da), La * La' + da * I + + q_y = MvNormalMeanCovariance(μy, Σy) + q_x = MvNormalMeanCovariance(μx, Σx) + q_a = MvNormalMeanCovariance(μa, Σa) + q_W = Wishart(dy + 1, Matrix{Float64}(I, dy, dy)) + q_y_x = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) + + m_y = MvNormalMeanCovariance(μy, Σy) + m_x = MvNormalMeanCovariance(μx, Σx) + + return ( + meta = meta, + q_y = q_y, q_x = q_x, q_a = q_a, q_W = q_W, + q_y_x = q_y_x, + m_y = m_y, m_x = m_x + ) +end + +# ============================================================================ +# Benchmark Functions +# ============================================================================ + +function bench_a_structured(data) + @call_rule ContinuousTransition(:a, Marginalisation) ( + q_y_x = data.q_y_x, + q_a = data.q_a, + q_W = data.q_W, + meta = data.meta + ) +end + +function bench_a_meanfield(data) + @call_rule ContinuousTransition(:a, Marginalisation) ( + q_y = data.q_y, + q_x = data.q_x, + q_a = data.q_a, + q_W = data.q_W, + meta = data.meta + ) +end + +function bench_marginal_y_x(data) + @call_marginalrule ContinuousTransition(:y_x) ( + m_y = data.m_y, + m_x = data.m_x, + q_a = data.q_a, + q_W = data.q_W, + meta = data.meta + ) +end + +# ============================================================================ +# Benchmark Runner +# ============================================================================ + +function run_benchmarks(; quick_mode=false) + println() + println("=" ^ 80) + println(" ContinuousTransition Rules Benchmark") + println(" Branch: ", strip(read(`git rev-parse --abbrev-ref HEAD`, String))) + println(" Commit: ", strip(read(`git rev-parse --short HEAD`, String))) + println("=" ^ 80) + println() + + if quick_mode + test_dims = [(10, 10), (20, 20)] + println(" Mode: QUICK (limited dimensions)") + else + test_dims = [(5, 5), (10, 10), (20, 20), (30, 30), (40, 40)] + println(" Mode: FULL") + end + println() + + # Results storage + results = Dict{String, Vector{Tuple{Int, Int, Float64}}}( + "a_structured" => [], + "a_meanfield" => [], + "marginal_y_x" => [] + ) + + for (dx, dy) in test_dims + println("-" ^ 60) + @printf(" Benchmarking: dx=%d, dy=%d (da=%d)\n", dx, dy, dx*dy) + println("-" ^ 60) + + data = create_benchmark_data(dx, dy) + + # Warm-up calls + bench_a_structured(data) + bench_a_meanfield(data) + bench_marginal_y_x(data) + + # Benchmark a.jl structured + t = @belapsed bench_a_structured($data) + push!(results["a_structured"], (dx, dy, t * 1e6)) + @printf(" a.jl Structured: %10.2f μs\n", t * 1e6) + + # Benchmark a.jl mean-field + t = @belapsed bench_a_meanfield($data) + push!(results["a_meanfield"], (dx, dy, t * 1e6)) + @printf(" a.jl Mean-field: %10.2f μs\n", t * 1e6) + + # Benchmark marginals.jl + t = @belapsed bench_marginal_y_x($data) + push!(results["marginal_y_x"], (dx, dy, t * 1e6)) + @printf(" marginals.jl y_x: %10.2f μs\n", t * 1e6) + + println() + end + + # Print summary table + println("=" ^ 80) + println(" SUMMARY TABLE (times in μs)") + println("=" ^ 80) + println() + + # Header + @printf(" %-12s", "Dimensions") + for rule in ["a_structured", "a_meanfield", "marginal_y_x"] + @printf(" | %14s", replace(rule, "_" => " ")) + end + println() + println(" " * "-" ^ 12, " | ", "-" ^ 14, " | ", "-" ^ 14, " | ", "-" ^ 14) + + # Data rows + for i in eachindex(test_dims) + dx, dy = test_dims[i] + @printf(" %4d × %-4d ", dx, dy) + @printf(" | %14.2f", results["a_structured"][i][3]) + @printf(" | %14.2f", results["a_meanfield"][i][3]) + @printf(" | %14.2f", results["marginal_y_x"][i][3]) + println() + end + + println() + println("=" ^ 80) + println(" Benchmark Complete") + println("=" ^ 80) + println() + + return results +end + +# ============================================================================ +# Main +# ============================================================================ + +quick_mode = length(ARGS) > 0 && ARGS[1] == "quick" +run_benchmarks(quick_mode=quick_mode) + diff --git a/benchmark/rules/continuous_transition/continuous_transition.jl b/benchmark/rules/continuous_transition/continuous_transition.jl new file mode 100644 index 000000000..bf400af6f --- /dev/null +++ b/benchmark/rules/continuous_transition/continuous_transition.jl @@ -0,0 +1,117 @@ +using BenchmarkTools +using ReactiveMP +using BayesBase +using ExponentialFamily +using Random +using LinearAlgebra +using Distributions +using StableRNGs + +import ReactiveMP: CTMeta, Marginal, Message, @call_rule, @call_marginalrule + +""" +Creates test data for ContinuousTransition benchmarks. +Returns distributions and meta needed to call the rules. +""" +function create_ct_benchmark_data(dx, dy) + rng = StableRNGs(42) + da = dx * dy # For linear transformation a -> reshape(a, dy, dx) + + # Transformation function + transformation = a -> reshape(a, dy, dx) + meta = CTMeta(transformation) + + # Create covariance matrices + Lx = rand(rng, dx, dx) + Ly = rand(rng, dy, dy) + La = rand(rng, da, da) + + μx, Σx = rand(rng, dx), Lx * Lx' + dx * I + μy, Σy = rand(rng, dy), Ly * Ly' + dy * I + μa, Σa = rand(rng, da), La * La' + da * I + + # Create distributions for mean-field factorization + q_y = MvNormalMeanCovariance(μy, Σy) + q_x = MvNormalMeanCovariance(μx, Σx) + q_a = MvNormalMeanCovariance(μa, Σa) + q_W = Wishart(dy + 1, Matrix{Float64}(I, dy, dy)) + + # Create joint distribution for structured factorization + q_y_x = MvNormalMeanCovariance([μy; μx], [Σy zeros(dy, dx); zeros(dx, dy) Σx]) + + # Create messages for marginal rule + m_y = MvNormalMeanCovariance(μy, Σy) + m_x = MvNormalMeanCovariance(μx, Σx) + + return ( + meta = meta, + q_y = q_y, q_x = q_x, q_a = q_a, q_W = q_W, + q_y_x = q_y_x, + m_y = m_y, m_x = m_x + ) +end + +""" +Adds ContinuousTransition rule benchmarks to the suite. +""" +function add_continuous_transition_rule_benchmarks(SUITE) + SUITE["ContinuousTransition"] = BenchmarkGroup() + + add_continuous_transition_a_benchmarks(SUITE["ContinuousTransition"]) + add_continuous_transition_marginals_benchmarks(SUITE["ContinuousTransition"]) +end + +function add_continuous_transition_a_benchmarks(SUITE) + SUITE["a"] = BenchmarkGroup(["Rules", "ContinuousTransition"]) + + # Test dimensions: (dx, dy) + test_dims = [(5, 5), (10, 10), (20, 20), (30, 30)] + + for (dx, dy) in test_dims + data = create_ct_benchmark_data(dx, dy) + + # Structured VMP: q(y,x) joint + SUITE["a"]["Structured"]["dx=$(dx), dy=$(dy)"] = @benchmarkable begin + @call_rule ContinuousTransition(:a, Marginalisation) ( + q_y_x = $data.q_y_x, + q_a = $data.q_a, + q_W = $data.q_W, + meta = $data.meta + ) + end + + # Mean-field VMP: q(y)q(x)q(a)q(W) + SUITE["a"]["Mean-field"]["dx=$(dx), dy=$(dy)"] = @benchmarkable begin + @call_rule ContinuousTransition(:a, Marginalisation) ( + q_y = $data.q_y, + q_x = $data.q_x, + q_a = $data.q_a, + q_W = $data.q_W, + meta = $data.meta + ) + end + end +end + +function add_continuous_transition_marginals_benchmarks(SUITE) + SUITE["marginals"] = BenchmarkGroup(["Rules", "ContinuousTransition"]) + + # Test dimensions: (dx, dy) + test_dims = [(5, 5), (10, 10), (20, 20), (30, 30)] + + for (dx, dy) in test_dims + data = create_ct_benchmark_data(dx, dy) + + # y_x marginal rule + SUITE["marginals"]["y_x"]["dx=$(dx), dy=$(dy)"] = @benchmarkable begin + @call_marginalrule ContinuousTransition(:y_x) ( + m_y = $data.m_y, + m_x = $data.m_x, + q_a = $data.q_a, + q_W = $data.q_W, + meta = $data.meta + ) + end + end +end + diff --git a/src/nodes/predefined/continuous_transition.jl b/src/nodes/predefined/continuous_transition.jl index 9845b40b8..35eebb077 100644 --- a/src/nodes/predefined/continuous_transition.jl +++ b/src/nodes/predefined/continuous_transition.jl @@ -127,12 +127,18 @@ end g1 = -mA * Vyx' g2 = g1' + + # Optimized: factor out inner summation to reduce complexity from O(dy²) to O(dy) + # Step 1: For each i, compute H[i] = Σⱼ mW[j,i] * Fs[j] + H = [sum(mW[j, i] * Fs[j] for j in 1:dy) for i in 1:dy] + + # Step 2: Compute traces trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma)) xxt = mx * mx' - for (i, j) in Iterators.product(1:dy, 1:dy) - FjVaFi = Fs[j] * Va * Fs[i]' - trWSU += mW[j, i] * tr(FjVaFi) - trkronxxWSU += mW[j, i] * tr(xxt * FjVaFi) + for i in 1:dy + HVaFi = H[i] * Va * Fs[i]' + trWSU += tr(HVaFi) + trkronxxWSU += tr(xxt * HVaFi) end AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + g1 + g2 + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2 @@ -151,12 +157,17 @@ end n = div(ndims(q_y), 2) mA = ctcompanion_matrix(ma, sqrt.(var(q_a)), meta) + # Optimized: factor out inner summation to reduce complexity from O(dy²) to O(dy) + # Step 1: For each i, compute H[i] = Σⱼ mW[j,i] * Fs[j] + H = [sum(mW[j, i] * Fs[j] for j in 1:dy) for i in 1:dy] + + # Step 2: Compute traces trWSU, trkronxxWSU = zero(eltype(ma)), zero(eltype(ma)) xxt = mx * mx' - for (i, j) in Iterators.product(1:dy, 1:dy) - FjVaFi = Fs[j] * Va * Fs[i]' - trWSU += mW[j, i] * tr(FjVaFi) - trkronxxWSU += mW[j, i] * tr(xxt * FjVaFi) + for i in 1:dy + HVaFi = H[i] * Va * Fs[i]' + trWSU += tr(HVaFi) + trkronxxWSU += tr(xxt * HVaFi) end AE = n / 2 * log2π - mean(logdet, q_W) + (tr(mW * (mA * Vx * mA' + Vy + (mA * mx - my) * (mA * mx - my)')) + trWSU + trkronxxWSU) / 2 diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index d3dd2f87b..ca62760e9 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -17,11 +17,15 @@ Vxymxy = rank1update(Vyx', mx, my) Vxmx = rank1update(Vx, mx) + + # Optimized: factor out inner summation to reduce complexity from O(dy²) to O(dy) + # Step 1: For each i, compute H[i] = Σⱼ mW[j,i] * Fs[j] + H = [sum(mW[j, i] * Fs[j] for j in 1:dy) for i in 1:dy] + + # Step 2: Compute xi and W for i in 1:dy xi += Fs[i]' * Vxymxy * mW[:, i] - for j in 1:dy - W += mW[j, i] * Fs[i]' * Vxmx * Fs[j] - end + W += Fs[i]' * Vxmx * H[i] end return MvNormalWeightedMeanPrecision(xi, W) @@ -42,11 +46,14 @@ end mxmy = mx * my' Vxmx = rank1update(Vx, mx) + # Optimized: factor out inner summation to reduce complexity from O(dy²) to O(dy) + # Step 1: For each i, compute H[i] = Σⱼ mW[j,i] * Fs[j] + H = [sum(mW[j, i] * Fs[j] for j in 1:dy) for i in 1:dy] + + # Step 2: Compute xi and W for i in 1:dy xi += Fs[i]' * mxmy * mW[:, i] - for j in 1:dy - W += mW[j, i] * Fs[i]' * Vxmx * Fs[j] - end + W += Fs[i]' * Vxmx * H[i] end return MvNormalWeightedMeanPrecision(xi, W) diff --git a/src/rules/continuous_transition/marginals.jl b/src/rules/continuous_transition/marginals.jl index abd21539b..dce5c15eb 100644 --- a/src/rules/continuous_transition/marginals.jl +++ b/src/rules/continuous_transition/marginals.jl @@ -24,9 +24,14 @@ function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamil W_21 = negate_inplace!(mA' * mW) + # Optimized: factor out inner summation to reduce complexity from O(dy²) to O(dy) + # Step 1: For each i, compute H[i] = Σⱼ mW[j,i] * Fs[j] + H = [sum(mW[j, i] * Fs[j] for j in 1:dy) for i in 1:dy] + + # Step 2: Compute Ξ Ξ = Wx - for (i, j) in Iterators.product(1:dy, 1:dy) - Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' + for i in 1:dy + Ξ += H[i] * Va * Fs[i]' end W_22 = Ξ + mA' * mW * mA