From 0979b243374019931c30e64e55d3d3a0af86f4c8 Mon Sep 17 00:00:00 2001 From: Richard Samuelson Date: Tue, 31 Mar 2026 13:47:27 -0400 Subject: [PATCH 01/11] Add type `ChordalGMRF`. --- Project.toml | 9 +- benchmarks/Project.toml | 2 + benchmarks/autodiff_comparison.jl | 125 ++++- .../gaussian_approximation_comparison.jl | 247 +++++++++ benchmarks/logpdf_comparison.jl | 215 ++++++++ src/GaussianMarkovRandomFields.jl | 2 + .../condition/gaussian_approximation.jl | 101 +++- src/autodiff/gaussian_approximation.jl | 113 +++- src/autodiff/logpdf.jl | 2 + src/chordal_gmrf.jl | 86 +++ src/piracy.jl | 503 ++++++++++++++++++ src/solvers/selinv.jl | 6 +- test/autodiff/runtests.jl | 1 + .../test_gaussian_approximation_chordal.jl | 356 +++++++++++++ .../test_gaussian_approximation_chordal.jl | 194 +++++++ 15 files changed, 1938 insertions(+), 24 deletions(-) create mode 100644 benchmarks/gaussian_approximation_comparison.jl create mode 100644 benchmarks/logpdf_comparison.jl create mode 100644 src/chordal_gmrf.jl create mode 100644 src/piracy.jl create mode 100644 test/autodiff/test_gaussian_approximation_chordal.jl create mode 100644 test/gaussian_approximation/test_gaussian_approximation_chordal.jl diff --git a/Project.toml b/Project.toml index 7fdae095..3714c1c6 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,8 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Ferrite = "c061ca5d-56c9-439f-9c0e-210fe06d3992" FerriteGmsh = "4f95f4f8-b27c-4ae5-9a39-ea55e634e36b" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GeoInterface = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" Gmsh = "705231aa-382f-11e9-3f0c-b7cb4346fdeb" @@ -20,6 +22,7 @@ LibGEOS = "a90b1aa1-3769-5649-ba7e-abc5a9d163eb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +MatrixDepot = "b51810bb-c9f3-55da-ae3c-350fc1fbce05" NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SelectedInversion = "043bf095-3f01-458a-9f1c-8cf4448fe908" @@ -28,10 +31,10 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" Shapefile = "8e980c4a-a4fe-5da2-b3a7-4b4b0353a2f4" @@ -40,7 +43,6 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] GaussianMarkovRandomFieldsAutoDiff = ["ForwardDiff", "Zygote"] @@ -75,6 +77,7 @@ LinearAlgebra = "<0.0.1, 1" LinearMaps = "3.11" LinearSolve = "2, 3" Makie = "0.19 - 0.22" +MatrixDepot = "1.0.15" NearestNeighbors = "0.4" Pardiso = "1" Random = "<0.0.1, 1" @@ -92,7 +95,7 @@ StatsModels = "0.7" Symbolics = "5" Tensors = "1.16" Test = "<0.0.1, 1" -Zygote = "0.6" +Zygote = "0.6, 0.7" julia = "1.10" [extras] diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 07b751eb..cc7061d0 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -1,5 +1,6 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -7,6 +8,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" GaussianMarkovRandomFields = "d5f06795-35bb-4323-9f0b-405ef76cfc5b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +MatrixDepot = "b51810bb-c9f3-55da-ae3c-350fc1fbce05" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/benchmarks/autodiff_comparison.jl b/benchmarks/autodiff_comparison.jl index a5b255ff..d3cf519d 100644 --- a/benchmarks/autodiff_comparison.jl +++ b/benchmarks/autodiff_comparison.jl @@ -22,6 +22,8 @@ using Random using Zygote, Enzyme, FiniteDiff +using CliqueTrees.Multifrontal: symbolic, chordal + println("="^80) println("AUTODIFF BACKEND COMPARISON: HIGH-DIMENSIONAL HYPERPARAMETER SPACE") println("="^80) @@ -37,7 +39,7 @@ println(" - $n mean parameters (one per time point)") println(" - 1 precision parameter (τ)") # Workflow: θ → GMRF → gaussian_approximation → logpdf -function benchmark_workflow(θ::Vector{Float64}, y::Vector{Int}, x_eval::Vector{Float64}) +function benchmark_workflow(θ::Vector{Float64}, y::PoissonObservations, x_eval::Vector{Float64}) # Extract hyperparameters μ = θ[1:n] # Mean field (100 params) log_τ = θ[n + 1] # Log precision @@ -61,6 +63,31 @@ function benchmark_workflow(θ::Vector{Float64}, y::Vector{Int}, x_eval::Vector{ return logpdf(posterior, x_eval) end +# ChordalGMRF workflow (only supports Zygote) +function benchmark_workflow_chordal(θ::Vector{Float64}, y::PoissonObservations, x_eval::Vector{Float64}) + # Extract hyperparameters + μ = θ[1:n] # Mean field (100 params) + log_τ = θ[n + 1] # Log precision + τ = exp(log_τ) + + # Build precision matrix using RW1 latent model + rw1_model = RW1Model(n) + Q = sparse(precision_matrix(rw1_model; τ = τ)) + + # Prior ChordalGMRF with custom mean + prior = ChordalGMRF(μ, Q) + + # Poisson observation likelihood + obs_model = ExponentialFamily(Poisson) + obs_lik = obs_model(y) + + # Gaussian approximation + posterior = gaussian_approximation(prior, obs_lik) + + # Evaluate log-density + return logpdf(posterior, x_eval) +end + # Generate test data println("\nGenerating test data...") Random.seed!(123) @@ -73,19 +100,24 @@ Random.seed!(123) # Simulate observations (Poisson counts from smooth latent field) x_true = μ_true .+ cumsum(randn(n)) .* sqrt(1 / τ_true) .* 0.5 x_true .-= mean(x_true) # Center -y_obs = rand.(Poisson.(exp.(x_true .+ 0.5))) +y_counts = rand.(Poisson.(exp.(x_true .+ 0.5))) +y_obs = PoissonObservations(y_counts) x_eval = randn(n) .+ 0.3 # Initial parameter values (perturbed from truth) θ_init = θ_true .+ randn(n_params) .* 0.1 -println(" ✓ Generated $(length(y_obs)) Poisson observations") +println(" ✓ Generated $(length(y_counts)) Poisson observations") println(" ✓ Initial parameters: $(n_params)-dimensional vector") -# Verify workflow works -println("\nVerifying workflow...") +# Verify workflows work +println("\nVerifying workflows...") f_val = benchmark_workflow(θ_init, y_obs, x_eval) -println(" ✓ Function value: $(@sprintf("%.4f", f_val))") +println(" ✓ GMRF function value: $(@sprintf("%.4f", f_val))") + +f_val_chordal = benchmark_workflow_chordal(θ_init, y_obs, x_eval) +println(" ✓ ChordalGMRF function value: $(@sprintf("%.4f", f_val_chordal))") +println(" ✓ Difference: $(@sprintf("%.2e", abs(f_val - f_val_chordal)))") # Define backends backends = [ @@ -140,18 +172,61 @@ for (name, backend) in backends end end +# ChordalGMRF benchmark (Zygote only) +println("\n" * "="^80) +println("BENCHMARKING ChordalGMRF (Zygote only)") +println("="^80) + +println("\nChordalGMRF + Zygote:") +println("-"^40) + +try + # Warmup + print(" Warming up... ") + grad_chordal = DifferentiationInterface.gradient( + θ -> benchmark_workflow_chordal(θ, y_obs, x_eval), + AutoZygote(), + θ_init + ) + println("✓") + + # Benchmark + print(" Benchmarking... ") + bench_chordal = @benchmark DifferentiationInterface.gradient( + θ -> benchmark_workflow_chordal(θ, y_obs, x_eval), + AutoZygote(), + $θ_init + ) samples = 10 seconds = 30 + + results["ChordalGMRF+Zygote"] = ( + gradient = grad_chordal, + time = minimum(bench_chordal.times) / 1.0e6, + bench = bench_chordal, + ) + + println("✓") + println(" Time (min): $(@sprintf("%.2f", results["ChordalGMRF+Zygote"].time)) ms") + println(" Time (median): $(@sprintf("%.2f", median(bench_chordal.times) / 1.0e6)) ms") + println(" Allocations: $(bench_chordal.allocs)") + println(" Memory: $(@sprintf("%.2f", bench_chordal.memory / 1.0e6)) MB") + +catch e + println(" ✗ Failed: $e") + results["ChordalGMRF+Zygote"] = nothing +end + # Summary comparison println("\n" * "="^80) println("SUMMARY") println("="^80) -if all(v !== nothing for v in values(results)) +if results["FiniteDiff"] !== nothing # Verify gradients match println("\nGradient verification (comparing to FiniteDiff):") fd_grad = results["FiniteDiff"].gradient - for name in ["Zygote", "Enzyme"] - if results[name] !== nothing + for name in ["Zygote", "Enzyme", "ChordalGMRF+Zygote"] + if get(results, name, nothing) !== nothing grad = results[name].gradient abs_error = abs.(grad - fd_grad) max_error = maximum(abs_error) @@ -164,7 +239,7 @@ if all(v !== nothing for v in values(results)) end # Performance comparison table - println("\nPerformance comparison:") + println("\nPerformance comparison (GMRF backends):") println(" " * "─"^76) println(@sprintf(" %-20s %12s %12s %12s %12s", "Backend", "Time (ms)", "Speedup", "Allocs", "Memory (MB)")) println(" " * "─"^76) @@ -191,10 +266,38 @@ if all(v !== nothing for v in values(results)) enzyme_vs_zygote = results["Zygote"].time / results["Enzyme"].time winner = enzyme_vs_zygote > 1 ? "Enzyme" : "Zygote" ratio = max(enzyme_vs_zygote, 1.0 / enzyme_vs_zygote) - println("\n 🏆 $winner is fastest: $(@sprintf("%.1f", ratio))× faster than the other") + println("\n GMRF winner: $winner ($(@sprintf("%.1f", ratio))× faster)") end end +# ChordalGMRF vs GMRF comparison (Zygote only) +if get(results, "ChordalGMRF+Zygote", nothing) !== nothing && get(results, "Zygote", nothing) !== nothing + println("\n" * "="^80) + println("GMRF vs ChordalGMRF COMPARISON (Zygote)") + println("="^80) + + r_gmrf = results["Zygote"] + r_chordal = results["ChordalGMRF+Zygote"] + + println("\n " * "─"^76) + println(@sprintf(" %-20s %12s %12s %12s %12s", "Implementation", "Time (ms)", "Speedup", "Allocs", "Memory (MB)")) + println(" " * "─"^76) + + println(@sprintf(" %-20s %12.2f %12s %12d %12.2f", + "GMRF", r_gmrf.time, "1.0×", r_gmrf.bench.allocs, r_gmrf.bench.memory / 1.0e6)) + + chordal_speedup = r_gmrf.time / r_chordal.time + println(@sprintf(" %-20s %12.2f %12s %12d %12.2f", + "ChordalGMRF", r_chordal.time, @sprintf("%.1f×", chordal_speedup), + r_chordal.bench.allocs, r_chordal.bench.memory / 1.0e6)) + + println(" " * "─"^76) + + winner = chordal_speedup > 1 ? "ChordalGMRF" : "GMRF" + ratio = max(chordal_speedup, 1.0 / chordal_speedup) + println("\n Winner: $winner ($(@sprintf("%.1f", ratio))× faster)") +end + println("\n" * "="^80) println("BENCHMARK COMPLETE") println("="^80) diff --git a/benchmarks/gaussian_approximation_comparison.jl b/benchmarks/gaussian_approximation_comparison.jl new file mode 100644 index 00000000..6f3a12b3 --- /dev/null +++ b/benchmarks/gaussian_approximation_comparison.jl @@ -0,0 +1,247 @@ +#!/usr/bin/env julia +""" +Benchmark: Compare GMRF vs ChordalGMRF for gaussian_approximation + +Tests both correctness (results match) and performance on SSMC matrices +with Poisson observation likelihoods. + +Usage: + cd benchmarks + julia --project=. gaussian_approximation_comparison.jl +""" + +using GaussianMarkovRandomFields +using BenchmarkTools +using Distributions: logpdf, Poisson +using SparseArrays +using LinearAlgebra +using LinearSolve +using Printf +using Random +using MatrixDepot +using Zygote + +using CliqueTrees.Multifrontal: symbolic, chordal + +println("="^80) +println("GAUSSIAN APPROXIMATION COMPARISON: GMRF vs ChordalGMRF") +println("="^80) + +# Helper to make a matrix positive definite +function make_posdef(A::SparseMatrixCSC) + # Symmetrize and add diagonal dominance + S = (A + A') / 2 + d = vec(sum(abs, S; dims=2)) + return S + spdiagm(0 => d .+ 1.0) +end + +# Handle Symmetric wrapper from MatrixDepot +make_posdef(A::Symmetric) = make_posdef(sparse(A)) + +# Test matrices from SSMC (larger for meaningful benchmarks) +test_matrices = [ + ("HB/bcsstk15", "Structural, n=3948"), + ("HB/bcsstk16", "Structural, n=4884"), + ("HB/bcsstk17", "Structural, n=10974"), + ("HB/bcsstk18", "Structural, n=11948"), +] + +println("\nTest matrices:") +for (name, desc) in test_matrices + println(" - $name ($desc)") +end + +results = [] + +for (matrix_name, desc) in test_matrices + println("\n" * "="^80) + println("Matrix: $matrix_name ($desc)") + println("="^80) + + # Load and prepare matrix + try + A_raw = matrixdepot(matrix_name) + Q = make_posdef(A_raw) + n = size(Q, 1) + + println(" Size: $n × $n") + println(" Nonzeros: $(nnz(Q))") + + # Create mean vector and synthetic Poisson observations + Random.seed!(42) + μ = zeros(n) + + # Generate Poisson counts (moderate values to avoid numerical issues) + latent = randn(n) * 0.5 + y_counts = rand.(Poisson.(exp.(latent .+ 1.0))) + y = PoissonObservations(y_counts) + + # Create observation likelihood + obs_model = ExponentialFamily(Poisson) + obs_lik = obs_model(y) + + # Create GMRF prior (baseline) + println("\n Creating GMRF prior...") + gmrf_prior = GMRF(μ, Q, LinearSolve.CHOLMODFactorization()) + + # Create ChordalGMRF prior + println(" Creating ChordalGMRF prior...") + chordal_prior = ChordalGMRF(μ, Q) + + # Run gaussian_approximation + println("\n Running gaussian_approximation...") + posterior_gmrf = gaussian_approximation(gmrf_prior, obs_lik) + posterior_chordal = gaussian_approximation(chordal_prior, obs_lik) + + # Correctness check + println("\n Correctness check:") + mean_gmrf = mean(posterior_gmrf) + mean_chordal = mean(posterior_chordal) + Q_gmrf = precision_matrix(posterior_gmrf) + Q_chordal = precision_matrix(posterior_chordal) + + mean_diff = norm(mean_gmrf - mean_chordal) + mean_rel_diff = mean_diff / (norm(mean_gmrf) + 1e-10) + Q_diff = norm(Q_gmrf - Q_chordal) + Q_rel_diff = Q_diff / (norm(Q_gmrf) + 1e-10) + + println(" Mean abs diff: $(@sprintf("%.2e", mean_diff))") + println(" Mean rel diff: $(@sprintf("%.2e", mean_rel_diff))") + println(" Precision abs diff: $(@sprintf("%.2e", Q_diff))") + println(" Precision rel diff: $(@sprintf("%.2e", Q_rel_diff))") + + correct = mean_rel_diff < 1e-6 && Q_rel_diff < 1e-6 + println(" Match: $(correct ? "✓ YES" : "✗ NO")") + + # Performance benchmark + println("\n Performance benchmark:") + + # Benchmark GMRF + print(" GMRF... ") + bench_gmrf = @benchmark gaussian_approximation($gmrf_prior, $obs_lik) samples=10 seconds=10 + time_gmrf = minimum(bench_gmrf.times) / 1e6 + println("$(@sprintf("%.3f", time_gmrf)) ms") + + # Benchmark ChordalGMRF + print(" ChordalGMRF... ") + bench_chordal = @benchmark gaussian_approximation($chordal_prior, $obs_lik) samples=10 seconds=10 + time_chordal = minimum(bench_chordal.times) / 1e6 + println("$(@sprintf("%.3f", time_chordal)) ms") + + speedup = time_gmrf / time_chordal + println(" Speedup: $(@sprintf("%.2f", speedup))×") + + # Gradient correctness check (gradient of sum(posterior mean) w.r.t. prior mean) + println("\n Gradient correctness check (w.r.t. prior mean):") + + function loss_gmrf(μ_prior) + prior = GMRF(μ_prior, Q, LinearSolve.CHOLMODFactorization()) + post = gaussian_approximation(prior, obs_lik) + return sum(mean(post)) + end + + function loss_chordal(μ_prior) + prior = ChordalGMRF(μ_prior, Q) + post = gaussian_approximation(prior, obs_lik) + return sum(mean(post)) + end + + grad_gmrf = Zygote.gradient(loss_gmrf, μ)[1] + grad_chordal = Zygote.gradient(loss_chordal, μ)[1] + grad_abs_diff = norm(grad_gmrf - grad_chordal) + grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1e-10) + + println(" Absolute diff: $(@sprintf("%.2e", grad_abs_diff))") + println(" Relative diff: $(@sprintf("%.2e", grad_rel_diff))") + + grad_correct = grad_rel_diff < 1e-6 + println(" Match: $(grad_correct ? "✓ YES" : "✗ NO")") + + # Gradient performance benchmark + println("\n Gradient performance benchmark:") + + print(" GMRF... ") + bench_grad_gmrf = @benchmark Zygote.gradient($loss_gmrf, $μ) samples=10 seconds=10 + time_grad_gmrf = minimum(bench_grad_gmrf.times) / 1e6 + println("$(@sprintf("%.3f", time_grad_gmrf)) ms") + + print(" ChordalGMRF... ") + bench_grad_chordal = @benchmark Zygote.gradient($loss_chordal, $μ) samples=10 seconds=10 + time_grad_chordal = minimum(bench_grad_chordal.times) / 1e6 + println("$(@sprintf("%.3f", time_grad_chordal)) ms") + + grad_speedup = time_grad_gmrf / time_grad_chordal + println(" Speedup: $(@sprintf("%.2f", grad_speedup))×") + + push!(results, ( + name=matrix_name, + n=n, + nnz=nnz(Q), + correct=correct, + grad_correct=grad_correct, + time_gmrf=time_gmrf, + time_chordal=time_chordal, + speedup=speedup, + time_grad_gmrf=time_grad_gmrf, + time_grad_chordal=time_grad_chordal, + grad_speedup=grad_speedup, + )) + + catch e + println(" ✗ Failed: $(typeof(e).name.name): $(sprint(showerror, e; context=:limit=>true))") + push!(results, (name=matrix_name, n=0, nnz=0, correct=false, grad_correct=false, + time_gmrf=NaN, time_chordal=NaN, speedup=NaN, + time_grad_gmrf=NaN, time_grad_chordal=NaN, grad_speedup=NaN)) + end +end + +# Summary table +println("\n" * "="^80) +println("SUMMARY: FORWARD PASS") +println("="^80) + +println("\n" * "-"^95) +@printf("%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup") +println("-"^95) + +for r in results + correct_str = r.correct ? "✓" : "✗" + @printf("%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_gmrf, r.time_chordal, r.speedup) +end +println("-"^95) + +# Gradient summary table +println("\n" * "="^80) +println("SUMMARY: GRADIENT (Zygote)") +println("="^80) + +println("\n" * "-"^95) +@printf("%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup") +println("-"^95) + +for r in results + correct_str = r.grad_correct ? "✓" : "✗" + @printf("%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_grad_gmrf, r.time_grad_chordal, r.grad_speedup) +end +println("-"^95) + +# Overall stats +valid_results = filter(r -> !isnan(r.speedup), results) +if !isempty(valid_results) + avg_speedup = sum(r.speedup for r in valid_results) / length(valid_results) + avg_grad_speedup = sum(r.grad_speedup for r in valid_results) / length(valid_results) + all_correct = all(r.correct for r in valid_results) + all_grad_correct = all(r.grad_correct for r in valid_results) + + println("\nOverall:") + println(" Forward - All match: $(all_correct ? "✓ YES" : "✗ NO"), Avg speedup: $(@sprintf("%.2f", avg_speedup))×") + println(" Gradient - All match: $(all_grad_correct ? "✓ YES" : "✗ NO"), Avg speedup: $(@sprintf("%.2f", avg_grad_speedup))×") +end + +println("\n" * "="^80) +println("BENCHMARK COMPLETE") +println("="^80) diff --git a/benchmarks/logpdf_comparison.jl b/benchmarks/logpdf_comparison.jl new file mode 100644 index 00000000..332ee0f4 --- /dev/null +++ b/benchmarks/logpdf_comparison.jl @@ -0,0 +1,215 @@ +#!/usr/bin/env julia +""" +Benchmark: Compare GMRF vs ChordalGMRF for logpdf computation + +Tests both correctness (results match) and performance on SSMC matrices. + +Usage: + cd benchmarks + julia --project=. logpdf_comparison.jl +""" + +using GaussianMarkovRandomFields +using BenchmarkTools +using Distributions: logpdf +using SparseArrays +using LinearAlgebra +using LinearSolve +using Printf +using Random +using MatrixDepot +using Zygote + +using CliqueTrees.Multifrontal: symbolic, chordal + +println("="^80) +println("LOGPDF COMPARISON: GMRF vs ChordalGMRF") +println("="^80) + +# Helper to make a matrix positive definite +function make_posdef(A::SparseMatrixCSC) + # Symmetrize and add diagonal dominance + S = (A + A') / 2 + d = vec(sum(abs, S; dims=2)) + return S + spdiagm(0 => d .+ 1.0) +end + +# Handle Symmetric wrapper from MatrixDepot +make_posdef(A::Symmetric) = make_posdef(sparse(A)) + +# Test matrices from SSMC (larger for ~100ms target) +test_matrices = [ + ("HB/bcsstk14", "Structural, n=1806"), + ("HB/bcsstk15", "Structural, n=3948"), + ("HB/bcsstk16", "Structural, n=4884"), + ("HB/bcsstk17", "Structural, n=10974"), +] + +println("\nTest matrices:") +for (name, desc) in test_matrices + println(" - $name ($desc)") +end + +results = [] + +for (matrix_name, desc) in test_matrices + println("\n" * "="^80) + println("Matrix: $matrix_name ($desc)") + println("="^80) + + # Load and prepare matrix + try + A_raw = matrixdepot(matrix_name) + Q = make_posdef(A_raw) + n = size(Q, 1) + + println(" Size: $n × $n") + println(" Nonzeros: $(nnz(Q))") + + # Create mean vector and evaluation point + Random.seed!(42) + μ = randn(n) + z = randn(n) + + # Create GMRF (baseline) + println("\n Creating GMRF...") + gmrf = GMRF(μ, Q, LinearSolve.CHOLMODFactorization()) + + # Create ChordalGMRF + println(" Creating ChordalGMRF...") + chordal_gmrf = ChordalGMRF(μ, Q) + + # Correctness check + println("\n Correctness check:") + lpdf_gmrf = logpdf(gmrf, z) + lpdf_chordal = logpdf(chordal_gmrf, z) + abs_diff = abs(lpdf_gmrf - lpdf_chordal) + rel_diff = abs_diff / (abs(lpdf_gmrf) + 1e-10) + + println(" GMRF logpdf: $(@sprintf("%.8f", lpdf_gmrf))") + println(" ChordalGMRF logpdf: $(@sprintf("%.8f", lpdf_chordal))") + println(" Absolute diff: $(@sprintf("%.2e", abs_diff))") + println(" Relative diff: $(@sprintf("%.2e", rel_diff))") + + correct = rel_diff < 1e-8 + println(" Match: $(correct ? "✓ YES" : "✗ NO")") + + # Performance benchmark + println("\n Performance benchmark:") + + # Benchmark GMRF + print(" GMRF... ") + bench_gmrf = @benchmark logpdf($gmrf, $z) samples=20 seconds=5 + time_gmrf = minimum(bench_gmrf.times) / 1e6 + println("$(@sprintf("%.3f", time_gmrf)) ms") + + # Benchmark ChordalGMRF + print(" ChordalGMRF... ") + bench_chordal = @benchmark logpdf($chordal_gmrf, $z) samples=20 seconds=5 + time_chordal = minimum(bench_chordal.times) / 1e6 + println("$(@sprintf("%.3f", time_chordal)) ms") + + speedup = time_gmrf / time_chordal + println(" Speedup: $(@sprintf("%.2f", speedup))×") + + # Gradient correctness check + println("\n Gradient correctness check (w.r.t. z):") + grad_gmrf = Zygote.gradient(x -> logpdf(gmrf, x), z)[1] + grad_chordal = Zygote.gradient(x -> logpdf(chordal_gmrf, x), z)[1] + grad_abs_diff = norm(grad_gmrf - grad_chordal) + grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1e-10) + + println(" Absolute diff: $(@sprintf("%.2e", grad_abs_diff))") + println(" Relative diff: $(@sprintf("%.2e", grad_rel_diff))") + + grad_correct = grad_rel_diff < 1e-8 + println(" Match: $(grad_correct ? "✓ YES" : "✗ NO")") + + # Gradient performance benchmark + println("\n Gradient performance benchmark:") + + print(" GMRF... ") + bench_grad_gmrf = @benchmark Zygote.gradient(x -> logpdf($gmrf, x), $z) samples=20 seconds=5 + time_grad_gmrf = minimum(bench_grad_gmrf.times) / 1e6 + println("$(@sprintf("%.3f", time_grad_gmrf)) ms") + + print(" ChordalGMRF... ") + bench_grad_chordal = @benchmark Zygote.gradient(x -> logpdf($chordal_gmrf, x), $z) samples=20 seconds=5 + time_grad_chordal = minimum(bench_grad_chordal.times) / 1e6 + println("$(@sprintf("%.3f", time_grad_chordal)) ms") + + grad_speedup = time_grad_gmrf / time_grad_chordal + println(" Speedup: $(@sprintf("%.2f", grad_speedup))×") + + push!(results, ( + name=matrix_name, + n=n, + nnz=nnz(Q), + correct=correct, + grad_correct=grad_correct, + time_gmrf=time_gmrf, + time_chordal=time_chordal, + speedup=speedup, + time_grad_gmrf=time_grad_gmrf, + time_grad_chordal=time_grad_chordal, + grad_speedup=grad_speedup, + )) + + catch e + println(" ✗ Failed: $(typeof(e).name.name): $(sprint(showerror, e; context=:limit=>true))") + push!(results, (name=matrix_name, n=0, nnz=0, correct=false, grad_correct=false, + time_gmrf=NaN, time_chordal=NaN, speedup=NaN, + time_grad_gmrf=NaN, time_grad_chordal=NaN, grad_speedup=NaN)) + end +end + +# Summary table +println("\n" * "="^80) +println("SUMMARY: FORWARD PASS") +println("="^80) + +println("\n" * "-"^95) +@printf("%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup") +println("-"^95) + +for r in results + correct_str = r.correct ? "✓" : "✗" + @printf("%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_gmrf, r.time_chordal, r.speedup) +end +println("-"^95) + +# Gradient summary table +println("\n" * "="^80) +println("SUMMARY: GRADIENT (Zygote)") +println("="^80) + +println("\n" * "-"^95) +@printf("%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup") +println("-"^95) + +for r in results + correct_str = r.grad_correct ? "✓" : "✗" + @printf("%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_grad_gmrf, r.time_grad_chordal, r.grad_speedup) +end +println("-"^95) + +# Overall stats +valid_results = filter(r -> !isnan(r.speedup), results) +if !isempty(valid_results) + avg_speedup = sum(r.speedup for r in valid_results) / length(valid_results) + avg_grad_speedup = sum(r.grad_speedup for r in valid_results) / length(valid_results) + all_correct = all(r.correct for r in valid_results) + all_grad_correct = all(r.grad_correct for r in valid_results) + + println("\nOverall:") + println(" Forward - All match: $(all_correct ? "✓ YES" : "✗ NO"), Avg speedup: $(@sprintf("%.2f", avg_speedup))×") + println(" Gradient - All match: $(all_grad_correct ? "✓ YES" : "✗ NO"), Avg speedup: $(@sprintf("%.2f", avg_grad_speedup))×") +end + +println("\n" * "="^80) +println("BENCHMARK COMPLETE") +println("="^80) diff --git a/src/GaussianMarkovRandomFields.jl b/src/GaussianMarkovRandomFields.jl index 03d68c05..8e20ffc4 100644 --- a/src/GaussianMarkovRandomFields.jl +++ b/src/GaussianMarkovRandomFields.jl @@ -1,10 +1,12 @@ module GaussianMarkovRandomFields +include("piracy.jl") include("typedefs.jl") include("utils/utils.jl") include("linear_maps/linear_maps.jl") include("preconditioners/preconditioners.jl") include("gmrf.jl") +include("chordal_gmrf.jl") include("metagmrf.jl") include("solvers/solvers.jl") include("autoregressive/autoregressive.jl") diff --git a/src/arithmetic/condition/gaussian_approximation.jl b/src/arithmetic/condition/gaussian_approximation.jl index 4a9b32b9..1574d335 100644 --- a/src/arithmetic/condition/gaussian_approximation.jl +++ b/src/arithmetic/condition/gaussian_approximation.jl @@ -1,9 +1,14 @@ using LinearAlgebra using SparseArrays using LinearMaps +using CliqueTrees.Multifrontal: chordal, ChordalCholesky, triangular +using CliqueTrees.Multifrontal.Differential: ldivsym export gaussian_approximation +# Sparse-preserving subtraction for Hermitian matrices +hermdiff(A::Hermitian, B) = Hermitian(parent(A) - B, Symbol(A.uplo)) + function neg_log_posterior(prior_gmrf::AbstractGMRF, obs_lik::ObservationLikelihood, x) return -logpdf(prior_gmrf, x) - loglik(x, obs_lik) end @@ -155,7 +160,7 @@ function gaussian_approximation( # Apply step with adaptive line search or full step if adaptive_stepsize obj_current = neg_log_posterior(base_gmrf, obs_lik, x_k) - step_accepted = false + accept = false for ls_iter in 1:max_linesearch_iter candidate = x_k - α * step @@ -165,7 +170,7 @@ function gaussian_approximation( # Accept step, increase α toward 1 μ_new = candidate α = sqrt(α) - step_accepted = true + accept = true verbose && ls_iter > 1 && println(" Accepted at α=$(round(α^2, digits = 3)) after $ls_iter backtracks") break else @@ -175,13 +180,13 @@ function gaussian_approximation( if α * norm(step, Inf) < newton_dec_tol / 1000 μ_new = candidate - step_accepted = true + accept = true break end end end - if !step_accepted + if !accept μ_new = x_k - α * step end else @@ -311,3 +316,91 @@ function gaussian_approximation(prior_mgmrf::MetaGMRF, obs_lik::ObservationLikel posterior_gmrf = gaussian_approximation(prior_mgmrf.gmrf, obs_lik; kwargs...) return MetaGMRF(posterior_gmrf, prior_mgmrf.metadata) end + +# ChordalGMRF dispatch - uses chordal factorization for efficient solves +function gaussian_approximation( + prior_gmrf::ChordalGMRF{T}, + obs_lik::ObservationLikelihood; + x0::Union{Nothing, AbstractVector} = nothing, + max_iter::Int = 50, + mean_change_tol::Real = 1.0e-4, + newton_dec_tol::Real = 1.0e-5, + adaptive_stepsize::Bool = true, + max_linesearch_iter::Int = 10, + verbose::Bool = false + ) where {T} + S = prior_gmrf.L.S + P = prior_gmrf.P + Q_prior = prior_gmrf.Q + + x_k = isnothing(x0) ? copy(mean(prior_gmrf)) : copy(x0) + + α = 1.0 + Q_new = Q_prior + F_new = ChordalCholesky{:L, T}(P, S) + + verbose && println("Starting Fisher scoring (ChordalGMRF)...") + + for iter in 1:max_iter + H_k = loghessian(x_k, obs_lik) + Q_new = hermdiff(Q_prior, H_k) + copyto!(F_new, Q_new) + cholesky!(F_new) + neg_score_k = ∇ₓ_neg_log_posterior(prior_gmrf, obs_lik, x_k) + step = F_new \ neg_score_k + + if adaptive_stepsize + obj_current = neg_log_posterior(prior_gmrf, obs_lik, x_k) + step_accepted = false + + for ls_iter in 1:max_linesearch_iter + candidate = x_k - α * step + obj_candidate = neg_log_posterior(prior_gmrf, obs_lik, candidate) + + if obj_candidate <= obj_current + x_new = candidate + α = sqrt(α) + step_accepted = true + verbose && ls_iter > 1 && println(" Accepted at α=$(round(α^2, digits = 3)) after $ls_iter backtracks") + break + else + α *= 0.1 + verbose && println(" Backtrack: α=$(round(α, digits = 4))") + + if α * norm(step, Inf) < newton_dec_tol / 1000 + x_new = candidate + step_accepted = true + break + end + end + end + + if !step_accepted + x_new = x_k - α * step + end + else + x_new = x_k - step + end + + newton_decrement = dot(neg_score_k, step) + mean_change = norm(x_new - x_k) + mean_change_rel = mean_change / max(norm(x_k), 1e-10) + + verbose && println(" Iter $iter: Newton dec = $(round(newton_decrement, sigdigits = 3)), α = $(round(α, digits = 3))") + + if (newton_decrement < newton_dec_tol) || (mean_change < mean_change_tol) || (mean_change_rel < mean_change_tol) + verbose && println(" Converged after $iter iterations") + return ChordalGMRF(x_new, Q_new, F_new.L, P) + end + + x_k = x_new + end + + verbose && println(" Reached max_iter = $max_iter without convergence") + + H_k = loghessian(x_k, obs_lik) + Q_new = Q_prior - H_k + copyto!(F_new, Hermitian(Q_new, :L)) + cholesky!(F_new) + return ChordalGMRF(x_k, Q_new, F_new.L, P) +end diff --git a/src/autodiff/gaussian_approximation.jl b/src/autodiff/gaussian_approximation.jl index 611b28ff..34eccc32 100644 --- a/src/autodiff/gaussian_approximation.jl +++ b/src/autodiff/gaussian_approximation.jl @@ -1,5 +1,6 @@ using ChainRulesCore using LinearAlgebra +using CliqueTrees.Multifrontal.Differential: ldivsym """ _is_zero_tangent(x) -> Bool @@ -313,12 +314,118 @@ function ChainRulesCore.rrule( end # Also handle case without RuleConfig for simpler usage + +# ============================================================================= +# ChordalGMRF rrule +# ============================================================================= + +# Extract tangents from ChordalGMRF posterior +function _extract_posterior_tangents(ȳ, ::ChordalGMRF) + μ̄ = _is_zero_tangent(ȳ.μ) ? ZeroTangent() : ȳ.μ + Q̄ = _is_zero_tangent(ȳ.Q) ? ZeroTangent() : ȳ.Q + return (μ̄, Q̄) +end + +# Add Q̄ contribution to prior tangent for ChordalGMRF +function _add_precision_tangent(prior_tangent, prior::ChordalGMRF, Q̄) + prior_μ̄ = (prior_tangent isa Tangent && hasproperty(prior_tangent, :μ)) ? prior_tangent.μ : NoTangent() + prior_Q̄_existing = (prior_tangent isa Tangent && hasproperty(prior_tangent, :Q)) ? prior_tangent.Q : NoTangent() + combined_Q̄ = _is_zero_tangent(prior_Q̄_existing) ? Q̄ : prior_Q̄_existing + Q̄ + return Tangent{typeof(prior)}(; + μ = prior_μ̄, + Q = combined_Q̄, + L = NoTangent(), + P = NoTangent(), + ) +end + +""" + ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(gaussian_approximation), + prior_gmrf::ChordalGMRF, obs_lik::ObservationLikelihood; kwargs...) + +Automatic differentiation rule for `gaussian_approximation` with `ChordalGMRF` using the Implicit Function Theorem. + +Uses the chordal structure from CliqueTrees for efficient operations. The posterior precision +is `Q_post = Q_prior - loghessian(x*, obs_lik)`. + +# Mathematical Approach +Uses IFT on the optimality condition ∇ₓ neg_log_posterior(x*) = 0: +- Forward: Solve via Fisher scoring +- Backward: + 1. Extract (μ̄, Q̄) from posterior tangent + 2. Backprop Q̄ through loghessian to get x* and obs_lik contributions + 3. Combine with μ̄ and solve Q_post λ = total_x̄ using chordal Cholesky + 4. VJP through ∇ₓ_neg_log_posterior with -λ to get prior and obs_lik tangents + 5. Add direct Q̄ contribution to prior +""" function ChainRulesCore.rrule( + config::RuleConfig{>:HasReverseMode}, ::typeof(gaussian_approximation), - prior_gmrf::Union{GMRF, ConstrainedGMRF}, + prior_gmrf::ChordalGMRF, obs_lik::ObservationLikelihood; kwargs... ) - # Delegate to the RuleConfig version with default config - return rrule(NoRuleConfig(), gaussian_approximation, prior_gmrf, obs_lik; kwargs...) + # === Forward pass === + posterior = gaussian_approximation(prior_gmrf, obs_lik; kwargs...) + + # Mode in unpermuted coordinates + x_star = mean(posterior) + P = posterior.P + S = posterior.L.S + + # === Pullback === + function chordal_gaussian_approximation_pullback(ȳ) + # Handle zero tangent case + if ȳ isa ZeroTangent + return (NoTangent(), ZeroTangent(), NoTangent()) + end + + # Extract tangent components + μ̄, Q̄ = _extract_posterior_tangents(ȳ, posterior) + + # --- Q̄ path: backprop through Q_post = Q_prior - loghessian(x*, obs_lik) --- + if _is_zero_tangent(Q̄) + x̄_from_Q = ZeroTangent() + obs_lik_tangent_from_Q = NoTangent() + else + # Use rrule_via_ad on loghessian; Q_post = Q_prior - loghessian(...) + _, hess_pullback = rrule_via_ad(config, loghessian, x_star, obs_lik) + # -Q̄ because Q_post = Q_prior - loghessian(...) + _, x̄_from_Q, obs_lik_tangent_from_Q = hess_pullback(-Q̄) + end + + # --- μ̄ path: μ_post = x*, so μ̄ flows directly to x* --- + x̄_from_μ = _is_zero_tangent(μ̄) ? ZeroTangent() : μ̄ + + # --- Combine x* tangents --- + if _is_zero_tangent(x̄_from_μ) && _is_zero_tangent(x̄_from_Q) + x̄_total = zeros(eltype(x_star), length(x_star)) + elseif _is_zero_tangent(x̄_from_μ) + x̄_total = collect(x̄_from_Q) + elseif _is_zero_tangent(x̄_from_Q) + x̄_total = collect(x̄_from_μ) + else + x̄_total = collect(x̄_from_μ) .+ collect(x̄_from_Q) + end + + # --- IFT: solve Q_post * λ = x̄_total using chordal Cholesky --- + λ = ldivsym(precision_matrix(posterior), posterior.L, P, x̄_total) + + # --- VJP through ∇ₓ_neg_log_posterior at x* --- + _, ∇_pullback = rrule_via_ad(config, ∇ₓ_neg_log_posterior, prior_gmrf, obs_lik, x_star) + _, prior_tangent, obs_lik_tangent, _ = ∇_pullback(-λ) # Minus sign from IFT + + # --- Add direct Q̄ contribution to prior (Q_post = Q_prior - ...) --- + if !_is_zero_tangent(Q̄) + prior_tangent = _add_precision_tangent(prior_tangent, prior_gmrf, Q̄) + end + + # --- Combine obs_lik tangents from both paths --- + obs_lik_combined = _add_namedtuples(obs_lik_tangent, obs_lik_tangent_from_Q) + + return (NoTangent(), prior_tangent, obs_lik_combined) + end + + return posterior, chordal_gaussian_approximation_pullback end + diff --git a/src/autodiff/logpdf.jl b/src/autodiff/logpdf.jl index e266abe8..71f5311c 100644 --- a/src/autodiff/logpdf.jl +++ b/src/autodiff/logpdf.jl @@ -98,3 +98,5 @@ function ChainRulesCore.rrule(::typeof(logpdf), x::AbstractGMRF, z::AbstractVect ) end end + +ChainRulesCore.@opt_out rrule(::typeof(logpdf), ::ChordalGMRF, ::AbstractVector) diff --git a/src/chordal_gmrf.jl b/src/chordal_gmrf.jl new file mode 100644 index 00000000..a5986fc9 --- /dev/null +++ b/src/chordal_gmrf.jl @@ -0,0 +1,86 @@ +using CliqueTrees.Multifrontal: ChordalTriangular, Permutation, ChordalSymbolic, symbolic, chordal, selinv as mselinv, logdet +using LinearAlgebra: Hermitian, cholesky, diag, ldiv!, axpy!, dot +using SparseArrays: SparseMatrixCSC +using Random: AbstractRNG, randn + +export ChordalGMRF + +struct ChordalGMRF{T <: Real, Herm <: HermSparse{T}, Tri <: ChordalTriangular{:N, :L, T}, Prm <: Permutation, Mea <: AbstractVector{T}} <: AbstractGMRF{T, Herm} + μ::Mea + Q::Herm + L::Tri + P::Prm +end + +function ChordalGMRF(μ::AbstractVector, Q::SparseMatrixCSC; kw...) + H = Hermitian(Q, :L) + P, S = symbolic(H; kw...) + L = cholesky(chordal(H, P, S)) + return ChordalGMRF(μ, H, L, P) +end + +function Base.length(d::ChordalGMRF) + return length(d.μ) +end + +function mean(d::ChordalGMRF) + return d.μ +end + +function precision_map(d::ChordalGMRF) + return d.Q +end + +function precision_matrix(d::ChordalGMRF) + return d.Q +end + +function logdetcov(d::ChordalGMRF) + return -logdet(precision_matrix(d), d.L, d.P) +end + +function sqmahal(d::ChordalGMRF, x::AbstractVector) + r = x - d.μ + return dot(r, precision_matrix(d), r) +end + +function gradlogpdf(d::ChordalGMRF, x::AbstractVector) + return precision_matrix(d) * (d.μ - x) +end + +function var(d::ChordalGMRF) + Σ = mselinv(precision_matrix(d), d.L, d.P) + return diag(Σ) +end + +function _rand!(rng::AbstractRNG, d::ChordalGMRF{T}, x::AbstractVector) where {T} + z = randn(rng, T, length(x)) + return axpy!(1, d.μ, d.P \ ldiv!(d.L', d.P * z)) +end + +function Base.show(io::IO, d::ChordalGMRF{T}) where {T} + return print(io, "ChordalGMRF{$T}(n=$(length(d)))") +end + +function Base.show(io::IO, ::MIME"text/plain", d::ChordalGMRF{T}) where {T} + println(io, "ChordalGMRF{$T} with $(length(d)) variables") + + μ = d.μ + + if length(μ) <= 6 + print(io, " Mean: $μ") + else + print(io, " Mean: [$(μ[1]), $(μ[2]), $(μ[3]), ..., $(μ[end-2]), $(μ[end-1]), $(μ[end])]") + end +end + +# ChainRulesCore rrule for ChordalGMRF constructor +# ChordalGMRF is defined by (μ, Q). L and P are derived - gradients never flow through them. +using ChainRulesCore: ChainRulesCore, NoTangent + +function ChainRulesCore.rrule(::typeof(ChordalGMRF), μ::AbstractVector, Q::SparseMatrixCSC; kw...) + result = ChordalGMRF(μ, Q; kw...) + ChordalGMRF_pullback(ȳ) = (NoTangent(), ȳ.μ, ȳ.Q) + return result, ChordalGMRF_pullback +end + diff --git a/src/piracy.jl b/src/piracy.jl new file mode 100644 index 00000000..9621147c --- /dev/null +++ b/src/piracy.jl @@ -0,0 +1,503 @@ +# | | | +# )_) )_) )_) +# )___))___))___) +# )____)____)_____) +# _____|____|____|____ +# ---------\ /--------- +# ^^^^^ ^^^^^^^^^^^^^^^^^^^^^ +# ^^^^ ^^^^ ^^^ ^^ +# ^^^^ ^^^ +# +# Type piracy to enable autodiff for Hermitian/Symmetric sparse matrices. +# These changes have been submitted as PRs to ChainRulesCore, ChainRules, and Zygote. +# This file can be removed once those PRs are merged and released. + +using ChainRulesCore +using ChainRulesCore: ProjectTo, project_type, _projection_mismatch, NoTangent, ZeroTangent, AbstractZero, @thunk, unthunk +using LinearAlgebra +using LinearAlgebra: Hermitian, Symmetric, Adjoint, Transpose, AdjOrTrans, dot, rmul!, tril, triu +using SparseArrays +using SparseArrays: SparseMatrixCSC, nzrange, rowvals, getcolptr, nonzeros +using Zygote + +##### +##### Type aliases +##### + +const HermSparse{T, I} = Hermitian{T, SparseMatrixCSC{T, I}} +const SymSparse{T, I} = Symmetric{T, SparseMatrixCSC{T, I}} +const HermOrSymSparse{T, I} = Union{HermSparse{T, I}, SymSparse{T, I}} + +const DenseMat{T} = Union{StridedMatrix{T}, AdjOrTrans{T, <:StridedVecOrMat{T}}} +const DenseVecOrMat{T} = Union{DenseMat{T}, StridedVector{T}} + +##### +##### Zygote: accum for HermOrSymSparse +##### + +Zygote.accum(x::HermOrSymSparse, y::HermOrSymSparse) = x + y + +##### +##### ChainRulesCore: ProjectTo for HermOrSymSparse +##### + +const SparseProjectToData{T, I} = NamedTuple{ + (:element, :axes, :rowval, :nzranges, :colptr), + Tuple{ + ProjectTo{T, NamedTuple{(), Tuple{}}}, + Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, + Vector{I}, + Vector{UnitRange{Int64}}, + Vector{I}, + }, +} + +const SparseProjectTo{T, I} = ProjectTo{SparseMatrixCSC, SparseProjectToData{T, I}} + +const HermSparseProjectTo{T, I} = ProjectTo{ + Hermitian, + NamedTuple{ + (:uplo, :parent), + Tuple{Symbol, SparseProjectTo{T, I}}, + }, +} + +const SymSparseProjectTo{T, I} = ProjectTo{ + Symmetric, + NamedTuple{ + (:uplo, :parent), + Tuple{Symbol, SparseProjectTo{T, I}}, + }, +} + +function ChainRulesCore.ProjectTo(x::HermSparse{T}) where {T<:Number} + return ProjectTo{Hermitian}(; + uplo=Symbol(x.uplo), + parent=ProjectTo(parent(x)), + ) +end + +function ChainRulesCore.ProjectTo(x::SymSparse{T}) where {T<:Number} + return ProjectTo{Symmetric}(; + uplo=Symbol(x.uplo), + parent=ProjectTo(parent(x)), + ) +end + +function project!(A::SparseMatrixCSC{T, I}, B::SparseMatrixCSC{<:Any, J}, uplo::Char) where {T, I, J} + @assert size(A) == size(B) + + @inbounds for j in axes(A, 2) + p = getcolptr(A)[j] + pstop = getcolptr(A)[j + 1] + q = getcolptr(B)[j] + qstop = getcolptr(B)[j + 1] + + while p < pstop + i = rowvals(A)[p] + + if (uplo == 'L' && i >= j) || (uplo == 'U' && i <= j) + while q < qstop && rowvals(B)[q] < i + q += one(J) + end + + if q < qstop && rowvals(B)[q] == i + nonzeros(A)[p] = nonzeros(B)[q] + else + nonzeros(A)[p] = zero(T) + end + end + + p += one(I) + end + end + + return A +end + +function project!(A::HermOrSymSparse, B::HermOrSymSparse) + if A.uplo == B.uplo + project!(parent(A), parent(B), A.uplo) + elseif A.uplo == 'L' + project!(parent(A), tril(B), A.uplo) + else + project!(parent(A), triu(B), A.uplo) + end + + return A +end + +function sparse_from_project(P::SparseProjectTo{T, I}) where {T, I} + m, n = map(length, P.axes) + return SparseMatrixCSC(m, n, P.colptr, P.rowval, zeros(T, length(P.rowval))) +end + +function sparse_from_project(P::HermSparseProjectTo) + return Hermitian(sparse_from_project(P.parent), P.uplo) +end + +function sparse_from_project(P::SymSparseProjectTo) + return Symmetric(sparse_from_project(P.parent), P.uplo) +end + +function checkpatternsym(n, Acolptr::Vector{IA}, Bcolptr::Vector{IB}, Arowval::AbstractVector, Browval::AbstractVector, uplo::Char) where {IA, IB} + for j in 1:n + pa = Acolptr[j] + pb = Bcolptr[j] + pastop = Acolptr[j + 1] + pbstop = Bcolptr[j + 1] + + while pa < pastop && pb < pbstop + ia = Arowval[pa] + ib = Browval[pb] + + if (uplo == 'L' && ia < j) || (uplo == 'U' && ia > j) + pa += one(IA) + elseif (uplo == 'L' && ib < j) || (uplo == 'U' && ib > j) + pb += one(IB) + elseif ia == ib + pa += one(IA) + pb += one(IB) + else + return false + end + end + + while pa < pastop + ia = Arowval[pa] + + if (uplo == 'L' && ia >= j) || (uplo == 'U' && ia <= j) + return false + end + + pa += one(IA) + end + + while pb < pbstop + ib = Browval[pb] + + if (uplo == 'L' && ib >= j) || (uplo == 'U' && ib <= j) + return false + end + + pb += one(IB) + end + end + + return true +end + +function checkpatternsym(P, dX) + return false +end + +function checkpatternsym(P::Union{HermSparseProjectTo{T, I}, SymSparseProjectTo{T, I}}, dX::HermOrSymSparse{T, I}) where {T, I} + dXP = parent(dX) + return Symbol(dX.uplo) == P.uplo && checkpatternsym(size(dXP, 2), P.parent.colptr, dXP.colptr, P.parent.rowval, dXP.rowval, dX.uplo) +end + +function (P::HermSparseProjectTo{T, I})(dX::HermSparse) where {T, I} + if checkpatternsym(P, dX) + return dX + else + return project!(sparse_from_project(P), dX) + end +end + +function (P::SymSparseProjectTo{T, I})(dX::SymSparse) where {T, I} + if checkpatternsym(P, dX) + return dX + else + return project!(sparse_from_project(P), dX) + end +end + +function (P::HermSparseProjectTo{T, I})(dX::SymSparse{T, I}) where {T <: Real, I} + if checkpatternsym(P, dX) + return Hermitian(parent(dX), P.uplo) + else + return project!(sparse_from_project(P), dX) + end +end + +function (P::SymSparseProjectTo{T, I})(dX::HermSparse{T, I}) where {T <: Real, I} + if checkpatternsym(P, dX) + return Symmetric(parent(dX), P.uplo) + else + return project!(sparse_from_project(P), dX) + end +end + +##### +##### ChainRules: selupd! for computing sparse gradients +##### + +function unwrap(A) + if A isa Adjoint + B = parent(A) + + if B isa Transpose + return (parent(B), Val(:N), Val(:C)) + else + return (B, Val(:T), Val(:C)) + end + elseif A isa Transpose + B = parent(A) + + if B isa Adjoint + return (parent(B), Val(:N), Val(:C)) + else + return (B, Val(:T), Val(:N)) + end + else + return (A, Val(:N), Val(:N)) + end +end + +# SELected UPDate: compute the selected low-rank update +# +# C ← α A Bᴴ + conj(α) B Aᴴ + β C +# +# The update is only applied to the structural nonzeros of C. +function selupd!(C::HermSparse, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) + selupd!(parent(C), C.uplo, A, adjoint(B), α, β) + selupd!(parent(C), C.uplo, B, adjoint(A), conj(α), 1) + return C +end + +# SELected UPDate: compute the selected low-rank update +# +# C ← α A Bᴴ + α conj(B) Aᵀ + β C +# +# The update is only applied to the structural nonzeros of C. +function selupd!(C::SymSparse, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) + selupd!(parent(C), C.uplo, A, adjoint(B), α, β) + selupd!(parent(C), C.uplo, adjoint(transpose(B)), transpose(A), α, 1) + return C +end + +# SELected UPDate: compute the selected low-rank update +# +# C ← α A B + β C +# +# The update is only applied to the structural nonzeros of C. +function selupd!(C::SparseMatrixCSC, uplo::Char, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) + AP, tA, cA = unwrap(A) + BP, tB, cB = unwrap(B) + return selupd_impl!(C, uplo, AP, BP, α, β, tA, cA, tB, cB) +end + +function selupd_impl!(C::SparseMatrixCSC, uplo::Char, A::AbstractVector, B::AbstractVector, α, β, ::Val{tA}, ::Val{cA}, ::Val{tB}, ::Val{cB}) where {tA, cA, tB, cB} + @assert size(C, 1) == size(C, 2) == length(A) == length(B) + + @inbounds for j in axes(C, 2) + Bj = cB === :C ? conj(B[j]) : B[j] + + for p in nzrange(C, j) + i = rowvals(C)[p] + + if (uplo == 'L' && i >= j) || (uplo == 'U' && i <= j) + Ai = cA === :C ? conj(A[i]) : A[i] + + if iszero(β) + nonzeros(C)[p] = α * Ai * Bj + else + nonzeros(C)[p] = β * nonzeros(C)[p] + α * Ai * Bj + end + end + end + end + + return C +end + +function selupd_impl!(C::SparseMatrixCSC, uplo::Char, A::AbstractMatrix, B::AbstractMatrix, α, β, tA::Val{TA}, cA::Val{CA}, tB::Val{TB}, cB::Val{CB}) where {TA, CA, TB, CB} + @assert size(C, 1) == size(C, 2) + + if TA === :N && TB === :N + @assert size(A, 1) == size(C, 1) + @assert size(B, 2) == size(C, 1) + @assert size(A, 2) == size(B, 1) + elseif TA === :N && TB !== :N + @assert size(A, 1) == size(C, 1) + @assert size(B, 1) == size(C, 1) + @assert size(A, 2) == size(B, 2) + elseif TA !== :N && TB === :N + @assert size(A, 2) == size(C, 1) + @assert size(B, 2) == size(C, 1) + @assert size(A, 1) == size(B, 1) + else + @assert size(A, 2) == size(C, 1) + @assert size(B, 1) == size(C, 1) + @assert size(A, 1) == size(B, 2) + end + + if TA === :N + rng = axes(A, 2) + else + rng = axes(A, 1) + end + + if iszero(β) + fill!(nonzeros(C), β) + else + rmul!(nonzeros(C), β) + end + + for k in rng + if TA === :N + Ak = view(A, :, k) + else + Ak = view(A, k, :) + end + + if TB === :N + Bk = view(B, k, :) + else + Bk = view(B, :, k) + end + + selupd_impl!(C, uplo, Ak, Bk, α, 1, tA, cA, tB, cB) + end + + return C +end + +##### +##### ChainRules: rrule/frule implementations +##### + +function mul_rrule_impl(A::HermOrSymSparse, B::DenseVecOrMat, ΔC) + ΔB = A * ΔC + ΔA = if ΔC isa AbstractZero + ZeroTangent() + else + @thunk begin + ΔA = similar(A) + selupd!(ΔA, ΔC, B, 1 / 2, 0) + ΔA + end + end + return ΔA, ΔB +end + +function mul_rrule_impl(A::DenseMat, B::HermSparse, ΔC) + ΔA = ΔC * B + ΔB = if ΔC isa AbstractZero + ZeroTangent() + else + @thunk begin + ΔB = similar(B) + selupd!(ΔB, A', ΔC', 1 / 2, 0) + ΔB + end + end + return ΔA, ΔB +end + +function mul_rrule_impl(A::DenseMat, B::SymSparse, ΔC) + ΔA = ΔC * B + ΔB = if ΔC isa AbstractZero + ZeroTangent() + else + @thunk begin + ΔB = similar(B) + selupd!(ΔB, transpose(ΔC), transpose(A), 1 / 2, 0) + ΔB + end + end + return ΔA, ΔB +end + +function dot_rrule_impl(x::StridedVector, A::HermOrSymSparse, y::StridedVector, Ax::StridedVector, Ay::StridedVector, Δz) + Δx = @thunk Δz * Ay + Δy = @thunk Δz * Ax + + ΔA = if Δz isa AbstractZero + ZeroTangent() + else + @thunk begin + ΔA = similar(A) + selupd!(ΔA, x, y, Δz / 2, 0) + ΔA + end + end + + return Δx, ΔA, Δy +end + +function mul_rrule(A::HermOrSymSparse, B::DenseVecOrMat) + C = A * B + + function pullback(ΔC) + ΔA, ΔB = mul_rrule_impl(A, B, ΔC) + return NoTangent(), ΔA, ΔB + end + + return C, pullback ∘ unthunk +end + +function mul_rrule(A::DenseMat, B::HermOrSymSparse) + C = A * B + + function pullback(ΔC) + ΔA, ΔB = mul_rrule_impl(A, B, ΔC) + return NoTangent(), ΔA, ΔB + end + + return C, pullback ∘ unthunk +end + +function dot_rrule(x::StridedVector, A::HermOrSymSparse, y::StridedVector) + Ax = A * x + Ay = A * y + z = dot(x, Ay) + + function pullback(Δz) + Δx, ΔA, Δy = dot_rrule_impl(x, A, y, Ax, Ay, Δz) + return NoTangent(), Δx, ΔA, Δy + end + + return z, pullback ∘ unthunk +end + +function mul_frule_impl(A, B, dA, dB) + return A * B, dA * B + A * dB +end + +function dot_frule_impl(x::StridedVector, A::HermOrSymSparse, y::StridedVector, dx, dA, dy) + return dot(x, A, y), dot(dx, A, y) + dot(x, A, dy) + dot(x, dA, y) +end + +##### +##### ChainRules: frule / rrule dispatches +##### + +for T in (HermSparse, SymSparse) + # A * X + @eval function ChainRulesCore.frule((_, dA, dX)::Tuple, ::typeof(*), A::$T, X::DenseVecOrMat) + return mul_frule_impl(A, X, dA, dX) + end + + @eval function ChainRulesCore.rrule(::typeof(*), A::$T, X::DenseVecOrMat) + return mul_rrule(A, X) + end + + # X * A + @eval function ChainRulesCore.frule((_, dX, dA)::Tuple, ::typeof(*), X::DenseMat, A::$T) + return mul_frule_impl(X, A, dX, dA) + end + + @eval function ChainRulesCore.rrule(::typeof(*), X::DenseMat, A::$T) + return mul_rrule(X, A) + end + + # dot(x, A, y) - vectors only, matching upstream ChainRules + @eval function ChainRulesCore.frule((_, dx, dA, dy)::Tuple, ::typeof(dot), x::StridedVector, A::$T, y::StridedVector) + return dot_frule_impl(x, A, y, dx, dA, dy) + end + + @eval function ChainRulesCore.rrule(::typeof(dot), x::StridedVector, A::$T, y::StridedVector) + return dot_rrule(x, A, y) + end +end diff --git a/src/solvers/selinv.jl b/src/solvers/selinv.jl index c700fca4..0e56f411 100644 --- a/src/solvers/selinv.jl +++ b/src/solvers/selinv.jl @@ -104,12 +104,12 @@ _selinv_impl(linsolve, alg) = error("Full selected inversion not implemented for function _selinv_impl(linsolve, ::LinearSolve.CHOLMODFactorization) factorization = LinearSolve.@get_cacheval(linsolve, :CHOLMODFactorization) - return SelectedInversion.selinv(factorization; depermute = true).Z + return Symmetric(SparseMatrixCSC(SelectedInversion.selinv(factorization; depermute = true).Z)) end function _selinv_impl(linsolve, ::LinearSolve.CholeskyFactorization) factorization = LinearSolve.@get_cacheval(linsolve, :CholeskyFactorization) - return SelectedInversion.selinv(factorization; depermute = true).Z + return Symmetric(SparseMatrixCSC(SelectedInversion.selinv(factorization; depermute = true).Z)) end function _selinv_impl(linsolve, ::LinearSolve.DiagonalFactorization) @@ -119,7 +119,7 @@ end function _selinv_impl(linsolve, ::LinearSolve.LDLtFactorization) factorization = LinearSolve.@get_cacheval(linsolve, :LDLtFactorization) - return SelectedInversion.selinv(factorization; depermute = true).Z + return Symmetric(SparseMatrixCSC(SelectedInversion.selinv(factorization; depermute = true).Z)) end function _selinv_impl(linsolve, ::LinearSolve.PardisoJL) diff --git a/test/autodiff/runtests.jl b/test/autodiff/runtests.jl index 6ec7c85e..0506e8b3 100644 --- a/test/autodiff/runtests.jl +++ b/test/autodiff/runtests.jl @@ -3,3 +3,4 @@ include("test_forwarddiff_extension.jl") include("test_logpdf.jl") include("test_constructors.jl") include("test_gaussian_approximation.jl") +include("test_gaussian_approximation_chordal.jl") diff --git a/test/autodiff/test_gaussian_approximation_chordal.jl b/test/autodiff/test_gaussian_approximation_chordal.jl new file mode 100644 index 00000000..4b018d32 --- /dev/null +++ b/test/autodiff/test_gaussian_approximation_chordal.jl @@ -0,0 +1,356 @@ +using Test +using GaussianMarkovRandomFields +using Distributions: logpdf, Poisson, Normal +using SparseArrays +using LinearAlgebra +using Random + +using CliqueTrees.Multifrontal: symbolic, chordal, HermTri, Permutation + +using DifferentiationInterface +using FiniteDiff, ForwardDiff, Zygote + +backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] + +@testset "$backend_name ChordalGMRF autodiff tests" for (backend_name, backend) in backends + # Set seed for reproducibility + Random.seed!(42) + fd_backend = AutoFiniteDiff() + + # Helper function to create simple AR(1) precision matrix + function ar_precision(ρ, k) + return spdiagm(-1 => -ρ * ones(k - 1), 0 => ones(k) .+ ρ^2, 1 => -ρ * ones(k - 1)) + end + + # Test pipeline: hyperparameters → ChordalGMRF → gaussian_approximation → logpdf + function test_gauss_approx_pipeline(θ::Vector, y::Vector, x::Vector, P, S, k::Int) + # Extract hyperparameters + ρ = θ[1] # AR parameter + μ_const = θ[2] # constant mean + + # Create precision matrix + Q = ar_precision(ρ, k) + + # Create constant mean vector + μ = μ_const * ones(k) + + # Use chordal (now differentiable!) + J = chordal(Hermitian(Q, :L), P, S) + L = cholesky(J) + + # Create prior ChordalGMRF (pass original Q, not chordal J) + prior_gmrf = ChordalGMRF(μ, Q, L, P) + + # Create Poisson observation likelihood + obs_model = ExponentialFamily(Poisson) + poisson_obs = PoissonObservations(y) + obs_lik = obs_model(poisson_obs) + + # Find Gaussian approximation + posterior_gmrf = gaussian_approximation(prior_gmrf, obs_lik) + + # Compute logpdf at evaluation point x + return logpdf(posterior_gmrf, x) + end + + @testset "Poisson likelihood with gaussian_approximation" begin + k = 8 + θ = [0.4, 0.5] # [ρ, μ_const] + y = [2, 1, 3, 2, 1, 4, 2, 1] # Poisson count data + x = randn(k) .+ 0.5 # Evaluation point + + # Pre-compute sparsity structure + Q_ref = ar_precision(0.5, k) + P, S = symbolic(Q_ref) + + grad_test = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-3 + @test maximum(rel_error) < 5.0e-2 + end + + @testset "Different hyperparameter values" begin + k = 6 + y = [1, 2, 1, 3, 1, 2] + x = randn(k) .+ 0.3 + + # Pre-compute sparsity structure + Q_ref = ar_precision(0.5, k) + P, S = symbolic(Q_ref) + + # Test different ρ and μ values + for ρ in [0.2, 0.5] + for μ_const in [0.3, 0.8] + θ = [ρ, μ_const] + + grad_test = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 2.0e-2 # Relaxed for complex optimization + @test maximum(rel_error) < 5.0e-2 + end + end + end + + @testset "Gaussian (conjugate) likelihood" begin + # Test with Gaussian likelihood - should also work through rrule + k = 6 + θ = [0.3, 0.1] + y = randn(k) .* 0.3 .+ 0.2 + x = randn(k) + + # Pre-compute sparsity structure + Q_ref = ar_precision(0.5, k) + P, S = symbolic(Q_ref) + + function gaussian_lik_pipeline(θ, y, x, P, S, k) + ρ, μ_const = θ + Q = ar_precision(ρ, k) + μ = μ_const * ones(k) + + J = chordal(Hermitian(Q, :L), P, S) + L = cholesky(J) + prior_gmrf = ChordalGMRF(μ, Q, L, P) + + obs_model = ExponentialFamily(Normal) + obs_lik = obs_model(y; σ = 0.5) + posterior_gmrf = gaussian_approximation(prior_gmrf, obs_lik) + return logpdf(posterior_gmrf, x) + end + + grad_test = DifferentiationInterface.gradient( + θ -> gaussian_lik_pipeline(θ, y, x, P, S, k), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> gaussian_lik_pipeline(θ, y, x, P, S, k), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-3 + @test maximum(rel_error) < 5.0e-2 + end + + @testset "Small system" begin + # Test with very small system + k = 4 + θ = [0.6, 0.4] + y = [1, 2, 1, 1] + x = randn(k) .+ 0.4 + + # Pre-compute sparsity structure + Q_ref = ar_precision(0.5, k) + P, S = symbolic(Q_ref) + + grad_test = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-3 + @test maximum(rel_error) < 5.0e-2 + end + + @testset "Basic logpdf autodiff with chordal" begin + k = 10 + z = randn(k) + + # Pre-compute sparsity structure (not differentiable) + Q_ref = ar_precision(0.5, k) + P, S = symbolic(Q_ref) + + # Test pipeline: use chordal which IS now differentiable + function test_chordal_pipeline(θ::AbstractVector, z::AbstractVector, P, S, k) + ρ = θ[1] + μ_const = θ[2] + + Q = ar_precision(ρ, k) + μ = μ_const * ones(k) + + # Use chordal directly (now differentiable!) + J = chordal(Hermitian(Q, :L), P, S) + L = cholesky(J) + + gmrf = ChordalGMRF(μ, Q, L, P) + return logpdf(gmrf, z) + end + + θ = [0.5, 0.1] + + # Compute gradients using AD backend + grad_test = DifferentiationInterface.gradient( + θ -> test_chordal_pipeline(θ, z, P, S, k), + backend, + θ + ) + + # Compute gradients using finite differences + grad_fd = DifferentiationInterface.gradient( + θ -> test_chordal_pipeline(θ, z, P, S, k), + fd_backend, + θ + ) + + # Check AD gradients match finite differences + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-4 + @test maximum(rel_error) < 1.0e-2 + end + + @testset "2D grid precision with chordal" begin + # Build 2D grid precision matrix using spdiagm (Zygote-compatible) + function grid_precision(α, grid_size) + n = grid_size^2 + + # Main diagonal + diag_main = fill(4.0 + α, n) + + # Horizontal neighbors (±1 diagonals), but skip row boundaries + horiz = [-1.0 * (mod(i, grid_size) != 0) for i in 1:(n-1)] + + # Vertical neighbors (±grid_size diagonals) + vert = fill(-1.0, n - grid_size) + + return spdiagm(-grid_size => vert, -1 => horiz, 0 => diag_main, 1 => horiz, grid_size => vert) + end + + grid_size = 4 + n = grid_size^2 + z = randn(n) + + Q_ref = grid_precision(0.5, grid_size) + P, S = symbolic(Q_ref) + + function test_grid_pipeline(θ::AbstractVector, z::AbstractVector, P, S, grid_size) + α = θ[1] + μ_const = θ[2] + + Q = grid_precision(α, grid_size) + n = grid_size^2 + μ = μ_const * ones(n) + + J = chordal(Hermitian(Q, :L), P, S) + L = cholesky(J) + + gmrf = ChordalGMRF(μ, Q, L, P) + return logpdf(gmrf, z) + end + + θ = [0.5, 0.1] + + grad_test = DifferentiationInterface.gradient( + θ -> test_grid_pipeline(θ, z, P, S, grid_size), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> test_grid_pipeline(θ, z, P, S, grid_size), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-4 + @test maximum(rel_error) < 1.0e-2 + end + + @testset "2D grid with Poisson gaussian_approximation" begin + function grid_precision(α, grid_size) + n = grid_size^2 + diag_main = fill(4.0 + α, n) + horiz = [-1.0 * (mod(i, grid_size) != 0) for i in 1:(n-1)] + vert = fill(-1.0, n - grid_size) + return spdiagm(-grid_size => vert, -1 => horiz, 0 => diag_main, 1 => horiz, grid_size => vert) + end + + grid_size = 3 + n = grid_size^2 + y = [2, 1, 3, 2, 1, 4, 2, 1, 2] # Poisson count data + x = randn(n) .+ 0.5 + + Q_ref = grid_precision(0.5, grid_size) + P, S = symbolic(Q_ref) + + function test_grid_gauss_approx(θ, y, x, P, S, grid_size) + α, μ_const = θ + Q = grid_precision(α, grid_size) + n = grid_size^2 + μ = μ_const * ones(n) + + J = chordal(Hermitian(Q, :L), P, S) + L = cholesky(J) + prior_gmrf = ChordalGMRF(μ, Q, L, P) + + obs_model = ExponentialFamily(Poisson) + obs_lik = obs_model(PoissonObservations(y)) + posterior_gmrf = gaussian_approximation(prior_gmrf, obs_lik) + return logpdf(posterior_gmrf, x) + end + + θ = [0.5, 0.3] + + grad_test = DifferentiationInterface.gradient( + θ -> test_grid_gauss_approx(θ, y, x, P, S, grid_size), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> test_grid_gauss_approx(θ, y, x, P, S, grid_size), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-2 + @test maximum(rel_error) < 5.0e-2 + end +end diff --git a/test/gaussian_approximation/test_gaussian_approximation_chordal.jl b/test/gaussian_approximation/test_gaussian_approximation_chordal.jl new file mode 100644 index 00000000..32cf7a7f --- /dev/null +++ b/test/gaussian_approximation/test_gaussian_approximation_chordal.jl @@ -0,0 +1,194 @@ +using Test +using GaussianMarkovRandomFields +using LinearAlgebra +using SparseArrays +using Distributions + +@testset "Gaussian Approximation - ChordalGMRF" begin + + @testset "Gaussian Likelihood - Analytical Solution" begin + # For Gaussian likelihood, the Gaussian approximation should be exact + n = 5 + Q_prior = spdiagm(0 => 2.0 * ones(n)) + μ_prior = zeros(n) + prior_gmrf = ChordalGMRF(μ_prior, Q_prior) + + obs_model = ExponentialFamily(Normal) + y = [0.1, 0.2, -0.1, 0.3, -0.2] + + obs_lik = obs_model(y; σ = 0.5) + result = gaussian_approximation(prior_gmrf, obs_lik) + + # Should return a ChordalGMRF + @test result isa ChordalGMRF + @test length(mean(result)) == n + + # For Gaussian case, verify against analytical solution + σ = 0.5 + Q_obs = sparse(I, n, n) / σ^2 + Q_analytical = Q_prior + Q_obs + μ_analytical = Q_analytical \ (Q_prior * μ_prior + Q_obs * y) + + @test precision_matrix(result) ≈ Q_analytical atol = 1e-8 + @test mean(result) ≈ μ_analytical atol = 1e-8 + end + + @testset "Bernoulli Likelihood - Mathematical Properties" begin + # Test with Bernoulli observation model (non-linear) + n = 8 + Q_prior = spdiagm(0 => ones(n), 1 => fill(-0.3, n-1), -1 => fill(-0.3, n-1)) + μ_prior = zeros(n) + prior_gmrf = ChordalGMRF(μ_prior, Q_prior) + + obs_model = ExponentialFamily(Bernoulli) + y = [1, 1, 0, 1, 0, 0, 1, 0] + + obs_lik = obs_model(y) + result = gaussian_approximation(prior_gmrf, obs_lik) + + # Should return a ChordalGMRF + @test result isa ChordalGMRF + @test length(mean(result)) == n + + # Mode should reflect the data pattern + μ_result = mean(result) + @test μ_result[1] > 0 # First observation is 1 + @test μ_result[3] < 0 # Third observation is 0 + end + + @testset "Poisson Likelihood - Mathematical Properties" begin + # Test with Poisson observation model + n = 6 + Q_prior = spdiagm(0 => ones(n)) + μ_prior = zeros(n) + prior_gmrf = ChordalGMRF(μ_prior, Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([1, 3, 0, 2, 4, 1]) + + obs_lik = obs_model(y) + result = gaussian_approximation(prior_gmrf, obs_lik) + + # Should return a ChordalGMRF + @test result isa ChordalGMRF + @test length(mean(result)) == n + + # Mode should be reasonable for Poisson data + μ_result = mean(result) + @test all(isfinite.(μ_result)) + + # Higher counts should correspond to higher modes + @test μ_result[5] > μ_result[3] # y[5]=4 > y[3]=0 + @test μ_result[2] > μ_result[3] # y[2]=3 > y[3]=0 + end + + @testset "Consistency with GMRF" begin + # Results should match between GMRF and ChordalGMRF + n = 5 + Q_prior = spdiagm(0 => 2.0 * ones(n), 1 => fill(-0.5, n-1), -1 => fill(-0.5, n-1)) + μ_prior = zeros(n) + + gmrf_prior = GMRF(μ_prior, Q_prior) + chordal_prior = ChordalGMRF(μ_prior, Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([2, 5, 1, 3, 4]) + obs_lik = obs_model(y) + + result_gmrf = gaussian_approximation(gmrf_prior, obs_lik) + result_chordal = gaussian_approximation(chordal_prior, obs_lik) + + @test mean(result_gmrf) ≈ mean(result_chordal) atol = 1e-6 + @test precision_matrix(result_gmrf) ≈ precision_matrix(result_chordal) atol = 1e-6 + end + + @testset "Sparse precision - tridiagonal" begin + # Test with tridiagonal precision (common in GMRFs) + n = 10 + Q_prior = spdiagm(0 => 2.0 * ones(n), 1 => -ones(n-1), -1 => -ones(n-1)) + μ_prior = zeros(n) + prior_gmrf = ChordalGMRF(μ_prior, Q_prior) + + obs_model = ExponentialFamily(Normal) + y = randn(n) + obs_lik = obs_model(y; σ = 0.5) + + result = gaussian_approximation(prior_gmrf, obs_lik) + + @test result isa ChordalGMRF + @test length(mean(result)) == n + @test all(isfinite.(mean(result))) + end + + @testset "Warm-start with x0" begin + n = 6 + Q_prior = spdiagm(0 => ones(n)) + prior_gmrf = ChordalGMRF(zeros(n), Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([5, 12, 2, 8, 15, 3]) + obs_lik = obs_model(y) + + # Cold-start result + result_cold = gaussian_approximation(prior_gmrf, obs_lik) + x_star = mean(result_cold) + + # Warm-start from converged mode + result_warm = gaussian_approximation(prior_gmrf, obs_lik; x0 = x_star) + @test mean(result_warm) ≈ x_star atol = 1e-4 + end + + @testset "Adaptive stepsize - extreme Poisson" begin + # Test case where adaptive stepsize is needed + n = 3 + Q_prior = spdiagm(0 => 0.01 * ones(n)) # Weak prior + prior_gmrf = ChordalGMRF(zeros(n), Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([100, 500, 50]) + obs_lik = obs_model(y) + + result = gaussian_approximation(prior_gmrf, obs_lik) + + @test result isa ChordalGMRF + μ_result = mean(result) + @test all(isfinite.(μ_result)) + + # Each component should be close to log of its count + for i in 1:n + @test abs(μ_result[i] - log(y.counts[i])) < 1.5 + end + end + + @testset "Non-convergence path" begin + # Force non-convergence with max_iter = 1 + n = 5 + Q_prior = spdiagm(0 => ones(n)) + prior_gmrf = ChordalGMRF(zeros(n), Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([10, 15, 8, 12, 20]) + obs_lik = obs_model(y) + + result = gaussian_approximation(prior_gmrf, obs_lik; max_iter = 1) + + # Should still return a ChordalGMRF + @test result isa ChordalGMRF + @test all(isfinite.(mean(result))) + end + + @testset "Verbose output" begin + n = 4 + Q_prior = spdiagm(0 => ones(n)) + prior_gmrf = ChordalGMRF(zeros(n), Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([2, 3, 1, 4]) + obs_lik = obs_model(y) + + # Should not error with verbose=true + result = gaussian_approximation(prior_gmrf, obs_lik; verbose = true) + @test result isa ChordalGMRF + end + +end From 9c1aabbe1f038404e3b792f53f5383b77afdc406 Mon Sep 17 00:00:00 2001 From: Richard Samuelson Date: Tue, 31 Mar 2026 14:05:13 -0400 Subject: [PATCH 02/11] Remove dependencies. --- Project.toml | 6 ++---- ext/GaussianMarkovRandomFieldsAutoDiff.jl | 19 ++++++++++++++++++- src/piracy.jl | 7 ------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 3714c1c6..5b095b0b 100644 --- a/Project.toml +++ b/Project.toml @@ -12,8 +12,6 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Ferrite = "c061ca5d-56c9-439f-9c0e-210fe06d3992" FerriteGmsh = "4f95f4f8-b27c-4ae5-9a39-ea55e634e36b" -FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GeoInterface = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" GeometryBasics = "5c1252a2-5f33-56bf-86c9-59e7332b4326" Gmsh = "705231aa-382f-11e9-3f0c-b7cb4346fdeb" @@ -22,7 +20,6 @@ LibGEOS = "a90b1aa1-3769-5649-ba7e-abc5a9d163eb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" -MatrixDepot = "b51810bb-c9f3-55da-ae3c-350fc1fbce05" NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SelectedInversion = "043bf095-3f01-458a-9f1c-8cf4448fe908" @@ -31,10 +28,10 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" Shapefile = "8e980c4a-a4fe-5da2-b3a7-4b4b0353a2f4" @@ -43,6 +40,7 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] GaussianMarkovRandomFieldsAutoDiff = ["ForwardDiff", "Zygote"] diff --git a/ext/GaussianMarkovRandomFieldsAutoDiff.jl b/ext/GaussianMarkovRandomFieldsAutoDiff.jl index 17de68a8..09f5ace6 100644 --- a/ext/GaussianMarkovRandomFieldsAutoDiff.jl +++ b/ext/GaussianMarkovRandomFieldsAutoDiff.jl @@ -1,9 +1,26 @@ module GaussianMarkovRandomFieldsAutoDiff using GaussianMarkovRandomFields -using ForwardDiff, Zygote, LinearAlgebra, LinearMaps +using ForwardDiff, Zygote, LinearAlgebra, LinearMaps, SparseArrays import LinearMaps: _unsafe_mul! +# | | | +# )_) )_) )_) +# )___))___))___) +# )____)____)_____) +# _____|____|____|____ +# ---------\ /--------- +# ^^^^^ ^^^^^^^^^^^^^^^^^^^^^ +# ^^^^ ^^^^ ^^^ ^^ +# ^^^^ ^^^ +# +# Zygote accum for sparse Hermitian/Symmetric (piracy until upstream PR is merged) +const HermOrSymSparse{T, I} = Union{ + Hermitian{T, SparseMatrixCSC{T, I}}, + Symmetric{T, SparseMatrixCSC{T, I}} +} +Zygote.accum(x::HermOrSymSparse, y::HermOrSymSparse) = x + y + function LinearMaps._unsafe_mul!(y, J::ADJacobianMap, x::AbstractVector) g(t) = J.f(J.x₀ + t * x) return y .= ForwardDiff.derivative(g, 0.0) diff --git a/src/piracy.jl b/src/piracy.jl index 9621147c..fa7b0232 100644 --- a/src/piracy.jl +++ b/src/piracy.jl @@ -18,7 +18,6 @@ using LinearAlgebra using LinearAlgebra: Hermitian, Symmetric, Adjoint, Transpose, AdjOrTrans, dot, rmul!, tril, triu using SparseArrays using SparseArrays: SparseMatrixCSC, nzrange, rowvals, getcolptr, nonzeros -using Zygote ##### ##### Type aliases @@ -31,12 +30,6 @@ const HermOrSymSparse{T, I} = Union{HermSparse{T, I}, SymSparse{T, I}} const DenseMat{T} = Union{StridedMatrix{T}, AdjOrTrans{T, <:StridedVecOrMat{T}}} const DenseVecOrMat{T} = Union{DenseMat{T}, StridedVector{T}} -##### -##### Zygote: accum for HermOrSymSparse -##### - -Zygote.accum(x::HermOrSymSparse, y::HermOrSymSparse) = x + y - ##### ##### ChainRulesCore: ProjectTo for HermOrSymSparse ##### From f39711a9029c44b6423dc50f6ac42c4c336d085c Mon Sep 17 00:00:00 2001 From: Richard Samuelson Date: Tue, 31 Mar 2026 14:11:01 -0400 Subject: [PATCH 03/11] Bump CliqueTrees compat. --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5b095b0b..76ebecdb 100644 --- a/Project.toml +++ b/Project.toml @@ -57,7 +57,7 @@ GaussianMarkovRandomFieldsSparseJacobian = ["Symbolics", "SparseDiffTools"] AMD = "0.5" Aqua = "0.8" ChainRulesCore = "1" -CliqueTrees = "1.18.0" +CliqueTrees = "1.19" DataStructures = "0.14 - 0.19" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25" From 8e3c0ed52567121e76bbcff998b2f1d38925ddca Mon Sep 17 00:00:00 2001 From: Richard Samuelson Date: Tue, 31 Mar 2026 14:18:29 -0400 Subject: [PATCH 04/11] Clean Project.toml. --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 76ebecdb..fcacad58 100644 --- a/Project.toml +++ b/Project.toml @@ -75,7 +75,6 @@ LinearAlgebra = "<0.0.1, 1" LinearMaps = "3.11" LinearSolve = "2, 3" Makie = "0.19 - 0.22" -MatrixDepot = "1.0.15" NearestNeighbors = "0.4" Pardiso = "1" Random = "<0.0.1, 1" From 44f1cec69fdeafdb4247c054789a0db636a4551d Mon Sep 17 00:00:00 2001 From: Richard Samuelson Date: Tue, 31 Mar 2026 15:06:03 -0400 Subject: [PATCH 05/11] Fix bug. --- src/solvers/selinv.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/solvers/selinv.jl b/src/solvers/selinv.jl index 0e56f411..521c12e2 100644 --- a/src/solvers/selinv.jl +++ b/src/solvers/selinv.jl @@ -104,12 +104,12 @@ _selinv_impl(linsolve, alg) = error("Full selected inversion not implemented for function _selinv_impl(linsolve, ::LinearSolve.CHOLMODFactorization) factorization = LinearSolve.@get_cacheval(linsolve, :CHOLMODFactorization) - return Symmetric(SparseMatrixCSC(SelectedInversion.selinv(factorization; depermute = true).Z)) + return Symmetric(sparse(SelectedInversion.selinv(factorization; depermute = true).Z)) end function _selinv_impl(linsolve, ::LinearSolve.CholeskyFactorization) factorization = LinearSolve.@get_cacheval(linsolve, :CholeskyFactorization) - return Symmetric(SparseMatrixCSC(SelectedInversion.selinv(factorization; depermute = true).Z)) + return Symmetric(sparse(SelectedInversion.selinv(factorization; depermute = true).Z)) end function _selinv_impl(linsolve, ::LinearSolve.DiagonalFactorization) @@ -119,7 +119,7 @@ end function _selinv_impl(linsolve, ::LinearSolve.LDLtFactorization) factorization = LinearSolve.@get_cacheval(linsolve, :LDLtFactorization) - return Symmetric(SparseMatrixCSC(SelectedInversion.selinv(factorization; depermute = true).Z)) + return Symmetric(sparse(SelectedInversion.selinv(factorization; depermute = true).Z)) end function _selinv_impl(linsolve, ::LinearSolve.PardisoJL) From be26f50ac50f17842342f9d38b828249c930cd95 Mon Sep 17 00:00:00 2001 From: Richard Samuelson Date: Tue, 31 Mar 2026 15:08:49 -0400 Subject: [PATCH 06/11] Larger test matrices. --- benchmarks/logpdf_comparison.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/logpdf_comparison.jl b/benchmarks/logpdf_comparison.jl index 332ee0f4..75f46e7f 100644 --- a/benchmarks/logpdf_comparison.jl +++ b/benchmarks/logpdf_comparison.jl @@ -37,12 +37,12 @@ end # Handle Symmetric wrapper from MatrixDepot make_posdef(A::Symmetric) = make_posdef(sparse(A)) -# Test matrices from SSMC (larger for ~100ms target) +# Test matrices from SSMC (larger for meaningful benchmarks) test_matrices = [ - ("HB/bcsstk14", "Structural, n=1806"), ("HB/bcsstk15", "Structural, n=3948"), ("HB/bcsstk16", "Structural, n=4884"), ("HB/bcsstk17", "Structural, n=10974"), + ("HB/bcsstk18", "Structural, n=11948"), ] println("\nTest matrices:") From 8f38d118966a400cb65cfbb26d554e94785cf182 Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Fri, 3 Apr 2026 18:09:01 +0200 Subject: [PATCH 07/11] Unify gaussian_approximation for GMRF and ChordalGMRF Extract solver abstraction helpers (_ga_init_solver, _ga_update_and_solve!, _ga_make_posterior) so the Newton iteration and IFT-based rrule each exist once, dispatching on GMRF type. Removes ~130 lines of duplicated code. --- .../condition/gaussian_approximation.jl | 174 +++++------------ src/autodiff/gaussian_approximation.jl | 178 ++++++------------ 2 files changed, 103 insertions(+), 249 deletions(-) diff --git a/src/arithmetic/condition/gaussian_approximation.jl b/src/arithmetic/condition/gaussian_approximation.jl index 1574d335..bcf48e9c 100644 --- a/src/arithmetic/condition/gaussian_approximation.jl +++ b/src/arithmetic/condition/gaussian_approximation.jl @@ -24,6 +24,7 @@ end # Private dispatch: extract constraints if present _extract_constraints(::GMRF) = nothing _extract_constraints(constrained::ConstrainedGMRF) = (A = constrained.constraint_matrix, e = constrained.constraint_vector) +_extract_constraints(::ChordalGMRF) = nothing # Private dispatch: apply constraints to result _apply_constraints(gmrf::GMRF, ::Nothing) = gmrf @@ -32,6 +33,7 @@ _apply_constraints(gmrf::GMRF, constraints::NamedTuple) = ConstrainedGMRF(gmrf, # Private dispatch: extract base GMRF for optimization _base_gmrf(gmrf::GMRF) = gmrf _base_gmrf(constrained::ConstrainedGMRF) = constrained.base_gmrf +_base_gmrf(gmrf::ChordalGMRF) = gmrf # Compute the constrained Newton step via the KKT Schur complement. # Solves H⁻¹Aᵀ using the existing linsolve cache (m sparse solves, m = #constraints), @@ -74,6 +76,34 @@ function _update_linsolve_cache_inner!(cache, Q, alg::LinearSolve.LDLtFactorizat return cache.A = copy(Q) end +# Solver abstraction for gaussian_approximation Newton iteration. +# Allows shared iteration logic for both LinearSolve-backed GMRF and ChordalCholesky-backed ChordalGMRF. +_ga_init_solver(gmrf::GMRF) = deepcopy(linsolve_cache(gmrf)) +_ga_init_solver(gmrf::ChordalGMRF{T}) where {T} = ChordalCholesky{:L, T}(gmrf.P, gmrf.L.S) + +function _ga_update_and_solve!(solver, Q_base, H_k, b, ::GMRF) + Q_new = prepare_for_linsolve(Q_base - H_k, solver.alg) + _update_linsolve_cache!(solver, Q_new) + solver.b = b + return Q_new, copy(solve!(solver).u) +end + +function _ga_update_and_solve!(solver, Q_base, H_k, b, ::ChordalGMRF) + Q_new = hermdiff(Q_base, H_k) + copyto!(solver, Q_new) + cholesky!(solver) + return Q_new, solver \ b +end + +function _ga_make_posterior(x, Q, solver, prior::Union{GMRF, ConstrainedGMRF}, constraints) + new_gmrf = GMRF(x, Q; linsolve_cache = solver) + return _apply_constraints(new_gmrf, constraints) +end + +function _ga_make_posterior(x, Q, solver, prior::ChordalGMRF, ::Nothing) + return ChordalGMRF(x, Q, solver.L, prior.P) +end + """ gaussian_approximation(prior_gmrf, obs_lik; kwargs...) -> AbstractGMRF @@ -82,11 +112,11 @@ Find Gaussian approximation to the posterior using Fisher scoring. This function finds the mode of the posterior distribution and constructs a Gaussian approximation around it using Fisher scoring (Newton-Raphson with Fisher information matrix). -Works for both regular `GMRF` and `ConstrainedGMRF` priors, automatically handling +Works for `GMRF`, `ConstrainedGMRF`, and `ChordalGMRF` priors, automatically handling constraint projection when needed. # Arguments -- `prior_gmrf`: Prior GMRF distribution for the latent field (GMRF or ConstrainedGMRF) +- `prior_gmrf`: Prior GMRF distribution for the latent field (GMRF, ConstrainedGMRF, or ChordalGMRF) - `obs_lik`: Materialized observation likelihood (contains data and hyperparameters) # Keyword Arguments @@ -115,7 +145,7 @@ posterior_gmrf = gaussian_approximation(prior_gmrf, obs_lik; adaptive_stepsize=f ``` """ function gaussian_approximation( - prior_gmrf::Union{GMRF, ConstrainedGMRF}, + prior_gmrf::Union{GMRF, ConstrainedGMRF, ChordalGMRF}, obs_lik::ObservationLikelihood; x0::Union{Nothing, AbstractVector} = nothing, max_iter::Int = 50, @@ -125,14 +155,14 @@ function gaussian_approximation( max_linesearch_iter::Int = 10, verbose::Bool = false ) - # Extract base GMRF and constraints (nothing for regular GMRF) + # Extract base GMRF and constraints (nothing for GMRF/ChordalGMRF) base_gmrf = _base_gmrf(prior_gmrf) constraints = _extract_constraints(prior_gmrf) # Initialize with provided starting point or prior mean - x_k = x0 === nothing ? mean(prior_gmrf) : copy(x0) + x_k = x0 === nothing ? copy(mean(prior_gmrf)) : copy(x0) - cache = deepcopy(linsolve_cache(base_gmrf)) + solver = _ga_init_solver(base_gmrf) Q_base = precision_matrix(base_gmrf) # Adaptive stepsize state (persists across outer iterations) @@ -141,21 +171,13 @@ function gaussian_approximation( verbose && println("Starting Fisher scoring...") for iter in 1:max_iter - # Compute observation likelihood derivatives at current point - H_k = loghessian(x_k, obs_lik) # Hessian: ∇²ₓ log p(y|x) - - # Update precision: Q_new = Q_base - H_k (note: H_k contains negative of Hessian) - Q_new = prepare_for_linsolve(Q_base - H_k, cache.alg) - - _update_linsolve_cache!(cache, Q_new) + H_k = loghessian(x_k, obs_lik) neg_score_k = ∇ₓ_neg_log_posterior(base_gmrf, obs_lik, x_k) - cache.b = neg_score_k - step = copy(solve!(cache).u) + Q_new, step = _ga_update_and_solve!(solver, Q_base, H_k, neg_score_k, base_gmrf) # For constrained problems, project step onto constraint tangent space - # via the KKT Schur complement (m sparse solves, m = #constraints). - # This ensures A*step = 0, so x_k - α*step stays on the manifold. - step = _constrain_step(step, cache, constraints) + # via the KKT Schur complement. No-op when constraints are nothing. + step = _constrain_step(step, solver, constraints) # Apply step with adaptive line search or full step if adaptive_stepsize @@ -167,19 +189,17 @@ function gaussian_approximation( obj_candidate = neg_log_posterior(base_gmrf, obs_lik, candidate) if obj_candidate <= obj_current - # Accept step, increase α toward 1 - μ_new = candidate + x_new = candidate α = sqrt(α) accept = true verbose && ls_iter > 1 && println(" Accepted at α=$(round(α^2, digits = 3)) after $ls_iter backtracks") break else - # Reject, reduce stepsize α *= 0.1 verbose && println(" Backtrack: α=$(round(α, digits = 4))") if α * norm(step, Inf) < newton_dec_tol / 1000 - μ_new = candidate + x_new = candidate accept = true break end @@ -187,38 +207,32 @@ function gaussian_approximation( end if !accept - μ_new = x_k - α * step + x_new = x_k - α * step end else - μ_new = x_k - step + x_new = x_k - step end - # Newton decrement: dot(g, constrained_step) = -g'Δx_nt = Δx_nt'HΔx_nt ≥ 0 newton_decrement = dot(neg_score_k, step) - - x_new = μ_new mean_change = norm(x_new - x_k) - mean_change_rel = mean_change / norm(x_k) + mean_change_rel = mean_change / max(norm(x_k), 1.0e-10) verbose && println(" Iter $iter: Newton dec = $(round(newton_decrement, sigdigits = 3)), α = $(round(α, digits = 3))") if (newton_decrement < newton_dec_tol) || (mean_change < mean_change_tol) || (mean_change_rel < mean_change_tol) verbose && println(" Converged after $iter iterations") - new_gmrf = GMRF(x_new, Q_new; linsolve_cache = cache) - return _apply_constraints(new_gmrf, constraints) + return _ga_make_posterior(x_new, Q_new, solver, prior_gmrf, constraints) end - # Update for next iteration x_k = x_new end verbose && println(" Reached max_iter = $max_iter without convergence") - # Return current best approximation + # Return current best approximation at final x_k H_k = loghessian(x_k, obs_lik) - Q_final = prepare_for_linsolve(Q_base - H_k, cache.alg) - cache.A = Q_final - final_gmrf = GMRF(x_k, Q_final; linsolve_cache = cache) - return _apply_constraints(final_gmrf, constraints) + neg_score_k = ∇ₓ_neg_log_posterior(base_gmrf, obs_lik, x_k) + Q_final, _ = _ga_update_and_solve!(solver, Q_base, H_k, neg_score_k, base_gmrf) + return _ga_make_posterior(x_k, Q_final, solver, prior_gmrf, constraints) end # Specialized dispatch for Normal observation likelihoods with identity link (conjugate prior case) @@ -316,91 +330,3 @@ function gaussian_approximation(prior_mgmrf::MetaGMRF, obs_lik::ObservationLikel posterior_gmrf = gaussian_approximation(prior_mgmrf.gmrf, obs_lik; kwargs...) return MetaGMRF(posterior_gmrf, prior_mgmrf.metadata) end - -# ChordalGMRF dispatch - uses chordal factorization for efficient solves -function gaussian_approximation( - prior_gmrf::ChordalGMRF{T}, - obs_lik::ObservationLikelihood; - x0::Union{Nothing, AbstractVector} = nothing, - max_iter::Int = 50, - mean_change_tol::Real = 1.0e-4, - newton_dec_tol::Real = 1.0e-5, - adaptive_stepsize::Bool = true, - max_linesearch_iter::Int = 10, - verbose::Bool = false - ) where {T} - S = prior_gmrf.L.S - P = prior_gmrf.P - Q_prior = prior_gmrf.Q - - x_k = isnothing(x0) ? copy(mean(prior_gmrf)) : copy(x0) - - α = 1.0 - Q_new = Q_prior - F_new = ChordalCholesky{:L, T}(P, S) - - verbose && println("Starting Fisher scoring (ChordalGMRF)...") - - for iter in 1:max_iter - H_k = loghessian(x_k, obs_lik) - Q_new = hermdiff(Q_prior, H_k) - copyto!(F_new, Q_new) - cholesky!(F_new) - neg_score_k = ∇ₓ_neg_log_posterior(prior_gmrf, obs_lik, x_k) - step = F_new \ neg_score_k - - if adaptive_stepsize - obj_current = neg_log_posterior(prior_gmrf, obs_lik, x_k) - step_accepted = false - - for ls_iter in 1:max_linesearch_iter - candidate = x_k - α * step - obj_candidate = neg_log_posterior(prior_gmrf, obs_lik, candidate) - - if obj_candidate <= obj_current - x_new = candidate - α = sqrt(α) - step_accepted = true - verbose && ls_iter > 1 && println(" Accepted at α=$(round(α^2, digits = 3)) after $ls_iter backtracks") - break - else - α *= 0.1 - verbose && println(" Backtrack: α=$(round(α, digits = 4))") - - if α * norm(step, Inf) < newton_dec_tol / 1000 - x_new = candidate - step_accepted = true - break - end - end - end - - if !step_accepted - x_new = x_k - α * step - end - else - x_new = x_k - step - end - - newton_decrement = dot(neg_score_k, step) - mean_change = norm(x_new - x_k) - mean_change_rel = mean_change / max(norm(x_k), 1e-10) - - verbose && println(" Iter $iter: Newton dec = $(round(newton_decrement, sigdigits = 3)), α = $(round(α, digits = 3))") - - if (newton_decrement < newton_dec_tol) || (mean_change < mean_change_tol) || (mean_change_rel < mean_change_tol) - verbose && println(" Converged after $iter iterations") - return ChordalGMRF(x_new, Q_new, F_new.L, P) - end - - x_k = x_new - end - - verbose && println(" Reached max_iter = $max_iter without convergence") - - H_k = loghessian(x_k, obs_lik) - Q_new = Q_prior - H_k - copyto!(F_new, Hermitian(Q_new, :L)) - cholesky!(F_new) - return ChordalGMRF(x_k, Q_new, F_new.L, P) -end diff --git a/src/autodiff/gaussian_approximation.jl b/src/autodiff/gaussian_approximation.jl index 34eccc32..dfae76a7 100644 --- a/src/autodiff/gaussian_approximation.jl +++ b/src/autodiff/gaussian_approximation.jl @@ -206,6 +206,7 @@ end # Wrap a base GMRF tangent in a ConstrainedGMRF tangent if the prior is constrained. _wrap_prior_tangent(base_tangent, ::GMRF) = base_tangent +_wrap_prior_tangent(base_tangent, ::ChordalGMRF) = base_tangent function _wrap_prior_tangent(base_tangent, prior::ConstrainedGMRF) return Tangent{typeof(prior)}(; base_gmrf = base_tangent, @@ -218,9 +219,38 @@ function _wrap_prior_tangent(base_tangent, prior::ConstrainedGMRF) ) end +# Combine x* tangent contributions from the mean and precision paths, handling zero tangents. +function _combine_x_tangents(x̄_from_μ, x̄_from_Q, x_star) + if _is_zero_tangent(x̄_from_μ) && _is_zero_tangent(x̄_from_Q) + return zeros(eltype(x_star), length(x_star)) + elseif _is_zero_tangent(x̄_from_μ) + return collect(x̄_from_Q) + elseif _is_zero_tangent(x̄_from_Q) + return collect(x̄_from_μ) + else + return collect(x̄_from_μ) .+ collect(x̄_from_Q) + end +end + +# IFT linear solve: solve Q_post * λ = x̄_total, with constraint projection for GMRF/ConstrainedGMRF. +function _ift_solve(posterior::Union{GMRF, ConstrainedGMRF}, x̄_total, prior_gmrf) + cache = linsolve_cache(_base_gmrf(posterior)) + b_saved = copy(cache.b) + cache.b = x̄_total + λ = copy(solve!(cache).u) + λ = _constrain_step(λ, cache, _extract_constraints(prior_gmrf)) + cache.b .= b_saved + return λ +end + +function _ift_solve(posterior::ChordalGMRF, x̄_total, ::ChordalGMRF) + return ldivsym(precision_matrix(posterior), posterior.L, posterior.P, x̄_total) +end + """ ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(gaussian_approximation), - prior_gmrf::Union{GMRF, ConstrainedGMRF}, obs_lik::ObservationLikelihood; kwargs...) + prior_gmrf::Union{GMRF, ConstrainedGMRF, ChordalGMRF}, + obs_lik::ObservationLikelihood; kwargs...) Backend-agnostic automatic differentiation rule for `gaussian_approximation` using the Implicit Function Theorem. @@ -238,7 +268,7 @@ Uses the Implicit Function Theorem on the optimality condition ∇ₓ neg_log_po # Arguments - `config::RuleConfig{>:HasReverseMode}`: AD backend configuration -- `prior_gmrf`: Prior GMRF or ConstrainedGMRF +- `prior_gmrf`: Prior GMRF, ConstrainedGMRF, or ChordalGMRF - `obs_lik`: Observation likelihood - `kwargs...`: Convergence parameters (max_iter, mean_change_tol, etc.) - treated as non-differentiable @@ -249,7 +279,7 @@ Uses the Implicit Function Theorem on the optimality condition ∇ₓ neg_log_po function ChainRulesCore.rrule( config::RuleConfig{>:HasReverseMode}, ::typeof(gaussian_approximation), - prior_gmrf::Union{GMRF, ConstrainedGMRF}, + prior_gmrf::Union{GMRF, ConstrainedGMRF, ChordalGMRF}, obs_lik::ObservationLikelihood; kwargs... ) @@ -259,41 +289,33 @@ function ChainRulesCore.rrule( # === Pullback === function gaussian_approximation_pullback(ȳ) - # Extract tangent components — handle both GMRF and ConstrainedGMRF posteriors + if ȳ isa ZeroTangent + return (NoTangent(), ZeroTangent(), NoTangent()) + end + + # Extract tangent components — dispatches on posterior type μ̄, Q̄ = _extract_posterior_tangents(ȳ, posterior) - # Handle precision tangent through loghessian to get indirect x* dependence - # The Hessian appears in the posterior precision: Q_post = Q_prior - loghessian(x*) - # When differentiating w.r.t. θ: ∂Q_post/∂θ = ∂Q_prior/∂θ - ∂loghessian/∂x* · ∂x*/∂θ - ∂loghessian/∂obs_lik · ∂obs_lik/∂θ - # Compute indirect x* contribution from precision gradient via loghessian + # Q̄ path: backprop through Q_post = Q_prior - loghessian(x*, obs_lik) if _is_zero_tangent(Q̄) - # No precision tangent, only mean path contributes - x_tangent_from_hess = nothing - obs_lik_tangent_from_Q̄ = NoTangent() + x̄_from_Q = ZeroTangent() + obs_lik_tangent_from_Q = NoTangent() else _, hess_pullback = rrule_via_ad(config, loghessian, x_star, obs_lik) - _, x_tangent_from_hess, obs_lik_tangent_from_Q̄ = hess_pullback(-Q̄) # Pass -Q̄ because Q_post = Q_prior - H + _, x̄_from_Q, obs_lik_tangent_from_Q = hess_pullback(-Q̄) end - # Solve for λ combining BOTH mean tangent and indirect x* dependence from precision - # H · λ = μ̄ + x_tangent_from_hess - # This accounts for: ∂L/∂x* from direct (mean) path + indirect (precision→hessian→x*) path - cache = linsolve_cache(_base_gmrf(posterior)) - b_saved = copy(cache.b) - cache.b = _is_zero_tangent(x_tangent_from_hess) ? collect(μ̄) : collect(μ̄) .+ collect(x_tangent_from_hess) - λ = copy(solve!(cache).u) - - # For constrained problems, project λ onto the constraint tangent space via - # the KKT Schur complement (same as forward pass Newton step projection). - λ = _constrain_step(λ, cache, _extract_constraints(prior_gmrf)) - cache.b .= b_saved - - # VJP through ∇ₓ_neg_log_posterior at x* to get gradients w.r.t. prior and likelihood. - # Use the BASE GMRF (not ConstrainedGMRF) to match the forward pass, which operates - # on the unconstrained GMRF with constraint-projected steps. + # μ̄ path: μ_post = x*, so μ̄ flows directly to x* + x̄_from_μ = _is_zero_tangent(μ̄) ? ZeroTangent() : μ̄ + + # Combine x* tangents and solve Q_post * λ = x̄_total via IFT + x̄_total = _combine_x_tangents(x̄_from_μ, x̄_from_Q, x_star) + λ = _ift_solve(posterior, x̄_total, prior_gmrf) + + # VJP through ∇ₓ_neg_log_posterior at x* to get gradients w.r.t. prior and likelihood base_prior = _base_gmrf(prior_gmrf) _, ∇_pullback = rrule_via_ad(config, ∇ₓ_neg_log_posterior, base_prior, obs_lik, x_star) - _, base_prior_tangent, obs_lik_tangent, _ = ∇_pullback(-λ) # Note the minus sign from IFT + _, base_prior_tangent, obs_lik_tangent, _ = ∇_pullback(-λ) # Add contribution from ȳ.precision to base prior tangent if !_is_zero_tangent(Q̄) @@ -304,19 +326,16 @@ function ChainRulesCore.rrule( prior_gmrf_tangent = _wrap_prior_tangent(base_prior_tangent, prior_gmrf) # Combine tangents from mean path and precision path - obs_lik_combined = _add_namedtuples(obs_lik_tangent, obs_lik_tangent_from_Q̄) + obs_lik_combined = _add_namedtuples(obs_lik_tangent, obs_lik_tangent_from_Q) - # Return tangents: NoTangent for function and kwargs return (NoTangent(), prior_gmrf_tangent, obs_lik_combined) end return posterior, gaussian_approximation_pullback end -# Also handle case without RuleConfig for simpler usage - # ============================================================================= -# ChordalGMRF rrule +# ChordalGMRF tangent helpers (dispatched from unified rrule above) # ============================================================================= # Extract tangents from ChordalGMRF posterior @@ -338,94 +357,3 @@ function _add_precision_tangent(prior_tangent, prior::ChordalGMRF, Q̄) P = NoTangent(), ) end - -""" - ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(gaussian_approximation), - prior_gmrf::ChordalGMRF, obs_lik::ObservationLikelihood; kwargs...) - -Automatic differentiation rule for `gaussian_approximation` with `ChordalGMRF` using the Implicit Function Theorem. - -Uses the chordal structure from CliqueTrees for efficient operations. The posterior precision -is `Q_post = Q_prior - loghessian(x*, obs_lik)`. - -# Mathematical Approach -Uses IFT on the optimality condition ∇ₓ neg_log_posterior(x*) = 0: -- Forward: Solve via Fisher scoring -- Backward: - 1. Extract (μ̄, Q̄) from posterior tangent - 2. Backprop Q̄ through loghessian to get x* and obs_lik contributions - 3. Combine with μ̄ and solve Q_post λ = total_x̄ using chordal Cholesky - 4. VJP through ∇ₓ_neg_log_posterior with -λ to get prior and obs_lik tangents - 5. Add direct Q̄ contribution to prior -""" -function ChainRulesCore.rrule( - config::RuleConfig{>:HasReverseMode}, - ::typeof(gaussian_approximation), - prior_gmrf::ChordalGMRF, - obs_lik::ObservationLikelihood; - kwargs... - ) - # === Forward pass === - posterior = gaussian_approximation(prior_gmrf, obs_lik; kwargs...) - - # Mode in unpermuted coordinates - x_star = mean(posterior) - P = posterior.P - S = posterior.L.S - - # === Pullback === - function chordal_gaussian_approximation_pullback(ȳ) - # Handle zero tangent case - if ȳ isa ZeroTangent - return (NoTangent(), ZeroTangent(), NoTangent()) - end - - # Extract tangent components - μ̄, Q̄ = _extract_posterior_tangents(ȳ, posterior) - - # --- Q̄ path: backprop through Q_post = Q_prior - loghessian(x*, obs_lik) --- - if _is_zero_tangent(Q̄) - x̄_from_Q = ZeroTangent() - obs_lik_tangent_from_Q = NoTangent() - else - # Use rrule_via_ad on loghessian; Q_post = Q_prior - loghessian(...) - _, hess_pullback = rrule_via_ad(config, loghessian, x_star, obs_lik) - # -Q̄ because Q_post = Q_prior - loghessian(...) - _, x̄_from_Q, obs_lik_tangent_from_Q = hess_pullback(-Q̄) - end - - # --- μ̄ path: μ_post = x*, so μ̄ flows directly to x* --- - x̄_from_μ = _is_zero_tangent(μ̄) ? ZeroTangent() : μ̄ - - # --- Combine x* tangents --- - if _is_zero_tangent(x̄_from_μ) && _is_zero_tangent(x̄_from_Q) - x̄_total = zeros(eltype(x_star), length(x_star)) - elseif _is_zero_tangent(x̄_from_μ) - x̄_total = collect(x̄_from_Q) - elseif _is_zero_tangent(x̄_from_Q) - x̄_total = collect(x̄_from_μ) - else - x̄_total = collect(x̄_from_μ) .+ collect(x̄_from_Q) - end - - # --- IFT: solve Q_post * λ = x̄_total using chordal Cholesky --- - λ = ldivsym(precision_matrix(posterior), posterior.L, P, x̄_total) - - # --- VJP through ∇ₓ_neg_log_posterior at x* --- - _, ∇_pullback = rrule_via_ad(config, ∇ₓ_neg_log_posterior, prior_gmrf, obs_lik, x_star) - _, prior_tangent, obs_lik_tangent, _ = ∇_pullback(-λ) # Minus sign from IFT - - # --- Add direct Q̄ contribution to prior (Q_post = Q_prior - ...) --- - if !_is_zero_tangent(Q̄) - prior_tangent = _add_precision_tangent(prior_tangent, prior_gmrf, Q̄) - end - - # --- Combine obs_lik tangents from both paths --- - obs_lik_combined = _add_namedtuples(obs_lik_tangent, obs_lik_tangent_from_Q) - - return (NoTangent(), prior_tangent, obs_lik_combined) - end - - return posterior, chordal_gaussian_approximation_pullback -end - From 3ae505c4f2f095d759376cc1183f29da7a698082 Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Fri, 3 Apr 2026 18:09:25 +0200 Subject: [PATCH 08/11] Fix ChordalGMRF constructor and rrule dispatch Add 4-arg SparseMatrixCSC constructor that wraps Q in Hermitian, matching the 2-arg constructor's behavior. Fix rrule dispatch from ::typeof(ChordalGMRF) (resolves to ::UnionAll, matching GMRF too) to ::Type{ChordalGMRF}. --- src/chordal_gmrf.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/chordal_gmrf.jl b/src/chordal_gmrf.jl index a5986fc9..602f6777 100644 --- a/src/chordal_gmrf.jl +++ b/src/chordal_gmrf.jl @@ -12,6 +12,10 @@ struct ChordalGMRF{T <: Real, Herm <: HermSparse{T}, Tri <: ChordalTriangular{:N P::Prm end +function ChordalGMRF(μ::AbstractVector, Q::SparseMatrixCSC, L, P) + return ChordalGMRF(μ, Hermitian(Q, :L), L, P) +end + function ChordalGMRF(μ::AbstractVector, Q::SparseMatrixCSC; kw...) H = Hermitian(Q, :L) P, S = symbolic(H; kw...) @@ -67,10 +71,10 @@ function Base.show(io::IO, ::MIME"text/plain", d::ChordalGMRF{T}) where {T} μ = d.μ - if length(μ) <= 6 + return if length(μ) <= 6 print(io, " Mean: $μ") else - print(io, " Mean: [$(μ[1]), $(μ[2]), $(μ[3]), ..., $(μ[end-2]), $(μ[end-1]), $(μ[end])]") + print(io, " Mean: [$(μ[1]), $(μ[2]), $(μ[3]), ..., $(μ[end - 2]), $(μ[end - 1]), $(μ[end])]") end end @@ -78,9 +82,8 @@ end # ChordalGMRF is defined by (μ, Q). L and P are derived - gradients never flow through them. using ChainRulesCore: ChainRulesCore, NoTangent -function ChainRulesCore.rrule(::typeof(ChordalGMRF), μ::AbstractVector, Q::SparseMatrixCSC; kw...) +function ChainRulesCore.rrule(::Type{ChordalGMRF}, μ::AbstractVector, Q::SparseMatrixCSC; kw...) result = ChordalGMRF(μ, Q; kw...) ChordalGMRF_pullback(ȳ) = (NoTangent(), ȳ.μ, ȳ.Q) return result, ChordalGMRF_pullback end - From 71647ccee28f94de2273714ed9d60ad9bac521e7 Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Fri, 3 Apr 2026 18:09:44 +0200 Subject: [PATCH 09/11] Add sum(::SymTridiagonal) rrule to fix invalidation-exposed bug The rrules for HermSparse/SymSparse cause method invalidation that exposes an upstream ChainRulesCore bug: ProjectTo{SymTridiagonal} extracts only one triangle of the off-diagonal, losing the factor of 2 from symmetry. --- src/piracy.jl | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/piracy.jl b/src/piracy.jl index fa7b0232..65bd9db4 100644 --- a/src/piracy.jl +++ b/src/piracy.jl @@ -63,17 +63,17 @@ const SymSparseProjectTo{T, I} = ProjectTo{ }, } -function ChainRulesCore.ProjectTo(x::HermSparse{T}) where {T<:Number} +function ChainRulesCore.ProjectTo(x::HermSparse{T}) where {T <: Number} return ProjectTo{Hermitian}(; - uplo=Symbol(x.uplo), - parent=ProjectTo(parent(x)), + uplo = Symbol(x.uplo), + parent = ProjectTo(parent(x)), ) end -function ChainRulesCore.ProjectTo(x::SymSparse{T}) where {T<:Number} +function ChainRulesCore.ProjectTo(x::SymSparse{T}) where {T <: Number} return ProjectTo{Symmetric}(; - uplo=Symbol(x.uplo), - parent=ProjectTo(parent(x)), + uplo = Symbol(x.uplo), + parent = ProjectTo(parent(x)), ) end @@ -232,7 +232,7 @@ function unwrap(A) if B isa Transpose return (parent(B), Val(:N), Val(:C)) else - return (B, Val(:T), Val(:C)) + return (B, Val(:T), Val(:C)) end elseif A isa Transpose B = parent(A) @@ -240,7 +240,7 @@ function unwrap(A) if B isa Adjoint return (parent(B), Val(:N), Val(:C)) else - return (B, Val(:T), Val(:N)) + return (B, Val(:T), Val(:N)) end else return (A, Val(:N), Val(:N)) @@ -253,7 +253,7 @@ end # # The update is only applied to the structural nonzeros of C. function selupd!(C::HermSparse, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) - selupd!(parent(C), C.uplo, A, adjoint(B), α, β) + selupd!(parent(C), C.uplo, A, adjoint(B), α, β) selupd!(parent(C), C.uplo, B, adjoint(A), conj(α), 1) return C end @@ -264,7 +264,7 @@ end # # The update is only applied to the structural nonzeros of C. function selupd!(C::SymSparse, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) - selupd!(parent(C), C.uplo, A, adjoint(B), α, β) + selupd!(parent(C), C.uplo, A, adjoint(B), α, β) selupd!(parent(C), C.uplo, adjoint(transpose(B)), transpose(A), α, 1) return C end @@ -494,3 +494,14 @@ for T in (HermSparse, SymSparse) return dot_rrule(x, A, y) end end + +# The rrules above cause method invalidation that exposes an upstream bug in +# ChainRulesCore's ProjectTo{SymTridiagonal}: it extracts only one triangle of +# the off-diagonal, losing the factor of 2 from symmetry. +function ChainRulesCore.rrule(::typeof(sum), Q::SymTridiagonal) + function sum_symtridiag_pullback(ȳ) + s = unthunk(ȳ) + return NoTangent(), Tangent{SymTridiagonal}(dv = fill(s, length(Q.dv)), ev = fill(2s, length(Q.ev))) + end + return sum(Q), sum_symtridiag_pullback +end From 8ace9944600413586ca4c097e16ecbe2075575fb Mon Sep 17 00:00:00 2001 From: Tim Weiland Date: Fri, 3 Apr 2026 18:10:05 +0200 Subject: [PATCH 10/11] Runic formatting --- benchmarks/autodiff_comparison.jl | 16 +++- .../gaussian_approximation_comparison.jl | 92 +++++++++++-------- benchmarks/logpdf_comparison.jl | 90 ++++++++++-------- ext/GaussianMarkovRandomFieldsAutoDiff.jl | 2 +- .../test_gaussian_approximation_chordal.jl | 4 +- .../test_gaussian_approximation_chordal.jl | 16 ++-- 6 files changed, 128 insertions(+), 92 deletions(-) diff --git a/benchmarks/autodiff_comparison.jl b/benchmarks/autodiff_comparison.jl index d3cf519d..bfab825e 100644 --- a/benchmarks/autodiff_comparison.jl +++ b/benchmarks/autodiff_comparison.jl @@ -283,13 +283,21 @@ if get(results, "ChordalGMRF+Zygote", nothing) !== nothing && get(results, "Zygo println(@sprintf(" %-20s %12s %12s %12s %12s", "Implementation", "Time (ms)", "Speedup", "Allocs", "Memory (MB)")) println(" " * "─"^76) - println(@sprintf(" %-20s %12.2f %12s %12d %12.2f", - "GMRF", r_gmrf.time, "1.0×", r_gmrf.bench.allocs, r_gmrf.bench.memory / 1.0e6)) + println( + @sprintf( + " %-20s %12.2f %12s %12d %12.2f", + "GMRF", r_gmrf.time, "1.0×", r_gmrf.bench.allocs, r_gmrf.bench.memory / 1.0e6 + ) + ) chordal_speedup = r_gmrf.time / r_chordal.time - println(@sprintf(" %-20s %12.2f %12s %12d %12.2f", + println( + @sprintf( + " %-20s %12.2f %12s %12d %12.2f", "ChordalGMRF", r_chordal.time, @sprintf("%.1f×", chordal_speedup), - r_chordal.bench.allocs, r_chordal.bench.memory / 1.0e6)) + r_chordal.bench.allocs, r_chordal.bench.memory / 1.0e6 + ) + ) println(" " * "─"^76) diff --git a/benchmarks/gaussian_approximation_comparison.jl b/benchmarks/gaussian_approximation_comparison.jl index 6f3a12b3..976163d6 100644 --- a/benchmarks/gaussian_approximation_comparison.jl +++ b/benchmarks/gaussian_approximation_comparison.jl @@ -31,7 +31,7 @@ println("="^80) function make_posdef(A::SparseMatrixCSC) # Symmetrize and add diagonal dominance S = (A + A') / 2 - d = vec(sum(abs, S; dims=2)) + d = vec(sum(abs, S; dims = 2)) return S + spdiagm(0 => d .+ 1.0) end @@ -101,16 +101,16 @@ for (matrix_name, desc) in test_matrices Q_chordal = precision_matrix(posterior_chordal) mean_diff = norm(mean_gmrf - mean_chordal) - mean_rel_diff = mean_diff / (norm(mean_gmrf) + 1e-10) + mean_rel_diff = mean_diff / (norm(mean_gmrf) + 1.0e-10) Q_diff = norm(Q_gmrf - Q_chordal) - Q_rel_diff = Q_diff / (norm(Q_gmrf) + 1e-10) + Q_rel_diff = Q_diff / (norm(Q_gmrf) + 1.0e-10) println(" Mean abs diff: $(@sprintf("%.2e", mean_diff))") println(" Mean rel diff: $(@sprintf("%.2e", mean_rel_diff))") println(" Precision abs diff: $(@sprintf("%.2e", Q_diff))") println(" Precision rel diff: $(@sprintf("%.2e", Q_rel_diff))") - correct = mean_rel_diff < 1e-6 && Q_rel_diff < 1e-6 + correct = mean_rel_diff < 1.0e-6 && Q_rel_diff < 1.0e-6 println(" Match: $(correct ? "✓ YES" : "✗ NO")") # Performance benchmark @@ -118,14 +118,14 @@ for (matrix_name, desc) in test_matrices # Benchmark GMRF print(" GMRF... ") - bench_gmrf = @benchmark gaussian_approximation($gmrf_prior, $obs_lik) samples=10 seconds=10 - time_gmrf = minimum(bench_gmrf.times) / 1e6 + bench_gmrf = @benchmark gaussian_approximation($gmrf_prior, $obs_lik) samples = 10 seconds = 10 + time_gmrf = minimum(bench_gmrf.times) / 1.0e6 println("$(@sprintf("%.3f", time_gmrf)) ms") # Benchmark ChordalGMRF print(" ChordalGMRF... ") - bench_chordal = @benchmark gaussian_approximation($chordal_prior, $obs_lik) samples=10 seconds=10 - time_chordal = minimum(bench_chordal.times) / 1e6 + bench_chordal = @benchmark gaussian_approximation($chordal_prior, $obs_lik) samples = 10 seconds = 10 + time_chordal = minimum(bench_chordal.times) / 1.0e6 println("$(@sprintf("%.3f", time_chordal)) ms") speedup = time_gmrf / time_chordal @@ -149,49 +149,55 @@ for (matrix_name, desc) in test_matrices grad_gmrf = Zygote.gradient(loss_gmrf, μ)[1] grad_chordal = Zygote.gradient(loss_chordal, μ)[1] grad_abs_diff = norm(grad_gmrf - grad_chordal) - grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1e-10) + grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1.0e-10) println(" Absolute diff: $(@sprintf("%.2e", grad_abs_diff))") println(" Relative diff: $(@sprintf("%.2e", grad_rel_diff))") - grad_correct = grad_rel_diff < 1e-6 + grad_correct = grad_rel_diff < 1.0e-6 println(" Match: $(grad_correct ? "✓ YES" : "✗ NO")") # Gradient performance benchmark println("\n Gradient performance benchmark:") print(" GMRF... ") - bench_grad_gmrf = @benchmark Zygote.gradient($loss_gmrf, $μ) samples=10 seconds=10 - time_grad_gmrf = minimum(bench_grad_gmrf.times) / 1e6 + bench_grad_gmrf = @benchmark Zygote.gradient($loss_gmrf, $μ) samples = 10 seconds = 10 + time_grad_gmrf = minimum(bench_grad_gmrf.times) / 1.0e6 println("$(@sprintf("%.3f", time_grad_gmrf)) ms") print(" ChordalGMRF... ") - bench_grad_chordal = @benchmark Zygote.gradient($loss_chordal, $μ) samples=10 seconds=10 - time_grad_chordal = minimum(bench_grad_chordal.times) / 1e6 + bench_grad_chordal = @benchmark Zygote.gradient($loss_chordal, $μ) samples = 10 seconds = 10 + time_grad_chordal = minimum(bench_grad_chordal.times) / 1.0e6 println("$(@sprintf("%.3f", time_grad_chordal)) ms") grad_speedup = time_grad_gmrf / time_grad_chordal println(" Speedup: $(@sprintf("%.2f", grad_speedup))×") - push!(results, ( - name=matrix_name, - n=n, - nnz=nnz(Q), - correct=correct, - grad_correct=grad_correct, - time_gmrf=time_gmrf, - time_chordal=time_chordal, - speedup=speedup, - time_grad_gmrf=time_grad_gmrf, - time_grad_chordal=time_grad_chordal, - grad_speedup=grad_speedup, - )) + push!( + results, ( + name = matrix_name, + n = n, + nnz = nnz(Q), + correct = correct, + grad_correct = grad_correct, + time_gmrf = time_gmrf, + time_chordal = time_chordal, + speedup = speedup, + time_grad_gmrf = time_grad_gmrf, + time_grad_chordal = time_grad_chordal, + grad_speedup = grad_speedup, + ) + ) catch e - println(" ✗ Failed: $(typeof(e).name.name): $(sprint(showerror, e; context=:limit=>true))") - push!(results, (name=matrix_name, n=0, nnz=0, correct=false, grad_correct=false, - time_gmrf=NaN, time_chordal=NaN, speedup=NaN, - time_grad_gmrf=NaN, time_grad_chordal=NaN, grad_speedup=NaN)) + println(" ✗ Failed: $(typeof(e).name.name): $(sprint(showerror, e; context = :limit => true))") + push!( + results, ( + name = matrix_name, n = 0, nnz = 0, correct = false, grad_correct = false, + time_gmrf = NaN, time_chordal = NaN, speedup = NaN, + time_grad_gmrf = NaN, time_grad_chordal = NaN, grad_speedup = NaN, + ) + ) end end @@ -201,14 +207,18 @@ println("SUMMARY: FORWARD PASS") println("="^80) println("\n" * "-"^95) -@printf("%-20s %8s %10s %8s %12s %12s %10s\n", - "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup") +@printf( + "%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup" +) println("-"^95) for r in results correct_str = r.correct ? "✓" : "✗" - @printf("%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", - r.name, r.n, r.nnz, correct_str, r.time_gmrf, r.time_chordal, r.speedup) + @printf( + "%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_gmrf, r.time_chordal, r.speedup + ) end println("-"^95) @@ -218,14 +228,18 @@ println("SUMMARY: GRADIENT (Zygote)") println("="^80) println("\n" * "-"^95) -@printf("%-20s %8s %10s %8s %12s %12s %10s\n", - "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup") +@printf( + "%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup" +) println("-"^95) for r in results correct_str = r.grad_correct ? "✓" : "✗" - @printf("%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", - r.name, r.n, r.nnz, correct_str, r.time_grad_gmrf, r.time_grad_chordal, r.grad_speedup) + @printf( + "%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_grad_gmrf, r.time_grad_chordal, r.grad_speedup + ) end println("-"^95) diff --git a/benchmarks/logpdf_comparison.jl b/benchmarks/logpdf_comparison.jl index 75f46e7f..499198f3 100644 --- a/benchmarks/logpdf_comparison.jl +++ b/benchmarks/logpdf_comparison.jl @@ -30,7 +30,7 @@ println("="^80) function make_posdef(A::SparseMatrixCSC) # Symmetrize and add diagonal dominance S = (A + A') / 2 - d = vec(sum(abs, S; dims=2)) + d = vec(sum(abs, S; dims = 2)) return S + spdiagm(0 => d .+ 1.0) end @@ -84,14 +84,14 @@ for (matrix_name, desc) in test_matrices lpdf_gmrf = logpdf(gmrf, z) lpdf_chordal = logpdf(chordal_gmrf, z) abs_diff = abs(lpdf_gmrf - lpdf_chordal) - rel_diff = abs_diff / (abs(lpdf_gmrf) + 1e-10) + rel_diff = abs_diff / (abs(lpdf_gmrf) + 1.0e-10) println(" GMRF logpdf: $(@sprintf("%.8f", lpdf_gmrf))") println(" ChordalGMRF logpdf: $(@sprintf("%.8f", lpdf_chordal))") println(" Absolute diff: $(@sprintf("%.2e", abs_diff))") println(" Relative diff: $(@sprintf("%.2e", rel_diff))") - correct = rel_diff < 1e-8 + correct = rel_diff < 1.0e-8 println(" Match: $(correct ? "✓ YES" : "✗ NO")") # Performance benchmark @@ -99,14 +99,14 @@ for (matrix_name, desc) in test_matrices # Benchmark GMRF print(" GMRF... ") - bench_gmrf = @benchmark logpdf($gmrf, $z) samples=20 seconds=5 - time_gmrf = minimum(bench_gmrf.times) / 1e6 + bench_gmrf = @benchmark logpdf($gmrf, $z) samples = 20 seconds = 5 + time_gmrf = minimum(bench_gmrf.times) / 1.0e6 println("$(@sprintf("%.3f", time_gmrf)) ms") # Benchmark ChordalGMRF print(" ChordalGMRF... ") - bench_chordal = @benchmark logpdf($chordal_gmrf, $z) samples=20 seconds=5 - time_chordal = minimum(bench_chordal.times) / 1e6 + bench_chordal = @benchmark logpdf($chordal_gmrf, $z) samples = 20 seconds = 5 + time_chordal = minimum(bench_chordal.times) / 1.0e6 println("$(@sprintf("%.3f", time_chordal)) ms") speedup = time_gmrf / time_chordal @@ -117,49 +117,55 @@ for (matrix_name, desc) in test_matrices grad_gmrf = Zygote.gradient(x -> logpdf(gmrf, x), z)[1] grad_chordal = Zygote.gradient(x -> logpdf(chordal_gmrf, x), z)[1] grad_abs_diff = norm(grad_gmrf - grad_chordal) - grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1e-10) + grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1.0e-10) println(" Absolute diff: $(@sprintf("%.2e", grad_abs_diff))") println(" Relative diff: $(@sprintf("%.2e", grad_rel_diff))") - grad_correct = grad_rel_diff < 1e-8 + grad_correct = grad_rel_diff < 1.0e-8 println(" Match: $(grad_correct ? "✓ YES" : "✗ NO")") # Gradient performance benchmark println("\n Gradient performance benchmark:") print(" GMRF... ") - bench_grad_gmrf = @benchmark Zygote.gradient(x -> logpdf($gmrf, x), $z) samples=20 seconds=5 - time_grad_gmrf = minimum(bench_grad_gmrf.times) / 1e6 + bench_grad_gmrf = @benchmark Zygote.gradient(x -> logpdf($gmrf, x), $z) samples = 20 seconds = 5 + time_grad_gmrf = minimum(bench_grad_gmrf.times) / 1.0e6 println("$(@sprintf("%.3f", time_grad_gmrf)) ms") print(" ChordalGMRF... ") - bench_grad_chordal = @benchmark Zygote.gradient(x -> logpdf($chordal_gmrf, x), $z) samples=20 seconds=5 - time_grad_chordal = minimum(bench_grad_chordal.times) / 1e6 + bench_grad_chordal = @benchmark Zygote.gradient(x -> logpdf($chordal_gmrf, x), $z) samples = 20 seconds = 5 + time_grad_chordal = minimum(bench_grad_chordal.times) / 1.0e6 println("$(@sprintf("%.3f", time_grad_chordal)) ms") grad_speedup = time_grad_gmrf / time_grad_chordal println(" Speedup: $(@sprintf("%.2f", grad_speedup))×") - push!(results, ( - name=matrix_name, - n=n, - nnz=nnz(Q), - correct=correct, - grad_correct=grad_correct, - time_gmrf=time_gmrf, - time_chordal=time_chordal, - speedup=speedup, - time_grad_gmrf=time_grad_gmrf, - time_grad_chordal=time_grad_chordal, - grad_speedup=grad_speedup, - )) + push!( + results, ( + name = matrix_name, + n = n, + nnz = nnz(Q), + correct = correct, + grad_correct = grad_correct, + time_gmrf = time_gmrf, + time_chordal = time_chordal, + speedup = speedup, + time_grad_gmrf = time_grad_gmrf, + time_grad_chordal = time_grad_chordal, + grad_speedup = grad_speedup, + ) + ) catch e - println(" ✗ Failed: $(typeof(e).name.name): $(sprint(showerror, e; context=:limit=>true))") - push!(results, (name=matrix_name, n=0, nnz=0, correct=false, grad_correct=false, - time_gmrf=NaN, time_chordal=NaN, speedup=NaN, - time_grad_gmrf=NaN, time_grad_chordal=NaN, grad_speedup=NaN)) + println(" ✗ Failed: $(typeof(e).name.name): $(sprint(showerror, e; context = :limit => true))") + push!( + results, ( + name = matrix_name, n = 0, nnz = 0, correct = false, grad_correct = false, + time_gmrf = NaN, time_chordal = NaN, speedup = NaN, + time_grad_gmrf = NaN, time_grad_chordal = NaN, grad_speedup = NaN, + ) + ) end end @@ -169,14 +175,18 @@ println("SUMMARY: FORWARD PASS") println("="^80) println("\n" * "-"^95) -@printf("%-20s %8s %10s %8s %12s %12s %10s\n", - "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup") +@printf( + "%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup" +) println("-"^95) for r in results correct_str = r.correct ? "✓" : "✗" - @printf("%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", - r.name, r.n, r.nnz, correct_str, r.time_gmrf, r.time_chordal, r.speedup) + @printf( + "%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_gmrf, r.time_chordal, r.speedup + ) end println("-"^95) @@ -186,14 +196,18 @@ println("SUMMARY: GRADIENT (Zygote)") println("="^80) println("\n" * "-"^95) -@printf("%-20s %8s %10s %8s %12s %12s %10s\n", - "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup") +@printf( + "%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup" +) println("-"^95) for r in results correct_str = r.grad_correct ? "✓" : "✗" - @printf("%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", - r.name, r.n, r.nnz, correct_str, r.time_grad_gmrf, r.time_grad_chordal, r.grad_speedup) + @printf( + "%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_grad_gmrf, r.time_grad_chordal, r.grad_speedup + ) end println("-"^95) diff --git a/ext/GaussianMarkovRandomFieldsAutoDiff.jl b/ext/GaussianMarkovRandomFieldsAutoDiff.jl index 09f5ace6..2a747a12 100644 --- a/ext/GaussianMarkovRandomFieldsAutoDiff.jl +++ b/ext/GaussianMarkovRandomFieldsAutoDiff.jl @@ -17,7 +17,7 @@ import LinearMaps: _unsafe_mul! # Zygote accum for sparse Hermitian/Symmetric (piracy until upstream PR is merged) const HermOrSymSparse{T, I} = Union{ Hermitian{T, SparseMatrixCSC{T, I}}, - Symmetric{T, SparseMatrixCSC{T, I}} + Symmetric{T, SparseMatrixCSC{T, I}}, } Zygote.accum(x::HermOrSymSparse, y::HermOrSymSparse) = x + y diff --git a/test/autodiff/test_gaussian_approximation_chordal.jl b/test/autodiff/test_gaussian_approximation_chordal.jl index 4b018d32..ffbc74c3 100644 --- a/test/autodiff/test_gaussian_approximation_chordal.jl +++ b/test/autodiff/test_gaussian_approximation_chordal.jl @@ -249,7 +249,7 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] diag_main = fill(4.0 + α, n) # Horizontal neighbors (±1 diagonals), but skip row boundaries - horiz = [-1.0 * (mod(i, grid_size) != 0) for i in 1:(n-1)] + horiz = [-1.0 * (mod(i, grid_size) != 0) for i in 1:(n - 1)] # Vertical neighbors (±grid_size diagonals) vert = fill(-1.0, n - grid_size) @@ -304,7 +304,7 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] function grid_precision(α, grid_size) n = grid_size^2 diag_main = fill(4.0 + α, n) - horiz = [-1.0 * (mod(i, grid_size) != 0) for i in 1:(n-1)] + horiz = [-1.0 * (mod(i, grid_size) != 0) for i in 1:(n - 1)] vert = fill(-1.0, n - grid_size) return spdiagm(-grid_size => vert, -1 => horiz, 0 => diag_main, 1 => horiz, grid_size => vert) end diff --git a/test/gaussian_approximation/test_gaussian_approximation_chordal.jl b/test/gaussian_approximation/test_gaussian_approximation_chordal.jl index 32cf7a7f..658fadd9 100644 --- a/test/gaussian_approximation/test_gaussian_approximation_chordal.jl +++ b/test/gaussian_approximation/test_gaussian_approximation_chordal.jl @@ -29,14 +29,14 @@ using Distributions Q_analytical = Q_prior + Q_obs μ_analytical = Q_analytical \ (Q_prior * μ_prior + Q_obs * y) - @test precision_matrix(result) ≈ Q_analytical atol = 1e-8 - @test mean(result) ≈ μ_analytical atol = 1e-8 + @test precision_matrix(result) ≈ Q_analytical atol = 1.0e-8 + @test mean(result) ≈ μ_analytical atol = 1.0e-8 end @testset "Bernoulli Likelihood - Mathematical Properties" begin # Test with Bernoulli observation model (non-linear) n = 8 - Q_prior = spdiagm(0 => ones(n), 1 => fill(-0.3, n-1), -1 => fill(-0.3, n-1)) + Q_prior = spdiagm(0 => ones(n), 1 => fill(-0.3, n - 1), -1 => fill(-0.3, n - 1)) μ_prior = zeros(n) prior_gmrf = ChordalGMRF(μ_prior, Q_prior) @@ -85,7 +85,7 @@ using Distributions @testset "Consistency with GMRF" begin # Results should match between GMRF and ChordalGMRF n = 5 - Q_prior = spdiagm(0 => 2.0 * ones(n), 1 => fill(-0.5, n-1), -1 => fill(-0.5, n-1)) + Q_prior = spdiagm(0 => 2.0 * ones(n), 1 => fill(-0.5, n - 1), -1 => fill(-0.5, n - 1)) μ_prior = zeros(n) gmrf_prior = GMRF(μ_prior, Q_prior) @@ -98,14 +98,14 @@ using Distributions result_gmrf = gaussian_approximation(gmrf_prior, obs_lik) result_chordal = gaussian_approximation(chordal_prior, obs_lik) - @test mean(result_gmrf) ≈ mean(result_chordal) atol = 1e-6 - @test precision_matrix(result_gmrf) ≈ precision_matrix(result_chordal) atol = 1e-6 + @test mean(result_gmrf) ≈ mean(result_chordal) atol = 1.0e-6 + @test precision_matrix(result_gmrf) ≈ precision_matrix(result_chordal) atol = 1.0e-6 end @testset "Sparse precision - tridiagonal" begin # Test with tridiagonal precision (common in GMRFs) n = 10 - Q_prior = spdiagm(0 => 2.0 * ones(n), 1 => -ones(n-1), -1 => -ones(n-1)) + Q_prior = spdiagm(0 => 2.0 * ones(n), 1 => -ones(n - 1), -1 => -ones(n - 1)) μ_prior = zeros(n) prior_gmrf = ChordalGMRF(μ_prior, Q_prior) @@ -135,7 +135,7 @@ using Distributions # Warm-start from converged mode result_warm = gaussian_approximation(prior_gmrf, obs_lik; x0 = x_star) - @test mean(result_warm) ≈ x_star atol = 1e-4 + @test mean(result_warm) ≈ x_star atol = 1.0e-4 end @testset "Adaptive stepsize - extreme Poisson" begin From 8e3150351bc75d31a76b9917c15d00d1a3510cd5 Mon Sep 17 00:00:00 2001 From: Richard Samuelson Date: Mon, 6 Apr 2026 23:31:42 -0400 Subject: [PATCH 11/11] Zygote -> Mooncake --- Project.toml | 6 +- benchmarks/autodiff_comparison.jl | 88 +-- .../gaussian_approximation_comparison.jl | 24 +- benchmarks/logpdf_comparison.jl | 24 +- deps/MooncakeSparse | 1 + ext/GaussianMarkovRandomFieldsAutoDiff.jl | 19 +- src/GaussianMarkovRandomFields.jl | 2 +- .../condition/gaussian_approximation.jl | 5 +- src/autodiff/autodiff.jl | 3 + src/autodiff/constructors.jl | 34 ++ src/autodiff/gaussian_approximation.jl | 6 +- src/autodiff/logpdf.jl | 1 - .../mooncake_gaussian_approximation.jl | 117 ++++ src/chordal_gmrf.jl | 40 +- src/piracy.jl | 507 ------------------ .../test_gaussian_approximation_chordal.jl | 102 +--- test/gaussian_approximation/runtests.jl | 1 + 17 files changed, 284 insertions(+), 696 deletions(-) create mode 160000 deps/MooncakeSparse create mode 100644 src/autodiff/mooncake_gaussian_approximation.jl delete mode 100644 src/piracy.jl diff --git a/Project.toml b/Project.toml index fcacad58..05ea47b5 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ LibGEOS = "a90b1aa1-3769-5649-ba7e-abc5a9d163eb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SelectedInversion = "043bf095-3f01-458a-9f1c-8cf4448fe908" @@ -30,6 +31,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4" [weakdeps] +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" @@ -40,7 +42,6 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] GaussianMarkovRandomFieldsAutoDiff = ["ForwardDiff", "Zygote"] @@ -57,7 +58,7 @@ GaussianMarkovRandomFieldsSparseJacobian = ["Symbolics", "SparseDiffTools"] AMD = "0.5" Aqua = "0.8" ChainRulesCore = "1" -CliqueTrees = "1.19" +CliqueTrees = "1.19.1" DataStructures = "0.14 - 0.19" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25" @@ -75,6 +76,7 @@ LinearAlgebra = "<0.0.1, 1" LinearMaps = "3.11" LinearSolve = "2, 3" Makie = "0.19 - 0.22" +Mooncake = "0.5.25" NearestNeighbors = "0.4" Pardiso = "1" Random = "<0.0.1, 1" diff --git a/benchmarks/autodiff_comparison.jl b/benchmarks/autodiff_comparison.jl index bfab825e..34ce81e2 100644 --- a/benchmarks/autodiff_comparison.jl +++ b/benchmarks/autodiff_comparison.jl @@ -20,7 +20,7 @@ using LinearSolve using Printf using Random -using Zygote, Enzyme, FiniteDiff +using Zygote, Enzyme, FiniteDiff, Mooncake using CliqueTrees.Multifrontal: symbolic, chordal @@ -63,7 +63,7 @@ function benchmark_workflow(θ::Vector{Float64}, y::PoissonObservations, x_eval: return logpdf(posterior, x_eval) end -# ChordalGMRF workflow (only supports Zygote) +# ChordalGMRF workflow (only supports Mooncake) function benchmark_workflow_chordal(θ::Vector{Float64}, y::PoissonObservations, x_eval::Vector{Float64}) # Extract hyperparameters μ = θ[1:n] # Mean field (100 params) @@ -127,7 +127,7 @@ backends = [ ] println("\n" * "="^80) -println("BENCHMARKING GRADIENT COMPUTATION") +println("BENCHMARKING GRADIENT COMPUTATION (via DifferentiationInterface, prepared)") println("="^80) results = Dict() @@ -137,22 +137,22 @@ for (name, backend) in backends println("-"^40) try - # Warmup - print(" Warming up... ") - grad = DifferentiationInterface.gradient( - θ -> benchmark_workflow(θ, y_obs, x_eval), - backend, - θ_init - ) + # Define loss function + loss = θ -> benchmark_workflow(θ, y_obs, x_eval) + + # Prepare gradient (includes warmup/compilation) + print(" Preparing... ") + prep = DifferentiationInterface.prepare_gradient(loss, backend, θ_init) println("✓") - # Benchmark + # Compute gradient once + print(" Computing gradient... ") + grad = DifferentiationInterface.gradient(loss, prep, backend, θ_init) + println("✓") + + # Benchmark with prepared gradient print(" Benchmarking... ") - bench = @benchmark DifferentiationInterface.gradient( - θ -> benchmark_workflow(θ, y_obs, x_eval), - $backend, - $θ_init - ) samples = 10 seconds = 30 + bench = @benchmark DifferentiationInterface.gradient($loss, $prep, $backend, $θ_init) samples = 10 seconds = 30 results[name] = ( gradient = grad, @@ -167,52 +167,58 @@ for (name, backend) in backends println(" Memory: $(@sprintf("%.2f", bench.memory / 1.0e6)) MB") catch e - println(" ✗ Failed: $e") + println(" ✗ Failed: $(typeof(e).name.name)") + if e isa ErrorException + println(" $(first(split(e.msg, '\n')))") + end results[name] = nothing end end -# ChordalGMRF benchmark (Zygote only) +# ChordalGMRF benchmark (Mooncake only) println("\n" * "="^80) -println("BENCHMARKING ChordalGMRF (Zygote only)") +println("BENCHMARKING ChordalGMRF (Mooncake only)") println("="^80) -println("\nChordalGMRF + Zygote:") +println("\nChordalGMRF + Mooncake:") println("-"^40) try - # Warmup - print(" Warming up... ") - grad_chordal = DifferentiationInterface.gradient( - θ -> benchmark_workflow_chordal(θ, y_obs, x_eval), - AutoZygote(), - θ_init - ) + # Define loss function + loss_chordal = θ -> benchmark_workflow_chordal(θ, y_obs, x_eval) + + # Prepare gradient (includes warmup/compilation) + print(" Preparing... ") + prep_chordal = DifferentiationInterface.prepare_gradient(loss_chordal, AutoMooncake(; config=nothing), θ_init) println("✓") - # Benchmark + # Compute gradient once + print(" Computing gradient... ") + grad_chordal = DifferentiationInterface.gradient(loss_chordal, prep_chordal, AutoMooncake(; config=nothing), θ_init) + println("✓") + + # Benchmark with prepared gradient print(" Benchmarking... ") - bench_chordal = @benchmark DifferentiationInterface.gradient( - θ -> benchmark_workflow_chordal(θ, y_obs, x_eval), - AutoZygote(), - $θ_init - ) samples = 10 seconds = 30 + bench_chordal = @benchmark DifferentiationInterface.gradient($loss_chordal, $prep_chordal, AutoMooncake(; config=nothing), $θ_init) samples = 10 seconds = 30 - results["ChordalGMRF+Zygote"] = ( + results["ChordalGMRF+Mooncake"] = ( gradient = grad_chordal, time = minimum(bench_chordal.times) / 1.0e6, bench = bench_chordal, ) println("✓") - println(" Time (min): $(@sprintf("%.2f", results["ChordalGMRF+Zygote"].time)) ms") + println(" Time (min): $(@sprintf("%.2f", results["ChordalGMRF+Mooncake"].time)) ms") println(" Time (median): $(@sprintf("%.2f", median(bench_chordal.times) / 1.0e6)) ms") println(" Allocations: $(bench_chordal.allocs)") println(" Memory: $(@sprintf("%.2f", bench_chordal.memory / 1.0e6)) MB") catch e - println(" ✗ Failed: $e") - results["ChordalGMRF+Zygote"] = nothing + println(" ✗ Failed: $(typeof(e).name.name)") + if e isa ErrorException + println(" $(first(split(e.msg, '\n')))") + end + results["ChordalGMRF+Mooncake"] = nothing end # Summary comparison @@ -225,7 +231,7 @@ if results["FiniteDiff"] !== nothing println("\nGradient verification (comparing to FiniteDiff):") fd_grad = results["FiniteDiff"].gradient - for name in ["Zygote", "Enzyme", "ChordalGMRF+Zygote"] + for name in ["Zygote", "Enzyme", "ChordalGMRF+Mooncake"] if get(results, name, nothing) !== nothing grad = results[name].gradient abs_error = abs.(grad - fd_grad) @@ -271,13 +277,13 @@ if results["FiniteDiff"] !== nothing end # ChordalGMRF vs GMRF comparison (Zygote only) -if get(results, "ChordalGMRF+Zygote", nothing) !== nothing && get(results, "Zygote", nothing) !== nothing +if get(results, "ChordalGMRF+Mooncake", nothing) !== nothing && get(results, "Zygote", nothing) !== nothing println("\n" * "="^80) - println("GMRF vs ChordalGMRF COMPARISON (Zygote)") + println("GMRF vs ChordalGMRF COMPARISON (Zygote vs Mooncake)") println("="^80) r_gmrf = results["Zygote"] - r_chordal = results["ChordalGMRF+Zygote"] + r_chordal = results["ChordalGMRF+Mooncake"] println("\n " * "─"^76) println(@sprintf(" %-20s %12s %12s %12s %12s", "Implementation", "Time (ms)", "Speedup", "Allocs", "Memory (MB)")) diff --git a/benchmarks/gaussian_approximation_comparison.jl b/benchmarks/gaussian_approximation_comparison.jl index 976163d6..26e015af 100644 --- a/benchmarks/gaussian_approximation_comparison.jl +++ b/benchmarks/gaussian_approximation_comparison.jl @@ -19,9 +19,8 @@ using LinearSolve using Printf using Random using MatrixDepot -using Zygote - -using CliqueTrees.Multifrontal: symbolic, chordal +using Zygote, Mooncake +using DifferentiationInterface: DifferentiationInterface, AutoZygote, AutoMooncake println("="^80) println("GAUSSIAN APPROXIMATION COMPARISON: GMRF vs ChordalGMRF") @@ -146,8 +145,11 @@ for (matrix_name, desc) in test_matrices return sum(mean(post)) end - grad_gmrf = Zygote.gradient(loss_gmrf, μ)[1] - grad_chordal = Zygote.gradient(loss_chordal, μ)[1] + # Use prepared gradients for both backends + prep_gmrf = DifferentiationInterface.prepare_gradient(loss_gmrf, AutoZygote(), μ) + grad_gmrf = DifferentiationInterface.gradient(loss_gmrf, prep_gmrf, AutoZygote(), μ) + prep_chordal = DifferentiationInterface.prepare_gradient(loss_chordal, AutoMooncake(; config=nothing), μ) + grad_chordal = DifferentiationInterface.gradient(loss_chordal, prep_chordal, AutoMooncake(; config=nothing), μ) grad_abs_diff = norm(grad_gmrf - grad_chordal) grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1.0e-10) @@ -158,15 +160,15 @@ for (matrix_name, desc) in test_matrices println(" Match: $(grad_correct ? "✓ YES" : "✗ NO")") # Gradient performance benchmark - println("\n Gradient performance benchmark:") + println("\n Gradient performance benchmark (via DifferentiationInterface, prepared):") - print(" GMRF... ") - bench_grad_gmrf = @benchmark Zygote.gradient($loss_gmrf, $μ) samples = 10 seconds = 10 + print(" GMRF (Zygote)... ") + bench_grad_gmrf = @benchmark DifferentiationInterface.gradient($loss_gmrf, $prep_gmrf, AutoZygote(), $μ) samples = 10 seconds = 10 time_grad_gmrf = minimum(bench_grad_gmrf.times) / 1.0e6 println("$(@sprintf("%.3f", time_grad_gmrf)) ms") - print(" ChordalGMRF... ") - bench_grad_chordal = @benchmark Zygote.gradient($loss_chordal, $μ) samples = 10 seconds = 10 + print(" ChordalGMRF (Mooncake)... ") + bench_grad_chordal = @benchmark DifferentiationInterface.gradient($loss_chordal, $prep_chordal, AutoMooncake(; config=nothing), $μ) samples = 10 seconds = 10 time_grad_chordal = minimum(bench_grad_chordal.times) / 1.0e6 println("$(@sprintf("%.3f", time_grad_chordal)) ms") @@ -224,7 +226,7 @@ println("-"^95) # Gradient summary table println("\n" * "="^80) -println("SUMMARY: GRADIENT (Zygote)") +println("SUMMARY: GRADIENT (via DifferentiationInterface, prepared)") println("="^80) println("\n" * "-"^95) diff --git a/benchmarks/logpdf_comparison.jl b/benchmarks/logpdf_comparison.jl index 499198f3..6bdea847 100644 --- a/benchmarks/logpdf_comparison.jl +++ b/benchmarks/logpdf_comparison.jl @@ -18,9 +18,8 @@ using LinearSolve using Printf using Random using MatrixDepot -using Zygote - -using CliqueTrees.Multifrontal: symbolic, chordal +using Zygote, Mooncake +using DifferentiationInterface: DifferentiationInterface, AutoZygote, AutoMooncake println("="^80) println("LOGPDF COMPARISON: GMRF vs ChordalGMRF") @@ -114,8 +113,15 @@ for (matrix_name, desc) in test_matrices # Gradient correctness check println("\n Gradient correctness check (w.r.t. z):") - grad_gmrf = Zygote.gradient(x -> logpdf(gmrf, x), z)[1] - grad_chordal = Zygote.gradient(x -> logpdf(chordal_gmrf, x), z)[1] + gmrf_logpdf = x -> logpdf(gmrf, x) + chordal_logpdf = x -> logpdf(chordal_gmrf, x) + + # Use prepared gradients for both backends + prep_gmrf = DifferentiationInterface.prepare_gradient(gmrf_logpdf, AutoZygote(), z) + prep_chordal = DifferentiationInterface.prepare_gradient(chordal_logpdf, AutoMooncake(; config=nothing), z) + + grad_gmrf = DifferentiationInterface.gradient(gmrf_logpdf, prep_gmrf, AutoZygote(), z) + grad_chordal = DifferentiationInterface.gradient(chordal_logpdf, prep_chordal, AutoMooncake(; config=nothing), z) grad_abs_diff = norm(grad_gmrf - grad_chordal) grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1.0e-10) @@ -126,15 +132,15 @@ for (matrix_name, desc) in test_matrices println(" Match: $(grad_correct ? "✓ YES" : "✗ NO")") # Gradient performance benchmark - println("\n Gradient performance benchmark:") + println("\n Gradient performance benchmark (via DifferentiationInterface, prepared):") print(" GMRF... ") - bench_grad_gmrf = @benchmark Zygote.gradient(x -> logpdf($gmrf, x), $z) samples = 20 seconds = 5 + bench_grad_gmrf = @benchmark DifferentiationInterface.gradient($gmrf_logpdf, $prep_gmrf, AutoZygote(), $z) samples = 20 seconds = 5 time_grad_gmrf = minimum(bench_grad_gmrf.times) / 1.0e6 println("$(@sprintf("%.3f", time_grad_gmrf)) ms") print(" ChordalGMRF... ") - bench_grad_chordal = @benchmark Zygote.gradient(x -> logpdf($chordal_gmrf, x), $z) samples = 20 seconds = 5 + bench_grad_chordal = @benchmark DifferentiationInterface.gradient($chordal_logpdf, $prep_chordal, AutoMooncake(; config=nothing), $z) samples = 20 seconds = 5 time_grad_chordal = minimum(bench_grad_chordal.times) / 1.0e6 println("$(@sprintf("%.3f", time_grad_chordal)) ms") @@ -192,7 +198,7 @@ println("-"^95) # Gradient summary table println("\n" * "="^80) -println("SUMMARY: GRADIENT (Zygote)") +println("SUMMARY: GRADIENT (via DifferentiationInterface, prepared)") println("="^80) println("\n" * "-"^95) diff --git a/deps/MooncakeSparse b/deps/MooncakeSparse new file mode 160000 index 00000000..6b1b2efe --- /dev/null +++ b/deps/MooncakeSparse @@ -0,0 +1 @@ +Subproject commit 6b1b2efe83b19b5a90a8c5fcd6e760b1a6269e97 diff --git a/ext/GaussianMarkovRandomFieldsAutoDiff.jl b/ext/GaussianMarkovRandomFieldsAutoDiff.jl index 2a747a12..17de68a8 100644 --- a/ext/GaussianMarkovRandomFieldsAutoDiff.jl +++ b/ext/GaussianMarkovRandomFieldsAutoDiff.jl @@ -1,26 +1,9 @@ module GaussianMarkovRandomFieldsAutoDiff using GaussianMarkovRandomFields -using ForwardDiff, Zygote, LinearAlgebra, LinearMaps, SparseArrays +using ForwardDiff, Zygote, LinearAlgebra, LinearMaps import LinearMaps: _unsafe_mul! -# | | | -# )_) )_) )_) -# )___))___))___) -# )____)____)_____) -# _____|____|____|____ -# ---------\ /--------- -# ^^^^^ ^^^^^^^^^^^^^^^^^^^^^ -# ^^^^ ^^^^ ^^^ ^^ -# ^^^^ ^^^ -# -# Zygote accum for sparse Hermitian/Symmetric (piracy until upstream PR is merged) -const HermOrSymSparse{T, I} = Union{ - Hermitian{T, SparseMatrixCSC{T, I}}, - Symmetric{T, SparseMatrixCSC{T, I}}, -} -Zygote.accum(x::HermOrSymSparse, y::HermOrSymSparse) = x + y - function LinearMaps._unsafe_mul!(y, J::ADJacobianMap, x::AbstractVector) g(t) = J.f(J.x₀ + t * x) return y .= ForwardDiff.derivative(g, 0.0) diff --git a/src/GaussianMarkovRandomFields.jl b/src/GaussianMarkovRandomFields.jl index 8e20ffc4..ff225f3d 100644 --- a/src/GaussianMarkovRandomFields.jl +++ b/src/GaussianMarkovRandomFields.jl @@ -1,6 +1,6 @@ module GaussianMarkovRandomFields -include("piracy.jl") +include("../deps/MooncakeSparse/MooncakeSparse.jl") include("typedefs.jl") include("utils/utils.jl") include("linear_maps/linear_maps.jl") diff --git a/src/arithmetic/condition/gaussian_approximation.jl b/src/arithmetic/condition/gaussian_approximation.jl index bcf48e9c..37f53db8 100644 --- a/src/arithmetic/condition/gaussian_approximation.jl +++ b/src/arithmetic/condition/gaussian_approximation.jl @@ -2,7 +2,6 @@ using LinearAlgebra using SparseArrays using LinearMaps using CliqueTrees.Multifrontal: chordal, ChordalCholesky, triangular -using CliqueTrees.Multifrontal.Differential: ldivsym export gaussian_approximation @@ -79,7 +78,7 @@ end # Solver abstraction for gaussian_approximation Newton iteration. # Allows shared iteration logic for both LinearSolve-backed GMRF and ChordalCholesky-backed ChordalGMRF. _ga_init_solver(gmrf::GMRF) = deepcopy(linsolve_cache(gmrf)) -_ga_init_solver(gmrf::ChordalGMRF{T}) where {T} = ChordalCholesky{:L, T}(gmrf.P, gmrf.L.S) +_ga_init_solver(gmrf::ChordalGMRF) = copy(gmrf.F) function _ga_update_and_solve!(solver, Q_base, H_k, b, ::GMRF) Q_new = prepare_for_linsolve(Q_base - H_k, solver.alg) @@ -101,7 +100,7 @@ function _ga_make_posterior(x, Q, solver, prior::Union{GMRF, ConstrainedGMRF}, c end function _ga_make_posterior(x, Q, solver, prior::ChordalGMRF, ::Nothing) - return ChordalGMRF(x, Q, solver.L, prior.P) + return ChordalGMRF(x, Q, solver) end """ diff --git a/src/autodiff/autodiff.jl b/src/autodiff/autodiff.jl index 4234f8d3..bb78555e 100644 --- a/src/autodiff/autodiff.jl +++ b/src/autodiff/autodiff.jl @@ -21,3 +21,6 @@ include("constructors.jl") include("precision_gradient.jl") # Helper for computing precision gradients include("logpdf.jl") include("gaussian_approximation.jl") + +# Mooncake-specific rules for ChordalGMRF +include("mooncake_gaussian_approximation.jl") diff --git a/src/autodiff/constructors.jl b/src/autodiff/constructors.jl index e80a6ba3..13ccfe60 100644 --- a/src/autodiff/constructors.jl +++ b/src/autodiff/constructors.jl @@ -2,6 +2,7 @@ using ChainRulesCore using SparseArrays using LinearAlgebra using LinearMaps +using CliqueTrees.Multifrontal: ChordalCholesky """ ChainRulesCore.rrule(::Type{GMRF}, μ::AbstractVector, Q::Union{AbstractMatrix, LinearMaps.LinearMap}, algorithm) @@ -122,3 +123,36 @@ function ChainRulesCore.rrule( return result, ConstrainedGMRF_pullback end + +""" + ChainRulesCore.rrule(::Type{ChordalGMRF}, μ::AbstractVector, Q::SparseMatrixCSC) + +Automatic differentiation rule for ChordalGMRF constructor. + +This rrule enables differentiation through ChordalGMRF construction, allowing gradients +to flow back to the mean vector and precision matrix. The factorization F is treated +as non-differentiable. +""" +function ChainRulesCore.rrule(::Type{ChordalGMRF}, μ::AbstractVector, Q::SparseMatrixCSC) + x = ChordalGMRF(μ, Q) + + function ChordalGMRF_pullback(x̄) + μ̄ = x̄.μ + Q̄ = x̄.Q + return NoTangent(), μ̄, Q̄ + end + + return x, ChordalGMRF_pullback +end + +function ChainRulesCore.rrule(::Type{ChordalGMRF}, μ::AbstractVector, Q::SparseMatrixCSC, F::ChordalCholesky) + x = ChordalGMRF(μ, Q, F) + + function ChordalGMRF_pullback(x̄) + μ̄ = x̄.μ + Q̄ = x̄.Q + return NoTangent(), μ̄, Q̄, NoTangent() + end + + return x, ChordalGMRF_pullback +end diff --git a/src/autodiff/gaussian_approximation.jl b/src/autodiff/gaussian_approximation.jl index dfae76a7..a8c6cc4e 100644 --- a/src/autodiff/gaussian_approximation.jl +++ b/src/autodiff/gaussian_approximation.jl @@ -1,6 +1,5 @@ using ChainRulesCore using LinearAlgebra -using CliqueTrees.Multifrontal.Differential: ldivsym """ _is_zero_tangent(x) -> Bool @@ -244,7 +243,7 @@ function _ift_solve(posterior::Union{GMRF, ConstrainedGMRF}, x̄_total, prior_gm end function _ift_solve(posterior::ChordalGMRF, x̄_total, ::ChordalGMRF) - return ldivsym(precision_matrix(posterior), posterior.L, posterior.P, x̄_total) + return posterior.F \ x̄_total end """ @@ -353,7 +352,6 @@ function _add_precision_tangent(prior_tangent, prior::ChordalGMRF, Q̄) return Tangent{typeof(prior)}(; μ = prior_μ̄, Q = combined_Q̄, - L = NoTangent(), - P = NoTangent(), + F = NoTangent(), ) end diff --git a/src/autodiff/logpdf.jl b/src/autodiff/logpdf.jl index 71f5311c..d7b75d9e 100644 --- a/src/autodiff/logpdf.jl +++ b/src/autodiff/logpdf.jl @@ -99,4 +99,3 @@ function ChainRulesCore.rrule(::typeof(logpdf), x::AbstractGMRF, z::AbstractVect end end -ChainRulesCore.@opt_out rrule(::typeof(logpdf), ::ChordalGMRF, ::AbstractVector) diff --git a/src/autodiff/mooncake_gaussian_approximation.jl b/src/autodiff/mooncake_gaussian_approximation.jl new file mode 100644 index 00000000..e756d58e --- /dev/null +++ b/src/autodiff/mooncake_gaussian_approximation.jl @@ -0,0 +1,117 @@ +using Mooncake +using Mooncake: @is_primitive, @mooncake_overlay, MinimalCtx, CoDual, NoRData, NoFData, primal, tangent, fdata, zero_tangent +using SparseArrays: nonzeros, SparseMatrixCSC +using LinearAlgebra: Hermitian +using CliqueTrees.Multifrontal: ChordalCholesky + +@is_primitive MinimalCtx Tuple{Type{ChordalGMRF}, AbstractVector, SparseMatrixCSC} + +function Mooncake.rrule!!( + ::CoDual{Type{ChordalGMRF}}, + cdμ::CoDual{<:AbstractVector}, + cdQ::CoDual{<:SparseMatrixCSC}, +) + μ, Σμ = MooncakeSparse.primaltangent(cdμ) + Q, ΣQ = MooncakeSparse.primaltangent(cdQ) + + gmrf = ChordalGMRF(μ, Q) + dy = fdata(zero_tangent(gmrf)) + + function pullback!!(::NoRData) + dμ = MooncakeSparse.toarray(gmrf.μ, dy.data.μ) + dQ = MooncakeSparse.toarray(gmrf.Q, dy.data.Q) + + Σμ .+= dμ + nonzeros(ΣQ) .+= nonzeros(parent(dQ)) + + return NoRData(), NoRData(), NoRData() + end + + return CoDual(gmrf, dy), pullback!! +end + +@is_primitive MinimalCtx Tuple{Type{ChordalGMRF}, AbstractVector, Hermitian, ChordalCholesky} + +function Mooncake.rrule!!( + ::CoDual{Type{ChordalGMRF}}, + cdμ::CoDual{<:AbstractVector}, + cdQ::CoDual{<:Hermitian}, + cdF::CoDual{<:ChordalCholesky}, +) + μ, Σμ = MooncakeSparse.primaltangent(cdμ) + Q, ΣQ = MooncakeSparse.primaltangent(cdQ) + F = primal(cdF) + + gmrf = ChordalGMRF(μ, Q, F) + dy = fdata(zero_tangent(gmrf)) + + function pullback!!(::NoRData) + dμ = MooncakeSparse.toarray(gmrf.μ, dy.data.μ) + dQ = MooncakeSparse.toarray(gmrf.Q, dy.data.Q) + + Σμ .+= dμ + nonzeros(parent(ΣQ)) .+= nonzeros(parent(dQ)) + + return NoRData(), NoRData(), NoRData(), NoRData() + end + + return CoDual(gmrf, dy), pullback!! +end + +function gaussian_approximation_notangent(prior::ChordalGMRF, obslik::ObservationLikelihood; kwargs...) + return gaussian_approximation(prior, obslik; kwargs...) +end + +@is_primitive MinimalCtx Tuple{typeof(gaussian_approximation_notangent), ChordalGMRF, ObservationLikelihood} +@is_primitive MinimalCtx Tuple{typeof(Core.kwcall), Any, typeof(gaussian_approximation_notangent), ChordalGMRF, ObservationLikelihood} + +function Mooncake.rrule!!( + ::CoDual{typeof(gaussian_approximation_notangent)}, + cdprior::CoDual{<:ChordalGMRF}, + cdobslik::CoDual{<:ObservationLikelihood}, +) + prior = primal(cdprior) + obslik = primal(cdobslik) + posterior = gaussian_approximation_notangent(prior, obslik) + + function pullback!!(::NoRData) + return NoRData(), Mooncake.zero_rdata(prior), Mooncake.zero_rdata(obslik) + end + + return CoDual(posterior, fdata(zero_tangent(posterior))), pullback!! +end + +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, + cdkwargs::CoDual, + ::CoDual{typeof(gaussian_approximation_notangent)}, + cdprior::CoDual{<:ChordalGMRF}, + cdobslik::CoDual{<:ObservationLikelihood}, +) + prior = primal(cdprior) + obslik = primal(cdobslik) + kwargs = primal(cdkwargs) + posterior = gaussian_approximation_notangent(prior, obslik; kwargs...) + + function pullback!!(::NoRData) + return NoRData(), NoRData(), NoRData(), Mooncake.zero_rdata(prior), Mooncake.zero_rdata(obslik) + end + + return CoDual(posterior, fdata(zero_tangent(posterior))), pullback!! +end + +@mooncake_overlay function gaussian_approximation( + prior::ChordalGMRF, + obslik::ObservationLikelihood; + kwargs... +) + posterior = gaussian_approximation_notangent(prior, obslik; kwargs...) + x_star = mean(posterior) + + grad = ∇ₓ_neg_log_posterior(prior, obslik, x_star) + x_corrected = x_star - posterior.F \ grad + + Q_post = hermdiff(precision_matrix(prior), loghessian(x_corrected, obslik)) + + return ChordalGMRF(x_corrected, Q_post, posterior.F) +end diff --git a/src/chordal_gmrf.jl b/src/chordal_gmrf.jl index 602f6777..43c8cbbf 100644 --- a/src/chordal_gmrf.jl +++ b/src/chordal_gmrf.jl @@ -1,26 +1,24 @@ -using CliqueTrees.Multifrontal: ChordalTriangular, Permutation, ChordalSymbolic, symbolic, chordal, selinv as mselinv, logdet -using LinearAlgebra: Hermitian, cholesky, diag, ldiv!, axpy!, dot +using CliqueTrees.Multifrontal: ChordalCholesky, selinv as mselinv, logdet +using LinearAlgebra: Hermitian, cholesky!, diag, ldiv!, axpy!, dot using SparseArrays: SparseMatrixCSC using Random: AbstractRNG, randn export ChordalGMRF -struct ChordalGMRF{T <: Real, Herm <: HermSparse{T}, Tri <: ChordalTriangular{:N, :L, T}, Prm <: Permutation, Mea <: AbstractVector{T}} <: AbstractGMRF{T, Herm} +struct ChordalGMRF{T <: Real, Hrm <: Hermitian, Fac <: ChordalCholesky, Mea <: AbstractVector{T}} <: AbstractGMRF{T, Hrm} μ::Mea - Q::Herm - L::Tri - P::Prm + Q::Hrm + F::Fac end -function ChordalGMRF(μ::AbstractVector, Q::SparseMatrixCSC, L, P) - return ChordalGMRF(μ, Hermitian(Q, :L), L, P) +function ChordalGMRF(μ::AbstractVector, Q::SparseMatrixCSC, F::ChordalCholesky) + return ChordalGMRF(μ, Hermitian(Q, :L), F) end function ChordalGMRF(μ::AbstractVector, Q::SparseMatrixCSC; kw...) H = Hermitian(Q, :L) - P, S = symbolic(H; kw...) - L = cholesky(chordal(H, P, S)) - return ChordalGMRF(μ, H, L, P) + F = cholesky!(ChordalCholesky(H; kw...)) + return ChordalGMRF(μ, H, F) end function Base.length(d::ChordalGMRF) @@ -40,26 +38,26 @@ function precision_matrix(d::ChordalGMRF) end function logdetcov(d::ChordalGMRF) - return -logdet(precision_matrix(d), d.L, d.P) + return -logdet(d.Q, d.F) end function sqmahal(d::ChordalGMRF, x::AbstractVector) r = x - d.μ - return dot(r, precision_matrix(d), r) + return dot(r, d.Q, r) end function gradlogpdf(d::ChordalGMRF, x::AbstractVector) - return precision_matrix(d) * (d.μ - x) + return d.Q * (d.μ - x) end function var(d::ChordalGMRF) - Σ = mselinv(precision_matrix(d), d.L, d.P) + Σ = mselinv(d.Q, d.F) return diag(Σ) end function _rand!(rng::AbstractRNG, d::ChordalGMRF{T}, x::AbstractVector) where {T} z = randn(rng, T, length(x)) - return axpy!(1, d.μ, d.P \ ldiv!(d.L', d.P * z)) + return axpy!(true, d.μ, d.F.P \ ldiv!(d.F.U, z)) end function Base.show(io::IO, d::ChordalGMRF{T}) where {T} @@ -77,13 +75,3 @@ function Base.show(io::IO, ::MIME"text/plain", d::ChordalGMRF{T}) where {T} print(io, " Mean: [$(μ[1]), $(μ[2]), $(μ[3]), ..., $(μ[end - 2]), $(μ[end - 1]), $(μ[end])]") end end - -# ChainRulesCore rrule for ChordalGMRF constructor -# ChordalGMRF is defined by (μ, Q). L and P are derived - gradients never flow through them. -using ChainRulesCore: ChainRulesCore, NoTangent - -function ChainRulesCore.rrule(::Type{ChordalGMRF}, μ::AbstractVector, Q::SparseMatrixCSC; kw...) - result = ChordalGMRF(μ, Q; kw...) - ChordalGMRF_pullback(ȳ) = (NoTangent(), ȳ.μ, ȳ.Q) - return result, ChordalGMRF_pullback -end diff --git a/src/piracy.jl b/src/piracy.jl deleted file mode 100644 index 65bd9db4..00000000 --- a/src/piracy.jl +++ /dev/null @@ -1,507 +0,0 @@ -# | | | -# )_) )_) )_) -# )___))___))___) -# )____)____)_____) -# _____|____|____|____ -# ---------\ /--------- -# ^^^^^ ^^^^^^^^^^^^^^^^^^^^^ -# ^^^^ ^^^^ ^^^ ^^ -# ^^^^ ^^^ -# -# Type piracy to enable autodiff for Hermitian/Symmetric sparse matrices. -# These changes have been submitted as PRs to ChainRulesCore, ChainRules, and Zygote. -# This file can be removed once those PRs are merged and released. - -using ChainRulesCore -using ChainRulesCore: ProjectTo, project_type, _projection_mismatch, NoTangent, ZeroTangent, AbstractZero, @thunk, unthunk -using LinearAlgebra -using LinearAlgebra: Hermitian, Symmetric, Adjoint, Transpose, AdjOrTrans, dot, rmul!, tril, triu -using SparseArrays -using SparseArrays: SparseMatrixCSC, nzrange, rowvals, getcolptr, nonzeros - -##### -##### Type aliases -##### - -const HermSparse{T, I} = Hermitian{T, SparseMatrixCSC{T, I}} -const SymSparse{T, I} = Symmetric{T, SparseMatrixCSC{T, I}} -const HermOrSymSparse{T, I} = Union{HermSparse{T, I}, SymSparse{T, I}} - -const DenseMat{T} = Union{StridedMatrix{T}, AdjOrTrans{T, <:StridedVecOrMat{T}}} -const DenseVecOrMat{T} = Union{DenseMat{T}, StridedVector{T}} - -##### -##### ChainRulesCore: ProjectTo for HermOrSymSparse -##### - -const SparseProjectToData{T, I} = NamedTuple{ - (:element, :axes, :rowval, :nzranges, :colptr), - Tuple{ - ProjectTo{T, NamedTuple{(), Tuple{}}}, - Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, - Vector{I}, - Vector{UnitRange{Int64}}, - Vector{I}, - }, -} - -const SparseProjectTo{T, I} = ProjectTo{SparseMatrixCSC, SparseProjectToData{T, I}} - -const HermSparseProjectTo{T, I} = ProjectTo{ - Hermitian, - NamedTuple{ - (:uplo, :parent), - Tuple{Symbol, SparseProjectTo{T, I}}, - }, -} - -const SymSparseProjectTo{T, I} = ProjectTo{ - Symmetric, - NamedTuple{ - (:uplo, :parent), - Tuple{Symbol, SparseProjectTo{T, I}}, - }, -} - -function ChainRulesCore.ProjectTo(x::HermSparse{T}) where {T <: Number} - return ProjectTo{Hermitian}(; - uplo = Symbol(x.uplo), - parent = ProjectTo(parent(x)), - ) -end - -function ChainRulesCore.ProjectTo(x::SymSparse{T}) where {T <: Number} - return ProjectTo{Symmetric}(; - uplo = Symbol(x.uplo), - parent = ProjectTo(parent(x)), - ) -end - -function project!(A::SparseMatrixCSC{T, I}, B::SparseMatrixCSC{<:Any, J}, uplo::Char) where {T, I, J} - @assert size(A) == size(B) - - @inbounds for j in axes(A, 2) - p = getcolptr(A)[j] - pstop = getcolptr(A)[j + 1] - q = getcolptr(B)[j] - qstop = getcolptr(B)[j + 1] - - while p < pstop - i = rowvals(A)[p] - - if (uplo == 'L' && i >= j) || (uplo == 'U' && i <= j) - while q < qstop && rowvals(B)[q] < i - q += one(J) - end - - if q < qstop && rowvals(B)[q] == i - nonzeros(A)[p] = nonzeros(B)[q] - else - nonzeros(A)[p] = zero(T) - end - end - - p += one(I) - end - end - - return A -end - -function project!(A::HermOrSymSparse, B::HermOrSymSparse) - if A.uplo == B.uplo - project!(parent(A), parent(B), A.uplo) - elseif A.uplo == 'L' - project!(parent(A), tril(B), A.uplo) - else - project!(parent(A), triu(B), A.uplo) - end - - return A -end - -function sparse_from_project(P::SparseProjectTo{T, I}) where {T, I} - m, n = map(length, P.axes) - return SparseMatrixCSC(m, n, P.colptr, P.rowval, zeros(T, length(P.rowval))) -end - -function sparse_from_project(P::HermSparseProjectTo) - return Hermitian(sparse_from_project(P.parent), P.uplo) -end - -function sparse_from_project(P::SymSparseProjectTo) - return Symmetric(sparse_from_project(P.parent), P.uplo) -end - -function checkpatternsym(n, Acolptr::Vector{IA}, Bcolptr::Vector{IB}, Arowval::AbstractVector, Browval::AbstractVector, uplo::Char) where {IA, IB} - for j in 1:n - pa = Acolptr[j] - pb = Bcolptr[j] - pastop = Acolptr[j + 1] - pbstop = Bcolptr[j + 1] - - while pa < pastop && pb < pbstop - ia = Arowval[pa] - ib = Browval[pb] - - if (uplo == 'L' && ia < j) || (uplo == 'U' && ia > j) - pa += one(IA) - elseif (uplo == 'L' && ib < j) || (uplo == 'U' && ib > j) - pb += one(IB) - elseif ia == ib - pa += one(IA) - pb += one(IB) - else - return false - end - end - - while pa < pastop - ia = Arowval[pa] - - if (uplo == 'L' && ia >= j) || (uplo == 'U' && ia <= j) - return false - end - - pa += one(IA) - end - - while pb < pbstop - ib = Browval[pb] - - if (uplo == 'L' && ib >= j) || (uplo == 'U' && ib <= j) - return false - end - - pb += one(IB) - end - end - - return true -end - -function checkpatternsym(P, dX) - return false -end - -function checkpatternsym(P::Union{HermSparseProjectTo{T, I}, SymSparseProjectTo{T, I}}, dX::HermOrSymSparse{T, I}) where {T, I} - dXP = parent(dX) - return Symbol(dX.uplo) == P.uplo && checkpatternsym(size(dXP, 2), P.parent.colptr, dXP.colptr, P.parent.rowval, dXP.rowval, dX.uplo) -end - -function (P::HermSparseProjectTo{T, I})(dX::HermSparse) where {T, I} - if checkpatternsym(P, dX) - return dX - else - return project!(sparse_from_project(P), dX) - end -end - -function (P::SymSparseProjectTo{T, I})(dX::SymSparse) where {T, I} - if checkpatternsym(P, dX) - return dX - else - return project!(sparse_from_project(P), dX) - end -end - -function (P::HermSparseProjectTo{T, I})(dX::SymSparse{T, I}) where {T <: Real, I} - if checkpatternsym(P, dX) - return Hermitian(parent(dX), P.uplo) - else - return project!(sparse_from_project(P), dX) - end -end - -function (P::SymSparseProjectTo{T, I})(dX::HermSparse{T, I}) where {T <: Real, I} - if checkpatternsym(P, dX) - return Symmetric(parent(dX), P.uplo) - else - return project!(sparse_from_project(P), dX) - end -end - -##### -##### ChainRules: selupd! for computing sparse gradients -##### - -function unwrap(A) - if A isa Adjoint - B = parent(A) - - if B isa Transpose - return (parent(B), Val(:N), Val(:C)) - else - return (B, Val(:T), Val(:C)) - end - elseif A isa Transpose - B = parent(A) - - if B isa Adjoint - return (parent(B), Val(:N), Val(:C)) - else - return (B, Val(:T), Val(:N)) - end - else - return (A, Val(:N), Val(:N)) - end -end - -# SELected UPDate: compute the selected low-rank update -# -# C ← α A Bᴴ + conj(α) B Aᴴ + β C -# -# The update is only applied to the structural nonzeros of C. -function selupd!(C::HermSparse, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) - selupd!(parent(C), C.uplo, A, adjoint(B), α, β) - selupd!(parent(C), C.uplo, B, adjoint(A), conj(α), 1) - return C -end - -# SELected UPDate: compute the selected low-rank update -# -# C ← α A Bᴴ + α conj(B) Aᵀ + β C -# -# The update is only applied to the structural nonzeros of C. -function selupd!(C::SymSparse, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) - selupd!(parent(C), C.uplo, A, adjoint(B), α, β) - selupd!(parent(C), C.uplo, adjoint(transpose(B)), transpose(A), α, 1) - return C -end - -# SELected UPDate: compute the selected low-rank update -# -# C ← α A B + β C -# -# The update is only applied to the structural nonzeros of C. -function selupd!(C::SparseMatrixCSC, uplo::Char, A::AbstractVecOrMat, B::AbstractVecOrMat, α, β) - AP, tA, cA = unwrap(A) - BP, tB, cB = unwrap(B) - return selupd_impl!(C, uplo, AP, BP, α, β, tA, cA, tB, cB) -end - -function selupd_impl!(C::SparseMatrixCSC, uplo::Char, A::AbstractVector, B::AbstractVector, α, β, ::Val{tA}, ::Val{cA}, ::Val{tB}, ::Val{cB}) where {tA, cA, tB, cB} - @assert size(C, 1) == size(C, 2) == length(A) == length(B) - - @inbounds for j in axes(C, 2) - Bj = cB === :C ? conj(B[j]) : B[j] - - for p in nzrange(C, j) - i = rowvals(C)[p] - - if (uplo == 'L' && i >= j) || (uplo == 'U' && i <= j) - Ai = cA === :C ? conj(A[i]) : A[i] - - if iszero(β) - nonzeros(C)[p] = α * Ai * Bj - else - nonzeros(C)[p] = β * nonzeros(C)[p] + α * Ai * Bj - end - end - end - end - - return C -end - -function selupd_impl!(C::SparseMatrixCSC, uplo::Char, A::AbstractMatrix, B::AbstractMatrix, α, β, tA::Val{TA}, cA::Val{CA}, tB::Val{TB}, cB::Val{CB}) where {TA, CA, TB, CB} - @assert size(C, 1) == size(C, 2) - - if TA === :N && TB === :N - @assert size(A, 1) == size(C, 1) - @assert size(B, 2) == size(C, 1) - @assert size(A, 2) == size(B, 1) - elseif TA === :N && TB !== :N - @assert size(A, 1) == size(C, 1) - @assert size(B, 1) == size(C, 1) - @assert size(A, 2) == size(B, 2) - elseif TA !== :N && TB === :N - @assert size(A, 2) == size(C, 1) - @assert size(B, 2) == size(C, 1) - @assert size(A, 1) == size(B, 1) - else - @assert size(A, 2) == size(C, 1) - @assert size(B, 1) == size(C, 1) - @assert size(A, 1) == size(B, 2) - end - - if TA === :N - rng = axes(A, 2) - else - rng = axes(A, 1) - end - - if iszero(β) - fill!(nonzeros(C), β) - else - rmul!(nonzeros(C), β) - end - - for k in rng - if TA === :N - Ak = view(A, :, k) - else - Ak = view(A, k, :) - end - - if TB === :N - Bk = view(B, k, :) - else - Bk = view(B, :, k) - end - - selupd_impl!(C, uplo, Ak, Bk, α, 1, tA, cA, tB, cB) - end - - return C -end - -##### -##### ChainRules: rrule/frule implementations -##### - -function mul_rrule_impl(A::HermOrSymSparse, B::DenseVecOrMat, ΔC) - ΔB = A * ΔC - ΔA = if ΔC isa AbstractZero - ZeroTangent() - else - @thunk begin - ΔA = similar(A) - selupd!(ΔA, ΔC, B, 1 / 2, 0) - ΔA - end - end - return ΔA, ΔB -end - -function mul_rrule_impl(A::DenseMat, B::HermSparse, ΔC) - ΔA = ΔC * B - ΔB = if ΔC isa AbstractZero - ZeroTangent() - else - @thunk begin - ΔB = similar(B) - selupd!(ΔB, A', ΔC', 1 / 2, 0) - ΔB - end - end - return ΔA, ΔB -end - -function mul_rrule_impl(A::DenseMat, B::SymSparse, ΔC) - ΔA = ΔC * B - ΔB = if ΔC isa AbstractZero - ZeroTangent() - else - @thunk begin - ΔB = similar(B) - selupd!(ΔB, transpose(ΔC), transpose(A), 1 / 2, 0) - ΔB - end - end - return ΔA, ΔB -end - -function dot_rrule_impl(x::StridedVector, A::HermOrSymSparse, y::StridedVector, Ax::StridedVector, Ay::StridedVector, Δz) - Δx = @thunk Δz * Ay - Δy = @thunk Δz * Ax - - ΔA = if Δz isa AbstractZero - ZeroTangent() - else - @thunk begin - ΔA = similar(A) - selupd!(ΔA, x, y, Δz / 2, 0) - ΔA - end - end - - return Δx, ΔA, Δy -end - -function mul_rrule(A::HermOrSymSparse, B::DenseVecOrMat) - C = A * B - - function pullback(ΔC) - ΔA, ΔB = mul_rrule_impl(A, B, ΔC) - return NoTangent(), ΔA, ΔB - end - - return C, pullback ∘ unthunk -end - -function mul_rrule(A::DenseMat, B::HermOrSymSparse) - C = A * B - - function pullback(ΔC) - ΔA, ΔB = mul_rrule_impl(A, B, ΔC) - return NoTangent(), ΔA, ΔB - end - - return C, pullback ∘ unthunk -end - -function dot_rrule(x::StridedVector, A::HermOrSymSparse, y::StridedVector) - Ax = A * x - Ay = A * y - z = dot(x, Ay) - - function pullback(Δz) - Δx, ΔA, Δy = dot_rrule_impl(x, A, y, Ax, Ay, Δz) - return NoTangent(), Δx, ΔA, Δy - end - - return z, pullback ∘ unthunk -end - -function mul_frule_impl(A, B, dA, dB) - return A * B, dA * B + A * dB -end - -function dot_frule_impl(x::StridedVector, A::HermOrSymSparse, y::StridedVector, dx, dA, dy) - return dot(x, A, y), dot(dx, A, y) + dot(x, A, dy) + dot(x, dA, y) -end - -##### -##### ChainRules: frule / rrule dispatches -##### - -for T in (HermSparse, SymSparse) - # A * X - @eval function ChainRulesCore.frule((_, dA, dX)::Tuple, ::typeof(*), A::$T, X::DenseVecOrMat) - return mul_frule_impl(A, X, dA, dX) - end - - @eval function ChainRulesCore.rrule(::typeof(*), A::$T, X::DenseVecOrMat) - return mul_rrule(A, X) - end - - # X * A - @eval function ChainRulesCore.frule((_, dX, dA)::Tuple, ::typeof(*), X::DenseMat, A::$T) - return mul_frule_impl(X, A, dX, dA) - end - - @eval function ChainRulesCore.rrule(::typeof(*), X::DenseMat, A::$T) - return mul_rrule(X, A) - end - - # dot(x, A, y) - vectors only, matching upstream ChainRules - @eval function ChainRulesCore.frule((_, dx, dA, dy)::Tuple, ::typeof(dot), x::StridedVector, A::$T, y::StridedVector) - return dot_frule_impl(x, A, y, dx, dA, dy) - end - - @eval function ChainRulesCore.rrule(::typeof(dot), x::StridedVector, A::$T, y::StridedVector) - return dot_rrule(x, A, y) - end -end - -# The rrules above cause method invalidation that exposes an upstream bug in -# ChainRulesCore's ProjectTo{SymTridiagonal}: it extracts only one triangle of -# the off-diagonal, losing the factor of 2 from symmetry. -function ChainRulesCore.rrule(::typeof(sum), Q::SymTridiagonal) - function sum_symtridiag_pullback(ȳ) - s = unthunk(ȳ) - return NoTangent(), Tangent{SymTridiagonal}(dv = fill(s, length(Q.dv)), ev = fill(2s, length(Q.ev))) - end - return sum(Q), sum_symtridiag_pullback -end diff --git a/test/autodiff/test_gaussian_approximation_chordal.jl b/test/autodiff/test_gaussian_approximation_chordal.jl index ffbc74c3..f9ee263e 100644 --- a/test/autodiff/test_gaussian_approximation_chordal.jl +++ b/test/autodiff/test_gaussian_approximation_chordal.jl @@ -5,12 +5,10 @@ using SparseArrays using LinearAlgebra using Random -using CliqueTrees.Multifrontal: symbolic, chordal, HermTri, Permutation - using DifferentiationInterface -using FiniteDiff, ForwardDiff, Zygote +using FiniteDiff, Mooncake -backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] +backends = Any[("Mooncake", AutoMooncake())] @testset "$backend_name ChordalGMRF autodiff tests" for (backend_name, backend) in backends # Set seed for reproducibility @@ -23,7 +21,7 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] end # Test pipeline: hyperparameters → ChordalGMRF → gaussian_approximation → logpdf - function test_gauss_approx_pipeline(θ::Vector, y::Vector, x::Vector, P, S, k::Int) + function test_gauss_approx_pipeline(θ::Vector, y::Vector, x::Vector, k::Int) # Extract hyperparameters ρ = θ[1] # AR parameter μ_const = θ[2] # constant mean @@ -34,12 +32,8 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] # Create constant mean vector μ = μ_const * ones(k) - # Use chordal (now differentiable!) - J = chordal(Hermitian(Q, :L), P, S) - L = cholesky(J) - - # Create prior ChordalGMRF (pass original Q, not chordal J) - prior_gmrf = ChordalGMRF(μ, Q, L, P) + # Create prior ChordalGMRF + prior_gmrf = ChordalGMRF(μ, Q) # Create Poisson observation likelihood obs_model = ExponentialFamily(Poisson) @@ -59,18 +53,14 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] y = [2, 1, 3, 2, 1, 4, 2, 1] # Poisson count data x = randn(k) .+ 0.5 # Evaluation point - # Pre-compute sparsity structure - Q_ref = ar_precision(0.5, k) - P, S = symbolic(Q_ref) - grad_test = DifferentiationInterface.gradient( - θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + θ -> test_gauss_approx_pipeline(θ, y, x, k), backend, θ ) grad_fd = DifferentiationInterface.gradient( - θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + θ -> test_gauss_approx_pipeline(θ, y, x, k), fd_backend, θ ) @@ -87,23 +77,19 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] y = [1, 2, 1, 3, 1, 2] x = randn(k) .+ 0.3 - # Pre-compute sparsity structure - Q_ref = ar_precision(0.5, k) - P, S = symbolic(Q_ref) - # Test different ρ and μ values for ρ in [0.2, 0.5] for μ_const in [0.3, 0.8] θ = [ρ, μ_const] grad_test = DifferentiationInterface.gradient( - θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + θ -> test_gauss_approx_pipeline(θ, y, x, k), backend, θ ) grad_fd = DifferentiationInterface.gradient( - θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + θ -> test_gauss_approx_pipeline(θ, y, x, k), fd_backend, θ ) @@ -124,18 +110,12 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] y = randn(k) .* 0.3 .+ 0.2 x = randn(k) - # Pre-compute sparsity structure - Q_ref = ar_precision(0.5, k) - P, S = symbolic(Q_ref) - - function gaussian_lik_pipeline(θ, y, x, P, S, k) + function gaussian_lik_pipeline(θ, y, x, k) ρ, μ_const = θ Q = ar_precision(ρ, k) μ = μ_const * ones(k) - J = chordal(Hermitian(Q, :L), P, S) - L = cholesky(J) - prior_gmrf = ChordalGMRF(μ, Q, L, P) + prior_gmrf = ChordalGMRF(μ, Q) obs_model = ExponentialFamily(Normal) obs_lik = obs_model(y; σ = 0.5) @@ -144,13 +124,13 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] end grad_test = DifferentiationInterface.gradient( - θ -> gaussian_lik_pipeline(θ, y, x, P, S, k), + θ -> gaussian_lik_pipeline(θ, y, x, k), backend, θ ) grad_fd = DifferentiationInterface.gradient( - θ -> gaussian_lik_pipeline(θ, y, x, P, S, k), + θ -> gaussian_lik_pipeline(θ, y, x, k), fd_backend, θ ) @@ -169,18 +149,14 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] y = [1, 2, 1, 1] x = randn(k) .+ 0.4 - # Pre-compute sparsity structure - Q_ref = ar_precision(0.5, k) - P, S = symbolic(Q_ref) - grad_test = DifferentiationInterface.gradient( - θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + θ -> test_gauss_approx_pipeline(θ, y, x, k), backend, θ ) grad_fd = DifferentiationInterface.gradient( - θ -> test_gauss_approx_pipeline(θ, y, x, P, S, k), + θ -> test_gauss_approx_pipeline(θ, y, x, k), fd_backend, θ ) @@ -192,27 +168,18 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] @test maximum(rel_error) < 5.0e-2 end - @testset "Basic logpdf autodiff with chordal" begin + @testset "Basic logpdf autodiff" begin k = 10 z = randn(k) - # Pre-compute sparsity structure (not differentiable) - Q_ref = ar_precision(0.5, k) - P, S = symbolic(Q_ref) - - # Test pipeline: use chordal which IS now differentiable - function test_chordal_pipeline(θ::AbstractVector, z::AbstractVector, P, S, k) + function test_logpdf_pipeline(θ::AbstractVector, z::AbstractVector, k) ρ = θ[1] μ_const = θ[2] Q = ar_precision(ρ, k) μ = μ_const * ones(k) - # Use chordal directly (now differentiable!) - J = chordal(Hermitian(Q, :L), P, S) - L = cholesky(J) - - gmrf = ChordalGMRF(μ, Q, L, P) + gmrf = ChordalGMRF(μ, Q) return logpdf(gmrf, z) end @@ -220,14 +187,14 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] # Compute gradients using AD backend grad_test = DifferentiationInterface.gradient( - θ -> test_chordal_pipeline(θ, z, P, S, k), + θ -> test_logpdf_pipeline(θ, z, k), backend, θ ) # Compute gradients using finite differences grad_fd = DifferentiationInterface.gradient( - θ -> test_chordal_pipeline(θ, z, P, S, k), + θ -> test_logpdf_pipeline(θ, z, k), fd_backend, θ ) @@ -240,7 +207,7 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] @test maximum(rel_error) < 1.0e-2 end - @testset "2D grid precision with chordal" begin + @testset "2D grid precision" begin # Build 2D grid precision matrix using spdiagm (Zygote-compatible) function grid_precision(α, grid_size) n = grid_size^2 @@ -261,10 +228,7 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] n = grid_size^2 z = randn(n) - Q_ref = grid_precision(0.5, grid_size) - P, S = symbolic(Q_ref) - - function test_grid_pipeline(θ::AbstractVector, z::AbstractVector, P, S, grid_size) + function test_grid_pipeline(θ::AbstractVector, z::AbstractVector, grid_size) α = θ[1] μ_const = θ[2] @@ -272,23 +236,20 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] n = grid_size^2 μ = μ_const * ones(n) - J = chordal(Hermitian(Q, :L), P, S) - L = cholesky(J) - - gmrf = ChordalGMRF(μ, Q, L, P) + gmrf = ChordalGMRF(μ, Q) return logpdf(gmrf, z) end θ = [0.5, 0.1] grad_test = DifferentiationInterface.gradient( - θ -> test_grid_pipeline(θ, z, P, S, grid_size), + θ -> test_grid_pipeline(θ, z, grid_size), backend, θ ) grad_fd = DifferentiationInterface.gradient( - θ -> test_grid_pipeline(θ, z, P, S, grid_size), + θ -> test_grid_pipeline(θ, z, grid_size), fd_backend, θ ) @@ -314,18 +275,13 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] y = [2, 1, 3, 2, 1, 4, 2, 1, 2] # Poisson count data x = randn(n) .+ 0.5 - Q_ref = grid_precision(0.5, grid_size) - P, S = symbolic(Q_ref) - - function test_grid_gauss_approx(θ, y, x, P, S, grid_size) + function test_grid_gauss_approx(θ, y, x, grid_size) α, μ_const = θ Q = grid_precision(α, grid_size) n = grid_size^2 μ = μ_const * ones(n) - J = chordal(Hermitian(Q, :L), P, S) - L = cholesky(J) - prior_gmrf = ChordalGMRF(μ, Q, L, P) + prior_gmrf = ChordalGMRF(μ, Q) obs_model = ExponentialFamily(Poisson) obs_lik = obs_model(PoissonObservations(y)) @@ -336,13 +292,13 @@ backends = Any[("Zygote", AutoZygote()), ("ForwardDiff", AutoForwardDiff())] θ = [0.5, 0.3] grad_test = DifferentiationInterface.gradient( - θ -> test_grid_gauss_approx(θ, y, x, P, S, grid_size), + θ -> test_grid_gauss_approx(θ, y, x, grid_size), backend, θ ) grad_fd = DifferentiationInterface.gradient( - θ -> test_grid_gauss_approx(θ, y, x, P, S, grid_size), + θ -> test_grid_gauss_approx(θ, y, x, grid_size), fd_backend, θ ) diff --git a/test/gaussian_approximation/runtests.jl b/test/gaussian_approximation/runtests.jl index 406b1d3d..2abcb515 100644 --- a/test/gaussian_approximation/runtests.jl +++ b/test/gaussian_approximation/runtests.jl @@ -1 +1,2 @@ include("test_gaussian_approximation.jl") +include("test_gaussian_approximation_chordal.jl")