diff --git a/Project.toml b/Project.toml index 7fdae095..05ea47b5 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ LibGEOS = "a90b1aa1-3769-5649-ba7e-abc5a9d163eb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SelectedInversion = "043bf095-3f01-458a-9f1c-8cf4448fe908" @@ -30,6 +31,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4" [weakdeps] +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" @@ -40,7 +42,6 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] GaussianMarkovRandomFieldsAutoDiff = ["ForwardDiff", "Zygote"] @@ -57,7 +58,7 @@ GaussianMarkovRandomFieldsSparseJacobian = ["Symbolics", "SparseDiffTools"] AMD = "0.5" Aqua = "0.8" ChainRulesCore = "1" -CliqueTrees = "1.18.0" +CliqueTrees = "1.19.1" DataStructures = "0.14 - 0.19" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25" @@ -75,6 +76,7 @@ LinearAlgebra = "<0.0.1, 1" LinearMaps = "3.11" LinearSolve = "2, 3" Makie = "0.19 - 0.22" +Mooncake = "0.5.25" NearestNeighbors = "0.4" Pardiso = "1" Random = "<0.0.1, 1" @@ -92,7 +94,7 @@ StatsModels = "0.7" Symbolics = "5" Tensors = "1.16" Test = "<0.0.1, 1" -Zygote = "0.6" +Zygote = "0.6, 0.7" julia = "1.10" [extras] diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 07b751eb..cc7061d0 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -1,5 +1,6 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -7,6 +8,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" GaussianMarkovRandomFields = "d5f06795-35bb-4323-9f0b-405ef76cfc5b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" +MatrixDepot = "b51810bb-c9f3-55da-ae3c-350fc1fbce05" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/benchmarks/autodiff_comparison.jl b/benchmarks/autodiff_comparison.jl index a5b255ff..34ce81e2 100644 --- a/benchmarks/autodiff_comparison.jl +++ b/benchmarks/autodiff_comparison.jl @@ -20,7 +20,9 @@ using LinearSolve using Printf using Random -using Zygote, Enzyme, FiniteDiff +using Zygote, Enzyme, FiniteDiff, Mooncake + +using CliqueTrees.Multifrontal: symbolic, chordal println("="^80) println("AUTODIFF BACKEND COMPARISON: HIGH-DIMENSIONAL HYPERPARAMETER SPACE") @@ -37,7 +39,7 @@ println(" - $n mean parameters (one per time point)") println(" - 1 precision parameter (τ)") # Workflow: θ → GMRF → gaussian_approximation → logpdf -function benchmark_workflow(θ::Vector{Float64}, y::Vector{Int}, x_eval::Vector{Float64}) +function benchmark_workflow(θ::Vector{Float64}, y::PoissonObservations, x_eval::Vector{Float64}) # Extract hyperparameters μ = θ[1:n] # Mean field (100 params) log_τ = θ[n + 1] # Log precision @@ -61,6 +63,31 @@ function benchmark_workflow(θ::Vector{Float64}, y::Vector{Int}, x_eval::Vector{ return logpdf(posterior, x_eval) end +# ChordalGMRF workflow (only supports Mooncake) +function benchmark_workflow_chordal(θ::Vector{Float64}, y::PoissonObservations, x_eval::Vector{Float64}) + # Extract hyperparameters + μ = θ[1:n] # Mean field (100 params) + log_τ = θ[n + 1] # Log precision + τ = exp(log_τ) + + # Build precision matrix using RW1 latent model + rw1_model = RW1Model(n) + Q = sparse(precision_matrix(rw1_model; τ = τ)) + + # Prior ChordalGMRF with custom mean + prior = ChordalGMRF(μ, Q) + + # Poisson observation likelihood + obs_model = ExponentialFamily(Poisson) + obs_lik = obs_model(y) + + # Gaussian approximation + posterior = gaussian_approximation(prior, obs_lik) + + # Evaluate log-density + return logpdf(posterior, x_eval) +end + # Generate test data println("\nGenerating test data...") Random.seed!(123) @@ -73,19 +100,24 @@ Random.seed!(123) # Simulate observations (Poisson counts from smooth latent field) x_true = μ_true .+ cumsum(randn(n)) .* sqrt(1 / τ_true) .* 0.5 x_true .-= mean(x_true) # Center -y_obs = rand.(Poisson.(exp.(x_true .+ 0.5))) +y_counts = rand.(Poisson.(exp.(x_true .+ 0.5))) +y_obs = PoissonObservations(y_counts) x_eval = randn(n) .+ 0.3 # Initial parameter values (perturbed from truth) θ_init = θ_true .+ randn(n_params) .* 0.1 -println(" ✓ Generated $(length(y_obs)) Poisson observations") +println(" ✓ Generated $(length(y_counts)) Poisson observations") println(" ✓ Initial parameters: $(n_params)-dimensional vector") -# Verify workflow works -println("\nVerifying workflow...") +# Verify workflows work +println("\nVerifying workflows...") f_val = benchmark_workflow(θ_init, y_obs, x_eval) -println(" ✓ Function value: $(@sprintf("%.4f", f_val))") +println(" ✓ GMRF function value: $(@sprintf("%.4f", f_val))") + +f_val_chordal = benchmark_workflow_chordal(θ_init, y_obs, x_eval) +println(" ✓ ChordalGMRF function value: $(@sprintf("%.4f", f_val_chordal))") +println(" ✓ Difference: $(@sprintf("%.2e", abs(f_val - f_val_chordal)))") # Define backends backends = [ @@ -95,7 +127,7 @@ backends = [ ] println("\n" * "="^80) -println("BENCHMARKING GRADIENT COMPUTATION") +println("BENCHMARKING GRADIENT COMPUTATION (via DifferentiationInterface, prepared)") println("="^80) results = Dict() @@ -105,22 +137,22 @@ for (name, backend) in backends println("-"^40) try - # Warmup - print(" Warming up... ") - grad = DifferentiationInterface.gradient( - θ -> benchmark_workflow(θ, y_obs, x_eval), - backend, - θ_init - ) + # Define loss function + loss = θ -> benchmark_workflow(θ, y_obs, x_eval) + + # Prepare gradient (includes warmup/compilation) + print(" Preparing... ") + prep = DifferentiationInterface.prepare_gradient(loss, backend, θ_init) + println("✓") + + # Compute gradient once + print(" Computing gradient... ") + grad = DifferentiationInterface.gradient(loss, prep, backend, θ_init) println("✓") - # Benchmark + # Benchmark with prepared gradient print(" Benchmarking... ") - bench = @benchmark DifferentiationInterface.gradient( - θ -> benchmark_workflow(θ, y_obs, x_eval), - $backend, - $θ_init - ) samples = 10 seconds = 30 + bench = @benchmark DifferentiationInterface.gradient($loss, $prep, $backend, $θ_init) samples = 10 seconds = 30 results[name] = ( gradient = grad, @@ -135,23 +167,72 @@ for (name, backend) in backends println(" Memory: $(@sprintf("%.2f", bench.memory / 1.0e6)) MB") catch e - println(" ✗ Failed: $e") + println(" ✗ Failed: $(typeof(e).name.name)") + if e isa ErrorException + println(" $(first(split(e.msg, '\n')))") + end results[name] = nothing end end +# ChordalGMRF benchmark (Mooncake only) +println("\n" * "="^80) +println("BENCHMARKING ChordalGMRF (Mooncake only)") +println("="^80) + +println("\nChordalGMRF + Mooncake:") +println("-"^40) + +try + # Define loss function + loss_chordal = θ -> benchmark_workflow_chordal(θ, y_obs, x_eval) + + # Prepare gradient (includes warmup/compilation) + print(" Preparing... ") + prep_chordal = DifferentiationInterface.prepare_gradient(loss_chordal, AutoMooncake(; config=nothing), θ_init) + println("✓") + + # Compute gradient once + print(" Computing gradient... ") + grad_chordal = DifferentiationInterface.gradient(loss_chordal, prep_chordal, AutoMooncake(; config=nothing), θ_init) + println("✓") + + # Benchmark with prepared gradient + print(" Benchmarking... ") + bench_chordal = @benchmark DifferentiationInterface.gradient($loss_chordal, $prep_chordal, AutoMooncake(; config=nothing), $θ_init) samples = 10 seconds = 30 + + results["ChordalGMRF+Mooncake"] = ( + gradient = grad_chordal, + time = minimum(bench_chordal.times) / 1.0e6, + bench = bench_chordal, + ) + + println("✓") + println(" Time (min): $(@sprintf("%.2f", results["ChordalGMRF+Mooncake"].time)) ms") + println(" Time (median): $(@sprintf("%.2f", median(bench_chordal.times) / 1.0e6)) ms") + println(" Allocations: $(bench_chordal.allocs)") + println(" Memory: $(@sprintf("%.2f", bench_chordal.memory / 1.0e6)) MB") + +catch e + println(" ✗ Failed: $(typeof(e).name.name)") + if e isa ErrorException + println(" $(first(split(e.msg, '\n')))") + end + results["ChordalGMRF+Mooncake"] = nothing +end + # Summary comparison println("\n" * "="^80) println("SUMMARY") println("="^80) -if all(v !== nothing for v in values(results)) +if results["FiniteDiff"] !== nothing # Verify gradients match println("\nGradient verification (comparing to FiniteDiff):") fd_grad = results["FiniteDiff"].gradient - for name in ["Zygote", "Enzyme"] - if results[name] !== nothing + for name in ["Zygote", "Enzyme", "ChordalGMRF+Mooncake"] + if get(results, name, nothing) !== nothing grad = results[name].gradient abs_error = abs.(grad - fd_grad) max_error = maximum(abs_error) @@ -164,7 +245,7 @@ if all(v !== nothing for v in values(results)) end # Performance comparison table - println("\nPerformance comparison:") + println("\nPerformance comparison (GMRF backends):") println(" " * "─"^76) println(@sprintf(" %-20s %12s %12s %12s %12s", "Backend", "Time (ms)", "Speedup", "Allocs", "Memory (MB)")) println(" " * "─"^76) @@ -191,10 +272,46 @@ if all(v !== nothing for v in values(results)) enzyme_vs_zygote = results["Zygote"].time / results["Enzyme"].time winner = enzyme_vs_zygote > 1 ? "Enzyme" : "Zygote" ratio = max(enzyme_vs_zygote, 1.0 / enzyme_vs_zygote) - println("\n 🏆 $winner is fastest: $(@sprintf("%.1f", ratio))× faster than the other") + println("\n GMRF winner: $winner ($(@sprintf("%.1f", ratio))× faster)") end end +# ChordalGMRF vs GMRF comparison (Zygote only) +if get(results, "ChordalGMRF+Mooncake", nothing) !== nothing && get(results, "Zygote", nothing) !== nothing + println("\n" * "="^80) + println("GMRF vs ChordalGMRF COMPARISON (Zygote vs Mooncake)") + println("="^80) + + r_gmrf = results["Zygote"] + r_chordal = results["ChordalGMRF+Mooncake"] + + println("\n " * "─"^76) + println(@sprintf(" %-20s %12s %12s %12s %12s", "Implementation", "Time (ms)", "Speedup", "Allocs", "Memory (MB)")) + println(" " * "─"^76) + + println( + @sprintf( + " %-20s %12.2f %12s %12d %12.2f", + "GMRF", r_gmrf.time, "1.0×", r_gmrf.bench.allocs, r_gmrf.bench.memory / 1.0e6 + ) + ) + + chordal_speedup = r_gmrf.time / r_chordal.time + println( + @sprintf( + " %-20s %12.2f %12s %12d %12.2f", + "ChordalGMRF", r_chordal.time, @sprintf("%.1f×", chordal_speedup), + r_chordal.bench.allocs, r_chordal.bench.memory / 1.0e6 + ) + ) + + println(" " * "─"^76) + + winner = chordal_speedup > 1 ? "ChordalGMRF" : "GMRF" + ratio = max(chordal_speedup, 1.0 / chordal_speedup) + println("\n Winner: $winner ($(@sprintf("%.1f", ratio))× faster)") +end + println("\n" * "="^80) println("BENCHMARK COMPLETE") println("="^80) diff --git a/benchmarks/gaussian_approximation_comparison.jl b/benchmarks/gaussian_approximation_comparison.jl new file mode 100644 index 00000000..26e015af --- /dev/null +++ b/benchmarks/gaussian_approximation_comparison.jl @@ -0,0 +1,263 @@ +#!/usr/bin/env julia +""" +Benchmark: Compare GMRF vs ChordalGMRF for gaussian_approximation + +Tests both correctness (results match) and performance on SSMC matrices +with Poisson observation likelihoods. + +Usage: + cd benchmarks + julia --project=. gaussian_approximation_comparison.jl +""" + +using GaussianMarkovRandomFields +using BenchmarkTools +using Distributions: logpdf, Poisson +using SparseArrays +using LinearAlgebra +using LinearSolve +using Printf +using Random +using MatrixDepot +using Zygote, Mooncake +using DifferentiationInterface: DifferentiationInterface, AutoZygote, AutoMooncake + +println("="^80) +println("GAUSSIAN APPROXIMATION COMPARISON: GMRF vs ChordalGMRF") +println("="^80) + +# Helper to make a matrix positive definite +function make_posdef(A::SparseMatrixCSC) + # Symmetrize and add diagonal dominance + S = (A + A') / 2 + d = vec(sum(abs, S; dims = 2)) + return S + spdiagm(0 => d .+ 1.0) +end + +# Handle Symmetric wrapper from MatrixDepot +make_posdef(A::Symmetric) = make_posdef(sparse(A)) + +# Test matrices from SSMC (larger for meaningful benchmarks) +test_matrices = [ + ("HB/bcsstk15", "Structural, n=3948"), + ("HB/bcsstk16", "Structural, n=4884"), + ("HB/bcsstk17", "Structural, n=10974"), + ("HB/bcsstk18", "Structural, n=11948"), +] + +println("\nTest matrices:") +for (name, desc) in test_matrices + println(" - $name ($desc)") +end + +results = [] + +for (matrix_name, desc) in test_matrices + println("\n" * "="^80) + println("Matrix: $matrix_name ($desc)") + println("="^80) + + # Load and prepare matrix + try + A_raw = matrixdepot(matrix_name) + Q = make_posdef(A_raw) + n = size(Q, 1) + + println(" Size: $n × $n") + println(" Nonzeros: $(nnz(Q))") + + # Create mean vector and synthetic Poisson observations + Random.seed!(42) + μ = zeros(n) + + # Generate Poisson counts (moderate values to avoid numerical issues) + latent = randn(n) * 0.5 + y_counts = rand.(Poisson.(exp.(latent .+ 1.0))) + y = PoissonObservations(y_counts) + + # Create observation likelihood + obs_model = ExponentialFamily(Poisson) + obs_lik = obs_model(y) + + # Create GMRF prior (baseline) + println("\n Creating GMRF prior...") + gmrf_prior = GMRF(μ, Q, LinearSolve.CHOLMODFactorization()) + + # Create ChordalGMRF prior + println(" Creating ChordalGMRF prior...") + chordal_prior = ChordalGMRF(μ, Q) + + # Run gaussian_approximation + println("\n Running gaussian_approximation...") + posterior_gmrf = gaussian_approximation(gmrf_prior, obs_lik) + posterior_chordal = gaussian_approximation(chordal_prior, obs_lik) + + # Correctness check + println("\n Correctness check:") + mean_gmrf = mean(posterior_gmrf) + mean_chordal = mean(posterior_chordal) + Q_gmrf = precision_matrix(posterior_gmrf) + Q_chordal = precision_matrix(posterior_chordal) + + mean_diff = norm(mean_gmrf - mean_chordal) + mean_rel_diff = mean_diff / (norm(mean_gmrf) + 1.0e-10) + Q_diff = norm(Q_gmrf - Q_chordal) + Q_rel_diff = Q_diff / (norm(Q_gmrf) + 1.0e-10) + + println(" Mean abs diff: $(@sprintf("%.2e", mean_diff))") + println(" Mean rel diff: $(@sprintf("%.2e", mean_rel_diff))") + println(" Precision abs diff: $(@sprintf("%.2e", Q_diff))") + println(" Precision rel diff: $(@sprintf("%.2e", Q_rel_diff))") + + correct = mean_rel_diff < 1.0e-6 && Q_rel_diff < 1.0e-6 + println(" Match: $(correct ? "✓ YES" : "✗ NO")") + + # Performance benchmark + println("\n Performance benchmark:") + + # Benchmark GMRF + print(" GMRF... ") + bench_gmrf = @benchmark gaussian_approximation($gmrf_prior, $obs_lik) samples = 10 seconds = 10 + time_gmrf = minimum(bench_gmrf.times) / 1.0e6 + println("$(@sprintf("%.3f", time_gmrf)) ms") + + # Benchmark ChordalGMRF + print(" ChordalGMRF... ") + bench_chordal = @benchmark gaussian_approximation($chordal_prior, $obs_lik) samples = 10 seconds = 10 + time_chordal = minimum(bench_chordal.times) / 1.0e6 + println("$(@sprintf("%.3f", time_chordal)) ms") + + speedup = time_gmrf / time_chordal + println(" Speedup: $(@sprintf("%.2f", speedup))×") + + # Gradient correctness check (gradient of sum(posterior mean) w.r.t. prior mean) + println("\n Gradient correctness check (w.r.t. prior mean):") + + function loss_gmrf(μ_prior) + prior = GMRF(μ_prior, Q, LinearSolve.CHOLMODFactorization()) + post = gaussian_approximation(prior, obs_lik) + return sum(mean(post)) + end + + function loss_chordal(μ_prior) + prior = ChordalGMRF(μ_prior, Q) + post = gaussian_approximation(prior, obs_lik) + return sum(mean(post)) + end + + # Use prepared gradients for both backends + prep_gmrf = DifferentiationInterface.prepare_gradient(loss_gmrf, AutoZygote(), μ) + grad_gmrf = DifferentiationInterface.gradient(loss_gmrf, prep_gmrf, AutoZygote(), μ) + prep_chordal = DifferentiationInterface.prepare_gradient(loss_chordal, AutoMooncake(; config=nothing), μ) + grad_chordal = DifferentiationInterface.gradient(loss_chordal, prep_chordal, AutoMooncake(; config=nothing), μ) + grad_abs_diff = norm(grad_gmrf - grad_chordal) + grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1.0e-10) + + println(" Absolute diff: $(@sprintf("%.2e", grad_abs_diff))") + println(" Relative diff: $(@sprintf("%.2e", grad_rel_diff))") + + grad_correct = grad_rel_diff < 1.0e-6 + println(" Match: $(grad_correct ? "✓ YES" : "✗ NO")") + + # Gradient performance benchmark + println("\n Gradient performance benchmark (via DifferentiationInterface, prepared):") + + print(" GMRF (Zygote)... ") + bench_grad_gmrf = @benchmark DifferentiationInterface.gradient($loss_gmrf, $prep_gmrf, AutoZygote(), $μ) samples = 10 seconds = 10 + time_grad_gmrf = minimum(bench_grad_gmrf.times) / 1.0e6 + println("$(@sprintf("%.3f", time_grad_gmrf)) ms") + + print(" ChordalGMRF (Mooncake)... ") + bench_grad_chordal = @benchmark DifferentiationInterface.gradient($loss_chordal, $prep_chordal, AutoMooncake(; config=nothing), $μ) samples = 10 seconds = 10 + time_grad_chordal = minimum(bench_grad_chordal.times) / 1.0e6 + println("$(@sprintf("%.3f", time_grad_chordal)) ms") + + grad_speedup = time_grad_gmrf / time_grad_chordal + println(" Speedup: $(@sprintf("%.2f", grad_speedup))×") + + push!( + results, ( + name = matrix_name, + n = n, + nnz = nnz(Q), + correct = correct, + grad_correct = grad_correct, + time_gmrf = time_gmrf, + time_chordal = time_chordal, + speedup = speedup, + time_grad_gmrf = time_grad_gmrf, + time_grad_chordal = time_grad_chordal, + grad_speedup = grad_speedup, + ) + ) + + catch e + println(" ✗ Failed: $(typeof(e).name.name): $(sprint(showerror, e; context = :limit => true))") + push!( + results, ( + name = matrix_name, n = 0, nnz = 0, correct = false, grad_correct = false, + time_gmrf = NaN, time_chordal = NaN, speedup = NaN, + time_grad_gmrf = NaN, time_grad_chordal = NaN, grad_speedup = NaN, + ) + ) + end +end + +# Summary table +println("\n" * "="^80) +println("SUMMARY: FORWARD PASS") +println("="^80) + +println("\n" * "-"^95) +@printf( + "%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup" +) +println("-"^95) + +for r in results + correct_str = r.correct ? "✓" : "✗" + @printf( + "%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_gmrf, r.time_chordal, r.speedup + ) +end +println("-"^95) + +# Gradient summary table +println("\n" * "="^80) +println("SUMMARY: GRADIENT (via DifferentiationInterface, prepared)") +println("="^80) + +println("\n" * "-"^95) +@printf( + "%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup" +) +println("-"^95) + +for r in results + correct_str = r.grad_correct ? "✓" : "✗" + @printf( + "%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_grad_gmrf, r.time_grad_chordal, r.grad_speedup + ) +end +println("-"^95) + +# Overall stats +valid_results = filter(r -> !isnan(r.speedup), results) +if !isempty(valid_results) + avg_speedup = sum(r.speedup for r in valid_results) / length(valid_results) + avg_grad_speedup = sum(r.grad_speedup for r in valid_results) / length(valid_results) + all_correct = all(r.correct for r in valid_results) + all_grad_correct = all(r.grad_correct for r in valid_results) + + println("\nOverall:") + println(" Forward - All match: $(all_correct ? "✓ YES" : "✗ NO"), Avg speedup: $(@sprintf("%.2f", avg_speedup))×") + println(" Gradient - All match: $(all_grad_correct ? "✓ YES" : "✗ NO"), Avg speedup: $(@sprintf("%.2f", avg_grad_speedup))×") +end + +println("\n" * "="^80) +println("BENCHMARK COMPLETE") +println("="^80) diff --git a/benchmarks/logpdf_comparison.jl b/benchmarks/logpdf_comparison.jl new file mode 100644 index 00000000..6bdea847 --- /dev/null +++ b/benchmarks/logpdf_comparison.jl @@ -0,0 +1,235 @@ +#!/usr/bin/env julia +""" +Benchmark: Compare GMRF vs ChordalGMRF for logpdf computation + +Tests both correctness (results match) and performance on SSMC matrices. + +Usage: + cd benchmarks + julia --project=. logpdf_comparison.jl +""" + +using GaussianMarkovRandomFields +using BenchmarkTools +using Distributions: logpdf +using SparseArrays +using LinearAlgebra +using LinearSolve +using Printf +using Random +using MatrixDepot +using Zygote, Mooncake +using DifferentiationInterface: DifferentiationInterface, AutoZygote, AutoMooncake + +println("="^80) +println("LOGPDF COMPARISON: GMRF vs ChordalGMRF") +println("="^80) + +# Helper to make a matrix positive definite +function make_posdef(A::SparseMatrixCSC) + # Symmetrize and add diagonal dominance + S = (A + A') / 2 + d = vec(sum(abs, S; dims = 2)) + return S + spdiagm(0 => d .+ 1.0) +end + +# Handle Symmetric wrapper from MatrixDepot +make_posdef(A::Symmetric) = make_posdef(sparse(A)) + +# Test matrices from SSMC (larger for meaningful benchmarks) +test_matrices = [ + ("HB/bcsstk15", "Structural, n=3948"), + ("HB/bcsstk16", "Structural, n=4884"), + ("HB/bcsstk17", "Structural, n=10974"), + ("HB/bcsstk18", "Structural, n=11948"), +] + +println("\nTest matrices:") +for (name, desc) in test_matrices + println(" - $name ($desc)") +end + +results = [] + +for (matrix_name, desc) in test_matrices + println("\n" * "="^80) + println("Matrix: $matrix_name ($desc)") + println("="^80) + + # Load and prepare matrix + try + A_raw = matrixdepot(matrix_name) + Q = make_posdef(A_raw) + n = size(Q, 1) + + println(" Size: $n × $n") + println(" Nonzeros: $(nnz(Q))") + + # Create mean vector and evaluation point + Random.seed!(42) + μ = randn(n) + z = randn(n) + + # Create GMRF (baseline) + println("\n Creating GMRF...") + gmrf = GMRF(μ, Q, LinearSolve.CHOLMODFactorization()) + + # Create ChordalGMRF + println(" Creating ChordalGMRF...") + chordal_gmrf = ChordalGMRF(μ, Q) + + # Correctness check + println("\n Correctness check:") + lpdf_gmrf = logpdf(gmrf, z) + lpdf_chordal = logpdf(chordal_gmrf, z) + abs_diff = abs(lpdf_gmrf - lpdf_chordal) + rel_diff = abs_diff / (abs(lpdf_gmrf) + 1.0e-10) + + println(" GMRF logpdf: $(@sprintf("%.8f", lpdf_gmrf))") + println(" ChordalGMRF logpdf: $(@sprintf("%.8f", lpdf_chordal))") + println(" Absolute diff: $(@sprintf("%.2e", abs_diff))") + println(" Relative diff: $(@sprintf("%.2e", rel_diff))") + + correct = rel_diff < 1.0e-8 + println(" Match: $(correct ? "✓ YES" : "✗ NO")") + + # Performance benchmark + println("\n Performance benchmark:") + + # Benchmark GMRF + print(" GMRF... ") + bench_gmrf = @benchmark logpdf($gmrf, $z) samples = 20 seconds = 5 + time_gmrf = minimum(bench_gmrf.times) / 1.0e6 + println("$(@sprintf("%.3f", time_gmrf)) ms") + + # Benchmark ChordalGMRF + print(" ChordalGMRF... ") + bench_chordal = @benchmark logpdf($chordal_gmrf, $z) samples = 20 seconds = 5 + time_chordal = minimum(bench_chordal.times) / 1.0e6 + println("$(@sprintf("%.3f", time_chordal)) ms") + + speedup = time_gmrf / time_chordal + println(" Speedup: $(@sprintf("%.2f", speedup))×") + + # Gradient correctness check + println("\n Gradient correctness check (w.r.t. z):") + gmrf_logpdf = x -> logpdf(gmrf, x) + chordal_logpdf = x -> logpdf(chordal_gmrf, x) + + # Use prepared gradients for both backends + prep_gmrf = DifferentiationInterface.prepare_gradient(gmrf_logpdf, AutoZygote(), z) + prep_chordal = DifferentiationInterface.prepare_gradient(chordal_logpdf, AutoMooncake(; config=nothing), z) + + grad_gmrf = DifferentiationInterface.gradient(gmrf_logpdf, prep_gmrf, AutoZygote(), z) + grad_chordal = DifferentiationInterface.gradient(chordal_logpdf, prep_chordal, AutoMooncake(; config=nothing), z) + grad_abs_diff = norm(grad_gmrf - grad_chordal) + grad_rel_diff = grad_abs_diff / (norm(grad_gmrf) + 1.0e-10) + + println(" Absolute diff: $(@sprintf("%.2e", grad_abs_diff))") + println(" Relative diff: $(@sprintf("%.2e", grad_rel_diff))") + + grad_correct = grad_rel_diff < 1.0e-8 + println(" Match: $(grad_correct ? "✓ YES" : "✗ NO")") + + # Gradient performance benchmark + println("\n Gradient performance benchmark (via DifferentiationInterface, prepared):") + + print(" GMRF... ") + bench_grad_gmrf = @benchmark DifferentiationInterface.gradient($gmrf_logpdf, $prep_gmrf, AutoZygote(), $z) samples = 20 seconds = 5 + time_grad_gmrf = minimum(bench_grad_gmrf.times) / 1.0e6 + println("$(@sprintf("%.3f", time_grad_gmrf)) ms") + + print(" ChordalGMRF... ") + bench_grad_chordal = @benchmark DifferentiationInterface.gradient($chordal_logpdf, $prep_chordal, AutoMooncake(; config=nothing), $z) samples = 20 seconds = 5 + time_grad_chordal = minimum(bench_grad_chordal.times) / 1.0e6 + println("$(@sprintf("%.3f", time_grad_chordal)) ms") + + grad_speedup = time_grad_gmrf / time_grad_chordal + println(" Speedup: $(@sprintf("%.2f", grad_speedup))×") + + push!( + results, ( + name = matrix_name, + n = n, + nnz = nnz(Q), + correct = correct, + grad_correct = grad_correct, + time_gmrf = time_gmrf, + time_chordal = time_chordal, + speedup = speedup, + time_grad_gmrf = time_grad_gmrf, + time_grad_chordal = time_grad_chordal, + grad_speedup = grad_speedup, + ) + ) + + catch e + println(" ✗ Failed: $(typeof(e).name.name): $(sprint(showerror, e; context = :limit => true))") + push!( + results, ( + name = matrix_name, n = 0, nnz = 0, correct = false, grad_correct = false, + time_gmrf = NaN, time_chordal = NaN, speedup = NaN, + time_grad_gmrf = NaN, time_grad_chordal = NaN, grad_speedup = NaN, + ) + ) + end +end + +# Summary table +println("\n" * "="^80) +println("SUMMARY: FORWARD PASS") +println("="^80) + +println("\n" * "-"^95) +@printf( + "%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup" +) +println("-"^95) + +for r in results + correct_str = r.correct ? "✓" : "✗" + @printf( + "%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_gmrf, r.time_chordal, r.speedup + ) +end +println("-"^95) + +# Gradient summary table +println("\n" * "="^80) +println("SUMMARY: GRADIENT (via DifferentiationInterface, prepared)") +println("="^80) + +println("\n" * "-"^95) +@printf( + "%-20s %8s %10s %8s %12s %12s %10s\n", + "Matrix", "n", "nnz", "Correct", "GMRF (ms)", "Chordal (ms)", "Speedup" +) +println("-"^95) + +for r in results + correct_str = r.grad_correct ? "✓" : "✗" + @printf( + "%-20s %8d %10d %8s %12.3f %12.3f %10.2f×\n", + r.name, r.n, r.nnz, correct_str, r.time_grad_gmrf, r.time_grad_chordal, r.grad_speedup + ) +end +println("-"^95) + +# Overall stats +valid_results = filter(r -> !isnan(r.speedup), results) +if !isempty(valid_results) + avg_speedup = sum(r.speedup for r in valid_results) / length(valid_results) + avg_grad_speedup = sum(r.grad_speedup for r in valid_results) / length(valid_results) + all_correct = all(r.correct for r in valid_results) + all_grad_correct = all(r.grad_correct for r in valid_results) + + println("\nOverall:") + println(" Forward - All match: $(all_correct ? "✓ YES" : "✗ NO"), Avg speedup: $(@sprintf("%.2f", avg_speedup))×") + println(" Gradient - All match: $(all_grad_correct ? "✓ YES" : "✗ NO"), Avg speedup: $(@sprintf("%.2f", avg_grad_speedup))×") +end + +println("\n" * "="^80) +println("BENCHMARK COMPLETE") +println("="^80) diff --git a/deps/MooncakeSparse b/deps/MooncakeSparse new file mode 160000 index 00000000..6b1b2efe --- /dev/null +++ b/deps/MooncakeSparse @@ -0,0 +1 @@ +Subproject commit 6b1b2efe83b19b5a90a8c5fcd6e760b1a6269e97 diff --git a/src/GaussianMarkovRandomFields.jl b/src/GaussianMarkovRandomFields.jl index 03d68c05..ff225f3d 100644 --- a/src/GaussianMarkovRandomFields.jl +++ b/src/GaussianMarkovRandomFields.jl @@ -1,10 +1,12 @@ module GaussianMarkovRandomFields +include("../deps/MooncakeSparse/MooncakeSparse.jl") include("typedefs.jl") include("utils/utils.jl") include("linear_maps/linear_maps.jl") include("preconditioners/preconditioners.jl") include("gmrf.jl") +include("chordal_gmrf.jl") include("metagmrf.jl") include("solvers/solvers.jl") include("autoregressive/autoregressive.jl") diff --git a/src/arithmetic/condition/gaussian_approximation.jl b/src/arithmetic/condition/gaussian_approximation.jl index 4a9b32b9..37f53db8 100644 --- a/src/arithmetic/condition/gaussian_approximation.jl +++ b/src/arithmetic/condition/gaussian_approximation.jl @@ -1,9 +1,13 @@ using LinearAlgebra using SparseArrays using LinearMaps +using CliqueTrees.Multifrontal: chordal, ChordalCholesky, triangular export gaussian_approximation +# Sparse-preserving subtraction for Hermitian matrices +hermdiff(A::Hermitian, B) = Hermitian(parent(A) - B, Symbol(A.uplo)) + function neg_log_posterior(prior_gmrf::AbstractGMRF, obs_lik::ObservationLikelihood, x) return -logpdf(prior_gmrf, x) - loglik(x, obs_lik) end @@ -19,6 +23,7 @@ end # Private dispatch: extract constraints if present _extract_constraints(::GMRF) = nothing _extract_constraints(constrained::ConstrainedGMRF) = (A = constrained.constraint_matrix, e = constrained.constraint_vector) +_extract_constraints(::ChordalGMRF) = nothing # Private dispatch: apply constraints to result _apply_constraints(gmrf::GMRF, ::Nothing) = gmrf @@ -27,6 +32,7 @@ _apply_constraints(gmrf::GMRF, constraints::NamedTuple) = ConstrainedGMRF(gmrf, # Private dispatch: extract base GMRF for optimization _base_gmrf(gmrf::GMRF) = gmrf _base_gmrf(constrained::ConstrainedGMRF) = constrained.base_gmrf +_base_gmrf(gmrf::ChordalGMRF) = gmrf # Compute the constrained Newton step via the KKT Schur complement. # Solves H⁻¹Aᵀ using the existing linsolve cache (m sparse solves, m = #constraints), @@ -69,6 +75,34 @@ function _update_linsolve_cache_inner!(cache, Q, alg::LinearSolve.LDLtFactorizat return cache.A = copy(Q) end +# Solver abstraction for gaussian_approximation Newton iteration. +# Allows shared iteration logic for both LinearSolve-backed GMRF and ChordalCholesky-backed ChordalGMRF. +_ga_init_solver(gmrf::GMRF) = deepcopy(linsolve_cache(gmrf)) +_ga_init_solver(gmrf::ChordalGMRF) = copy(gmrf.F) + +function _ga_update_and_solve!(solver, Q_base, H_k, b, ::GMRF) + Q_new = prepare_for_linsolve(Q_base - H_k, solver.alg) + _update_linsolve_cache!(solver, Q_new) + solver.b = b + return Q_new, copy(solve!(solver).u) +end + +function _ga_update_and_solve!(solver, Q_base, H_k, b, ::ChordalGMRF) + Q_new = hermdiff(Q_base, H_k) + copyto!(solver, Q_new) + cholesky!(solver) + return Q_new, solver \ b +end + +function _ga_make_posterior(x, Q, solver, prior::Union{GMRF, ConstrainedGMRF}, constraints) + new_gmrf = GMRF(x, Q; linsolve_cache = solver) + return _apply_constraints(new_gmrf, constraints) +end + +function _ga_make_posterior(x, Q, solver, prior::ChordalGMRF, ::Nothing) + return ChordalGMRF(x, Q, solver) +end + """ gaussian_approximation(prior_gmrf, obs_lik; kwargs...) -> AbstractGMRF @@ -77,11 +111,11 @@ Find Gaussian approximation to the posterior using Fisher scoring. This function finds the mode of the posterior distribution and constructs a Gaussian approximation around it using Fisher scoring (Newton-Raphson with Fisher information matrix). -Works for both regular `GMRF` and `ConstrainedGMRF` priors, automatically handling +Works for `GMRF`, `ConstrainedGMRF`, and `ChordalGMRF` priors, automatically handling constraint projection when needed. # Arguments -- `prior_gmrf`: Prior GMRF distribution for the latent field (GMRF or ConstrainedGMRF) +- `prior_gmrf`: Prior GMRF distribution for the latent field (GMRF, ConstrainedGMRF, or ChordalGMRF) - `obs_lik`: Materialized observation likelihood (contains data and hyperparameters) # Keyword Arguments @@ -110,7 +144,7 @@ posterior_gmrf = gaussian_approximation(prior_gmrf, obs_lik; adaptive_stepsize=f ``` """ function gaussian_approximation( - prior_gmrf::Union{GMRF, ConstrainedGMRF}, + prior_gmrf::Union{GMRF, ConstrainedGMRF, ChordalGMRF}, obs_lik::ObservationLikelihood; x0::Union{Nothing, AbstractVector} = nothing, max_iter::Int = 50, @@ -120,14 +154,14 @@ function gaussian_approximation( max_linesearch_iter::Int = 10, verbose::Bool = false ) - # Extract base GMRF and constraints (nothing for regular GMRF) + # Extract base GMRF and constraints (nothing for GMRF/ChordalGMRF) base_gmrf = _base_gmrf(prior_gmrf) constraints = _extract_constraints(prior_gmrf) # Initialize with provided starting point or prior mean - x_k = x0 === nothing ? mean(prior_gmrf) : copy(x0) + x_k = x0 === nothing ? copy(mean(prior_gmrf)) : copy(x0) - cache = deepcopy(linsolve_cache(base_gmrf)) + solver = _ga_init_solver(base_gmrf) Q_base = precision_matrix(base_gmrf) # Adaptive stepsize state (persists across outer iterations) @@ -136,84 +170,68 @@ function gaussian_approximation( verbose && println("Starting Fisher scoring...") for iter in 1:max_iter - # Compute observation likelihood derivatives at current point - H_k = loghessian(x_k, obs_lik) # Hessian: ∇²ₓ log p(y|x) - - # Update precision: Q_new = Q_base - H_k (note: H_k contains negative of Hessian) - Q_new = prepare_for_linsolve(Q_base - H_k, cache.alg) - - _update_linsolve_cache!(cache, Q_new) + H_k = loghessian(x_k, obs_lik) neg_score_k = ∇ₓ_neg_log_posterior(base_gmrf, obs_lik, x_k) - cache.b = neg_score_k - step = copy(solve!(cache).u) + Q_new, step = _ga_update_and_solve!(solver, Q_base, H_k, neg_score_k, base_gmrf) # For constrained problems, project step onto constraint tangent space - # via the KKT Schur complement (m sparse solves, m = #constraints). - # This ensures A*step = 0, so x_k - α*step stays on the manifold. - step = _constrain_step(step, cache, constraints) + # via the KKT Schur complement. No-op when constraints are nothing. + step = _constrain_step(step, solver, constraints) # Apply step with adaptive line search or full step if adaptive_stepsize obj_current = neg_log_posterior(base_gmrf, obs_lik, x_k) - step_accepted = false + accept = false for ls_iter in 1:max_linesearch_iter candidate = x_k - α * step obj_candidate = neg_log_posterior(base_gmrf, obs_lik, candidate) if obj_candidate <= obj_current - # Accept step, increase α toward 1 - μ_new = candidate + x_new = candidate α = sqrt(α) - step_accepted = true + accept = true verbose && ls_iter > 1 && println(" Accepted at α=$(round(α^2, digits = 3)) after $ls_iter backtracks") break else - # Reject, reduce stepsize α *= 0.1 verbose && println(" Backtrack: α=$(round(α, digits = 4))") if α * norm(step, Inf) < newton_dec_tol / 1000 - μ_new = candidate - step_accepted = true + x_new = candidate + accept = true break end end end - if !step_accepted - μ_new = x_k - α * step + if !accept + x_new = x_k - α * step end else - μ_new = x_k - step + x_new = x_k - step end - # Newton decrement: dot(g, constrained_step) = -g'Δx_nt = Δx_nt'HΔx_nt ≥ 0 newton_decrement = dot(neg_score_k, step) - - x_new = μ_new mean_change = norm(x_new - x_k) - mean_change_rel = mean_change / norm(x_k) + mean_change_rel = mean_change / max(norm(x_k), 1.0e-10) verbose && println(" Iter $iter: Newton dec = $(round(newton_decrement, sigdigits = 3)), α = $(round(α, digits = 3))") if (newton_decrement < newton_dec_tol) || (mean_change < mean_change_tol) || (mean_change_rel < mean_change_tol) verbose && println(" Converged after $iter iterations") - new_gmrf = GMRF(x_new, Q_new; linsolve_cache = cache) - return _apply_constraints(new_gmrf, constraints) + return _ga_make_posterior(x_new, Q_new, solver, prior_gmrf, constraints) end - # Update for next iteration x_k = x_new end verbose && println(" Reached max_iter = $max_iter without convergence") - # Return current best approximation + # Return current best approximation at final x_k H_k = loghessian(x_k, obs_lik) - Q_final = prepare_for_linsolve(Q_base - H_k, cache.alg) - cache.A = Q_final - final_gmrf = GMRF(x_k, Q_final; linsolve_cache = cache) - return _apply_constraints(final_gmrf, constraints) + neg_score_k = ∇ₓ_neg_log_posterior(base_gmrf, obs_lik, x_k) + Q_final, _ = _ga_update_and_solve!(solver, Q_base, H_k, neg_score_k, base_gmrf) + return _ga_make_posterior(x_k, Q_final, solver, prior_gmrf, constraints) end # Specialized dispatch for Normal observation likelihoods with identity link (conjugate prior case) diff --git a/src/autodiff/autodiff.jl b/src/autodiff/autodiff.jl index 4234f8d3..bb78555e 100644 --- a/src/autodiff/autodiff.jl +++ b/src/autodiff/autodiff.jl @@ -21,3 +21,6 @@ include("constructors.jl") include("precision_gradient.jl") # Helper for computing precision gradients include("logpdf.jl") include("gaussian_approximation.jl") + +# Mooncake-specific rules for ChordalGMRF +include("mooncake_gaussian_approximation.jl") diff --git a/src/autodiff/constructors.jl b/src/autodiff/constructors.jl index e80a6ba3..13ccfe60 100644 --- a/src/autodiff/constructors.jl +++ b/src/autodiff/constructors.jl @@ -2,6 +2,7 @@ using ChainRulesCore using SparseArrays using LinearAlgebra using LinearMaps +using CliqueTrees.Multifrontal: ChordalCholesky """ ChainRulesCore.rrule(::Type{GMRF}, μ::AbstractVector, Q::Union{AbstractMatrix, LinearMaps.LinearMap}, algorithm) @@ -122,3 +123,36 @@ function ChainRulesCore.rrule( return result, ConstrainedGMRF_pullback end + +""" + ChainRulesCore.rrule(::Type{ChordalGMRF}, μ::AbstractVector, Q::SparseMatrixCSC) + +Automatic differentiation rule for ChordalGMRF constructor. + +This rrule enables differentiation through ChordalGMRF construction, allowing gradients +to flow back to the mean vector and precision matrix. The factorization F is treated +as non-differentiable. +""" +function ChainRulesCore.rrule(::Type{ChordalGMRF}, μ::AbstractVector, Q::SparseMatrixCSC) + x = ChordalGMRF(μ, Q) + + function ChordalGMRF_pullback(x̄) + μ̄ = x̄.μ + Q̄ = x̄.Q + return NoTangent(), μ̄, Q̄ + end + + return x, ChordalGMRF_pullback +end + +function ChainRulesCore.rrule(::Type{ChordalGMRF}, μ::AbstractVector, Q::SparseMatrixCSC, F::ChordalCholesky) + x = ChordalGMRF(μ, Q, F) + + function ChordalGMRF_pullback(x̄) + μ̄ = x̄.μ + Q̄ = x̄.Q + return NoTangent(), μ̄, Q̄, NoTangent() + end + + return x, ChordalGMRF_pullback +end diff --git a/src/autodiff/gaussian_approximation.jl b/src/autodiff/gaussian_approximation.jl index 611b28ff..a8c6cc4e 100644 --- a/src/autodiff/gaussian_approximation.jl +++ b/src/autodiff/gaussian_approximation.jl @@ -205,6 +205,7 @@ end # Wrap a base GMRF tangent in a ConstrainedGMRF tangent if the prior is constrained. _wrap_prior_tangent(base_tangent, ::GMRF) = base_tangent +_wrap_prior_tangent(base_tangent, ::ChordalGMRF) = base_tangent function _wrap_prior_tangent(base_tangent, prior::ConstrainedGMRF) return Tangent{typeof(prior)}(; base_gmrf = base_tangent, @@ -217,9 +218,38 @@ function _wrap_prior_tangent(base_tangent, prior::ConstrainedGMRF) ) end +# Combine x* tangent contributions from the mean and precision paths, handling zero tangents. +function _combine_x_tangents(x̄_from_μ, x̄_from_Q, x_star) + if _is_zero_tangent(x̄_from_μ) && _is_zero_tangent(x̄_from_Q) + return zeros(eltype(x_star), length(x_star)) + elseif _is_zero_tangent(x̄_from_μ) + return collect(x̄_from_Q) + elseif _is_zero_tangent(x̄_from_Q) + return collect(x̄_from_μ) + else + return collect(x̄_from_μ) .+ collect(x̄_from_Q) + end +end + +# IFT linear solve: solve Q_post * λ = x̄_total, with constraint projection for GMRF/ConstrainedGMRF. +function _ift_solve(posterior::Union{GMRF, ConstrainedGMRF}, x̄_total, prior_gmrf) + cache = linsolve_cache(_base_gmrf(posterior)) + b_saved = copy(cache.b) + cache.b = x̄_total + λ = copy(solve!(cache).u) + λ = _constrain_step(λ, cache, _extract_constraints(prior_gmrf)) + cache.b .= b_saved + return λ +end + +function _ift_solve(posterior::ChordalGMRF, x̄_total, ::ChordalGMRF) + return posterior.F \ x̄_total +end + """ ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(gaussian_approximation), - prior_gmrf::Union{GMRF, ConstrainedGMRF}, obs_lik::ObservationLikelihood; kwargs...) + prior_gmrf::Union{GMRF, ConstrainedGMRF, ChordalGMRF}, + obs_lik::ObservationLikelihood; kwargs...) Backend-agnostic automatic differentiation rule for `gaussian_approximation` using the Implicit Function Theorem. @@ -237,7 +267,7 @@ Uses the Implicit Function Theorem on the optimality condition ∇ₓ neg_log_po # Arguments - `config::RuleConfig{>:HasReverseMode}`: AD backend configuration -- `prior_gmrf`: Prior GMRF or ConstrainedGMRF +- `prior_gmrf`: Prior GMRF, ConstrainedGMRF, or ChordalGMRF - `obs_lik`: Observation likelihood - `kwargs...`: Convergence parameters (max_iter, mean_change_tol, etc.) - treated as non-differentiable @@ -248,7 +278,7 @@ Uses the Implicit Function Theorem on the optimality condition ∇ₓ neg_log_po function ChainRulesCore.rrule( config::RuleConfig{>:HasReverseMode}, ::typeof(gaussian_approximation), - prior_gmrf::Union{GMRF, ConstrainedGMRF}, + prior_gmrf::Union{GMRF, ConstrainedGMRF, ChordalGMRF}, obs_lik::ObservationLikelihood; kwargs... ) @@ -258,41 +288,33 @@ function ChainRulesCore.rrule( # === Pullback === function gaussian_approximation_pullback(ȳ) - # Extract tangent components — handle both GMRF and ConstrainedGMRF posteriors + if ȳ isa ZeroTangent + return (NoTangent(), ZeroTangent(), NoTangent()) + end + + # Extract tangent components — dispatches on posterior type μ̄, Q̄ = _extract_posterior_tangents(ȳ, posterior) - # Handle precision tangent through loghessian to get indirect x* dependence - # The Hessian appears in the posterior precision: Q_post = Q_prior - loghessian(x*) - # When differentiating w.r.t. θ: ∂Q_post/∂θ = ∂Q_prior/∂θ - ∂loghessian/∂x* · ∂x*/∂θ - ∂loghessian/∂obs_lik · ∂obs_lik/∂θ - # Compute indirect x* contribution from precision gradient via loghessian + # Q̄ path: backprop through Q_post = Q_prior - loghessian(x*, obs_lik) if _is_zero_tangent(Q̄) - # No precision tangent, only mean path contributes - x_tangent_from_hess = nothing - obs_lik_tangent_from_Q̄ = NoTangent() + x̄_from_Q = ZeroTangent() + obs_lik_tangent_from_Q = NoTangent() else _, hess_pullback = rrule_via_ad(config, loghessian, x_star, obs_lik) - _, x_tangent_from_hess, obs_lik_tangent_from_Q̄ = hess_pullback(-Q̄) # Pass -Q̄ because Q_post = Q_prior - H + _, x̄_from_Q, obs_lik_tangent_from_Q = hess_pullback(-Q̄) end - # Solve for λ combining BOTH mean tangent and indirect x* dependence from precision - # H · λ = μ̄ + x_tangent_from_hess - # This accounts for: ∂L/∂x* from direct (mean) path + indirect (precision→hessian→x*) path - cache = linsolve_cache(_base_gmrf(posterior)) - b_saved = copy(cache.b) - cache.b = _is_zero_tangent(x_tangent_from_hess) ? collect(μ̄) : collect(μ̄) .+ collect(x_tangent_from_hess) - λ = copy(solve!(cache).u) - - # For constrained problems, project λ onto the constraint tangent space via - # the KKT Schur complement (same as forward pass Newton step projection). - λ = _constrain_step(λ, cache, _extract_constraints(prior_gmrf)) - cache.b .= b_saved - - # VJP through ∇ₓ_neg_log_posterior at x* to get gradients w.r.t. prior and likelihood. - # Use the BASE GMRF (not ConstrainedGMRF) to match the forward pass, which operates - # on the unconstrained GMRF with constraint-projected steps. + # μ̄ path: μ_post = x*, so μ̄ flows directly to x* + x̄_from_μ = _is_zero_tangent(μ̄) ? ZeroTangent() : μ̄ + + # Combine x* tangents and solve Q_post * λ = x̄_total via IFT + x̄_total = _combine_x_tangents(x̄_from_μ, x̄_from_Q, x_star) + λ = _ift_solve(posterior, x̄_total, prior_gmrf) + + # VJP through ∇ₓ_neg_log_posterior at x* to get gradients w.r.t. prior and likelihood base_prior = _base_gmrf(prior_gmrf) _, ∇_pullback = rrule_via_ad(config, ∇ₓ_neg_log_posterior, base_prior, obs_lik, x_star) - _, base_prior_tangent, obs_lik_tangent, _ = ∇_pullback(-λ) # Note the minus sign from IFT + _, base_prior_tangent, obs_lik_tangent, _ = ∇_pullback(-λ) # Add contribution from ȳ.precision to base prior tangent if !_is_zero_tangent(Q̄) @@ -303,22 +325,33 @@ function ChainRulesCore.rrule( prior_gmrf_tangent = _wrap_prior_tangent(base_prior_tangent, prior_gmrf) # Combine tangents from mean path and precision path - obs_lik_combined = _add_namedtuples(obs_lik_tangent, obs_lik_tangent_from_Q̄) + obs_lik_combined = _add_namedtuples(obs_lik_tangent, obs_lik_tangent_from_Q) - # Return tangents: NoTangent for function and kwargs return (NoTangent(), prior_gmrf_tangent, obs_lik_combined) end return posterior, gaussian_approximation_pullback end -# Also handle case without RuleConfig for simpler usage -function ChainRulesCore.rrule( - ::typeof(gaussian_approximation), - prior_gmrf::Union{GMRF, ConstrainedGMRF}, - obs_lik::ObservationLikelihood; - kwargs... +# ============================================================================= +# ChordalGMRF tangent helpers (dispatched from unified rrule above) +# ============================================================================= + +# Extract tangents from ChordalGMRF posterior +function _extract_posterior_tangents(ȳ, ::ChordalGMRF) + μ̄ = _is_zero_tangent(ȳ.μ) ? ZeroTangent() : ȳ.μ + Q̄ = _is_zero_tangent(ȳ.Q) ? ZeroTangent() : ȳ.Q + return (μ̄, Q̄) +end + +# Add Q̄ contribution to prior tangent for ChordalGMRF +function _add_precision_tangent(prior_tangent, prior::ChordalGMRF, Q̄) + prior_μ̄ = (prior_tangent isa Tangent && hasproperty(prior_tangent, :μ)) ? prior_tangent.μ : NoTangent() + prior_Q̄_existing = (prior_tangent isa Tangent && hasproperty(prior_tangent, :Q)) ? prior_tangent.Q : NoTangent() + combined_Q̄ = _is_zero_tangent(prior_Q̄_existing) ? Q̄ : prior_Q̄_existing + Q̄ + return Tangent{typeof(prior)}(; + μ = prior_μ̄, + Q = combined_Q̄, + F = NoTangent(), ) - # Delegate to the RuleConfig version with default config - return rrule(NoRuleConfig(), gaussian_approximation, prior_gmrf, obs_lik; kwargs...) end diff --git a/src/autodiff/logpdf.jl b/src/autodiff/logpdf.jl index e266abe8..d7b75d9e 100644 --- a/src/autodiff/logpdf.jl +++ b/src/autodiff/logpdf.jl @@ -98,3 +98,4 @@ function ChainRulesCore.rrule(::typeof(logpdf), x::AbstractGMRF, z::AbstractVect ) end end + diff --git a/src/autodiff/mooncake_gaussian_approximation.jl b/src/autodiff/mooncake_gaussian_approximation.jl new file mode 100644 index 00000000..e756d58e --- /dev/null +++ b/src/autodiff/mooncake_gaussian_approximation.jl @@ -0,0 +1,117 @@ +using Mooncake +using Mooncake: @is_primitive, @mooncake_overlay, MinimalCtx, CoDual, NoRData, NoFData, primal, tangent, fdata, zero_tangent +using SparseArrays: nonzeros, SparseMatrixCSC +using LinearAlgebra: Hermitian +using CliqueTrees.Multifrontal: ChordalCholesky + +@is_primitive MinimalCtx Tuple{Type{ChordalGMRF}, AbstractVector, SparseMatrixCSC} + +function Mooncake.rrule!!( + ::CoDual{Type{ChordalGMRF}}, + cdμ::CoDual{<:AbstractVector}, + cdQ::CoDual{<:SparseMatrixCSC}, +) + μ, Σμ = MooncakeSparse.primaltangent(cdμ) + Q, ΣQ = MooncakeSparse.primaltangent(cdQ) + + gmrf = ChordalGMRF(μ, Q) + dy = fdata(zero_tangent(gmrf)) + + function pullback!!(::NoRData) + dμ = MooncakeSparse.toarray(gmrf.μ, dy.data.μ) + dQ = MooncakeSparse.toarray(gmrf.Q, dy.data.Q) + + Σμ .+= dμ + nonzeros(ΣQ) .+= nonzeros(parent(dQ)) + + return NoRData(), NoRData(), NoRData() + end + + return CoDual(gmrf, dy), pullback!! +end + +@is_primitive MinimalCtx Tuple{Type{ChordalGMRF}, AbstractVector, Hermitian, ChordalCholesky} + +function Mooncake.rrule!!( + ::CoDual{Type{ChordalGMRF}}, + cdμ::CoDual{<:AbstractVector}, + cdQ::CoDual{<:Hermitian}, + cdF::CoDual{<:ChordalCholesky}, +) + μ, Σμ = MooncakeSparse.primaltangent(cdμ) + Q, ΣQ = MooncakeSparse.primaltangent(cdQ) + F = primal(cdF) + + gmrf = ChordalGMRF(μ, Q, F) + dy = fdata(zero_tangent(gmrf)) + + function pullback!!(::NoRData) + dμ = MooncakeSparse.toarray(gmrf.μ, dy.data.μ) + dQ = MooncakeSparse.toarray(gmrf.Q, dy.data.Q) + + Σμ .+= dμ + nonzeros(parent(ΣQ)) .+= nonzeros(parent(dQ)) + + return NoRData(), NoRData(), NoRData(), NoRData() + end + + return CoDual(gmrf, dy), pullback!! +end + +function gaussian_approximation_notangent(prior::ChordalGMRF, obslik::ObservationLikelihood; kwargs...) + return gaussian_approximation(prior, obslik; kwargs...) +end + +@is_primitive MinimalCtx Tuple{typeof(gaussian_approximation_notangent), ChordalGMRF, ObservationLikelihood} +@is_primitive MinimalCtx Tuple{typeof(Core.kwcall), Any, typeof(gaussian_approximation_notangent), ChordalGMRF, ObservationLikelihood} + +function Mooncake.rrule!!( + ::CoDual{typeof(gaussian_approximation_notangent)}, + cdprior::CoDual{<:ChordalGMRF}, + cdobslik::CoDual{<:ObservationLikelihood}, +) + prior = primal(cdprior) + obslik = primal(cdobslik) + posterior = gaussian_approximation_notangent(prior, obslik) + + function pullback!!(::NoRData) + return NoRData(), Mooncake.zero_rdata(prior), Mooncake.zero_rdata(obslik) + end + + return CoDual(posterior, fdata(zero_tangent(posterior))), pullback!! +end + +function Mooncake.rrule!!( + ::CoDual{typeof(Core.kwcall)}, + cdkwargs::CoDual, + ::CoDual{typeof(gaussian_approximation_notangent)}, + cdprior::CoDual{<:ChordalGMRF}, + cdobslik::CoDual{<:ObservationLikelihood}, +) + prior = primal(cdprior) + obslik = primal(cdobslik) + kwargs = primal(cdkwargs) + posterior = gaussian_approximation_notangent(prior, obslik; kwargs...) + + function pullback!!(::NoRData) + return NoRData(), NoRData(), NoRData(), Mooncake.zero_rdata(prior), Mooncake.zero_rdata(obslik) + end + + return CoDual(posterior, fdata(zero_tangent(posterior))), pullback!! +end + +@mooncake_overlay function gaussian_approximation( + prior::ChordalGMRF, + obslik::ObservationLikelihood; + kwargs... +) + posterior = gaussian_approximation_notangent(prior, obslik; kwargs...) + x_star = mean(posterior) + + grad = ∇ₓ_neg_log_posterior(prior, obslik, x_star) + x_corrected = x_star - posterior.F \ grad + + Q_post = hermdiff(precision_matrix(prior), loghessian(x_corrected, obslik)) + + return ChordalGMRF(x_corrected, Q_post, posterior.F) +end diff --git a/src/chordal_gmrf.jl b/src/chordal_gmrf.jl new file mode 100644 index 00000000..43c8cbbf --- /dev/null +++ b/src/chordal_gmrf.jl @@ -0,0 +1,77 @@ +using CliqueTrees.Multifrontal: ChordalCholesky, selinv as mselinv, logdet +using LinearAlgebra: Hermitian, cholesky!, diag, ldiv!, axpy!, dot +using SparseArrays: SparseMatrixCSC +using Random: AbstractRNG, randn + +export ChordalGMRF + +struct ChordalGMRF{T <: Real, Hrm <: Hermitian, Fac <: ChordalCholesky, Mea <: AbstractVector{T}} <: AbstractGMRF{T, Hrm} + μ::Mea + Q::Hrm + F::Fac +end + +function ChordalGMRF(μ::AbstractVector, Q::SparseMatrixCSC, F::ChordalCholesky) + return ChordalGMRF(μ, Hermitian(Q, :L), F) +end + +function ChordalGMRF(μ::AbstractVector, Q::SparseMatrixCSC; kw...) + H = Hermitian(Q, :L) + F = cholesky!(ChordalCholesky(H; kw...)) + return ChordalGMRF(μ, H, F) +end + +function Base.length(d::ChordalGMRF) + return length(d.μ) +end + +function mean(d::ChordalGMRF) + return d.μ +end + +function precision_map(d::ChordalGMRF) + return d.Q +end + +function precision_matrix(d::ChordalGMRF) + return d.Q +end + +function logdetcov(d::ChordalGMRF) + return -logdet(d.Q, d.F) +end + +function sqmahal(d::ChordalGMRF, x::AbstractVector) + r = x - d.μ + return dot(r, d.Q, r) +end + +function gradlogpdf(d::ChordalGMRF, x::AbstractVector) + return d.Q * (d.μ - x) +end + +function var(d::ChordalGMRF) + Σ = mselinv(d.Q, d.F) + return diag(Σ) +end + +function _rand!(rng::AbstractRNG, d::ChordalGMRF{T}, x::AbstractVector) where {T} + z = randn(rng, T, length(x)) + return axpy!(true, d.μ, d.F.P \ ldiv!(d.F.U, z)) +end + +function Base.show(io::IO, d::ChordalGMRF{T}) where {T} + return print(io, "ChordalGMRF{$T}(n=$(length(d)))") +end + +function Base.show(io::IO, ::MIME"text/plain", d::ChordalGMRF{T}) where {T} + println(io, "ChordalGMRF{$T} with $(length(d)) variables") + + μ = d.μ + + return if length(μ) <= 6 + print(io, " Mean: $μ") + else + print(io, " Mean: [$(μ[1]), $(μ[2]), $(μ[3]), ..., $(μ[end - 2]), $(μ[end - 1]), $(μ[end])]") + end +end diff --git a/src/solvers/selinv.jl b/src/solvers/selinv.jl index c700fca4..521c12e2 100644 --- a/src/solvers/selinv.jl +++ b/src/solvers/selinv.jl @@ -104,12 +104,12 @@ _selinv_impl(linsolve, alg) = error("Full selected inversion not implemented for function _selinv_impl(linsolve, ::LinearSolve.CHOLMODFactorization) factorization = LinearSolve.@get_cacheval(linsolve, :CHOLMODFactorization) - return SelectedInversion.selinv(factorization; depermute = true).Z + return Symmetric(sparse(SelectedInversion.selinv(factorization; depermute = true).Z)) end function _selinv_impl(linsolve, ::LinearSolve.CholeskyFactorization) factorization = LinearSolve.@get_cacheval(linsolve, :CholeskyFactorization) - return SelectedInversion.selinv(factorization; depermute = true).Z + return Symmetric(sparse(SelectedInversion.selinv(factorization; depermute = true).Z)) end function _selinv_impl(linsolve, ::LinearSolve.DiagonalFactorization) @@ -119,7 +119,7 @@ end function _selinv_impl(linsolve, ::LinearSolve.LDLtFactorization) factorization = LinearSolve.@get_cacheval(linsolve, :LDLtFactorization) - return SelectedInversion.selinv(factorization; depermute = true).Z + return Symmetric(sparse(SelectedInversion.selinv(factorization; depermute = true).Z)) end function _selinv_impl(linsolve, ::LinearSolve.PardisoJL) diff --git a/test/autodiff/runtests.jl b/test/autodiff/runtests.jl index 6ec7c85e..0506e8b3 100644 --- a/test/autodiff/runtests.jl +++ b/test/autodiff/runtests.jl @@ -3,3 +3,4 @@ include("test_forwarddiff_extension.jl") include("test_logpdf.jl") include("test_constructors.jl") include("test_gaussian_approximation.jl") +include("test_gaussian_approximation_chordal.jl") diff --git a/test/autodiff/test_gaussian_approximation_chordal.jl b/test/autodiff/test_gaussian_approximation_chordal.jl new file mode 100644 index 00000000..f9ee263e --- /dev/null +++ b/test/autodiff/test_gaussian_approximation_chordal.jl @@ -0,0 +1,312 @@ +using Test +using GaussianMarkovRandomFields +using Distributions: logpdf, Poisson, Normal +using SparseArrays +using LinearAlgebra +using Random + +using DifferentiationInterface +using FiniteDiff, Mooncake + +backends = Any[("Mooncake", AutoMooncake())] + +@testset "$backend_name ChordalGMRF autodiff tests" for (backend_name, backend) in backends + # Set seed for reproducibility + Random.seed!(42) + fd_backend = AutoFiniteDiff() + + # Helper function to create simple AR(1) precision matrix + function ar_precision(ρ, k) + return spdiagm(-1 => -ρ * ones(k - 1), 0 => ones(k) .+ ρ^2, 1 => -ρ * ones(k - 1)) + end + + # Test pipeline: hyperparameters → ChordalGMRF → gaussian_approximation → logpdf + function test_gauss_approx_pipeline(θ::Vector, y::Vector, x::Vector, k::Int) + # Extract hyperparameters + ρ = θ[1] # AR parameter + μ_const = θ[2] # constant mean + + # Create precision matrix + Q = ar_precision(ρ, k) + + # Create constant mean vector + μ = μ_const * ones(k) + + # Create prior ChordalGMRF + prior_gmrf = ChordalGMRF(μ, Q) + + # Create Poisson observation likelihood + obs_model = ExponentialFamily(Poisson) + poisson_obs = PoissonObservations(y) + obs_lik = obs_model(poisson_obs) + + # Find Gaussian approximation + posterior_gmrf = gaussian_approximation(prior_gmrf, obs_lik) + + # Compute logpdf at evaluation point x + return logpdf(posterior_gmrf, x) + end + + @testset "Poisson likelihood with gaussian_approximation" begin + k = 8 + θ = [0.4, 0.5] # [ρ, μ_const] + y = [2, 1, 3, 2, 1, 4, 2, 1] # Poisson count data + x = randn(k) .+ 0.5 # Evaluation point + + grad_test = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, k), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, k), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-3 + @test maximum(rel_error) < 5.0e-2 + end + + @testset "Different hyperparameter values" begin + k = 6 + y = [1, 2, 1, 3, 1, 2] + x = randn(k) .+ 0.3 + + # Test different ρ and μ values + for ρ in [0.2, 0.5] + for μ_const in [0.3, 0.8] + θ = [ρ, μ_const] + + grad_test = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, k), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, k), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 2.0e-2 # Relaxed for complex optimization + @test maximum(rel_error) < 5.0e-2 + end + end + end + + @testset "Gaussian (conjugate) likelihood" begin + # Test with Gaussian likelihood - should also work through rrule + k = 6 + θ = [0.3, 0.1] + y = randn(k) .* 0.3 .+ 0.2 + x = randn(k) + + function gaussian_lik_pipeline(θ, y, x, k) + ρ, μ_const = θ + Q = ar_precision(ρ, k) + μ = μ_const * ones(k) + + prior_gmrf = ChordalGMRF(μ, Q) + + obs_model = ExponentialFamily(Normal) + obs_lik = obs_model(y; σ = 0.5) + posterior_gmrf = gaussian_approximation(prior_gmrf, obs_lik) + return logpdf(posterior_gmrf, x) + end + + grad_test = DifferentiationInterface.gradient( + θ -> gaussian_lik_pipeline(θ, y, x, k), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> gaussian_lik_pipeline(θ, y, x, k), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-3 + @test maximum(rel_error) < 5.0e-2 + end + + @testset "Small system" begin + # Test with very small system + k = 4 + θ = [0.6, 0.4] + y = [1, 2, 1, 1] + x = randn(k) .+ 0.4 + + grad_test = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, k), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> test_gauss_approx_pipeline(θ, y, x, k), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-3 + @test maximum(rel_error) < 5.0e-2 + end + + @testset "Basic logpdf autodiff" begin + k = 10 + z = randn(k) + + function test_logpdf_pipeline(θ::AbstractVector, z::AbstractVector, k) + ρ = θ[1] + μ_const = θ[2] + + Q = ar_precision(ρ, k) + μ = μ_const * ones(k) + + gmrf = ChordalGMRF(μ, Q) + return logpdf(gmrf, z) + end + + θ = [0.5, 0.1] + + # Compute gradients using AD backend + grad_test = DifferentiationInterface.gradient( + θ -> test_logpdf_pipeline(θ, z, k), + backend, + θ + ) + + # Compute gradients using finite differences + grad_fd = DifferentiationInterface.gradient( + θ -> test_logpdf_pipeline(θ, z, k), + fd_backend, + θ + ) + + # Check AD gradients match finite differences + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-4 + @test maximum(rel_error) < 1.0e-2 + end + + @testset "2D grid precision" begin + # Build 2D grid precision matrix using spdiagm (Zygote-compatible) + function grid_precision(α, grid_size) + n = grid_size^2 + + # Main diagonal + diag_main = fill(4.0 + α, n) + + # Horizontal neighbors (±1 diagonals), but skip row boundaries + horiz = [-1.0 * (mod(i, grid_size) != 0) for i in 1:(n - 1)] + + # Vertical neighbors (±grid_size diagonals) + vert = fill(-1.0, n - grid_size) + + return spdiagm(-grid_size => vert, -1 => horiz, 0 => diag_main, 1 => horiz, grid_size => vert) + end + + grid_size = 4 + n = grid_size^2 + z = randn(n) + + function test_grid_pipeline(θ::AbstractVector, z::AbstractVector, grid_size) + α = θ[1] + μ_const = θ[2] + + Q = grid_precision(α, grid_size) + n = grid_size^2 + μ = μ_const * ones(n) + + gmrf = ChordalGMRF(μ, Q) + return logpdf(gmrf, z) + end + + θ = [0.5, 0.1] + + grad_test = DifferentiationInterface.gradient( + θ -> test_grid_pipeline(θ, z, grid_size), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> test_grid_pipeline(θ, z, grid_size), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-4 + @test maximum(rel_error) < 1.0e-2 + end + + @testset "2D grid with Poisson gaussian_approximation" begin + function grid_precision(α, grid_size) + n = grid_size^2 + diag_main = fill(4.0 + α, n) + horiz = [-1.0 * (mod(i, grid_size) != 0) for i in 1:(n - 1)] + vert = fill(-1.0, n - grid_size) + return spdiagm(-grid_size => vert, -1 => horiz, 0 => diag_main, 1 => horiz, grid_size => vert) + end + + grid_size = 3 + n = grid_size^2 + y = [2, 1, 3, 2, 1, 4, 2, 1, 2] # Poisson count data + x = randn(n) .+ 0.5 + + function test_grid_gauss_approx(θ, y, x, grid_size) + α, μ_const = θ + Q = grid_precision(α, grid_size) + n = grid_size^2 + μ = μ_const * ones(n) + + prior_gmrf = ChordalGMRF(μ, Q) + + obs_model = ExponentialFamily(Poisson) + obs_lik = obs_model(PoissonObservations(y)) + posterior_gmrf = gaussian_approximation(prior_gmrf, obs_lik) + return logpdf(posterior_gmrf, x) + end + + θ = [0.5, 0.3] + + grad_test = DifferentiationInterface.gradient( + θ -> test_grid_gauss_approx(θ, y, x, grid_size), + backend, + θ + ) + + grad_fd = DifferentiationInterface.gradient( + θ -> test_grid_gauss_approx(θ, y, x, grid_size), + fd_backend, + θ + ) + + abs_error = abs.(grad_test - grad_fd) + rel_error = abs_error ./ (abs.(grad_fd) .+ 1.0e-10) + + @test maximum(abs_error) < 1.0e-2 + @test maximum(rel_error) < 5.0e-2 + end +end diff --git a/test/gaussian_approximation/runtests.jl b/test/gaussian_approximation/runtests.jl index 406b1d3d..2abcb515 100644 --- a/test/gaussian_approximation/runtests.jl +++ b/test/gaussian_approximation/runtests.jl @@ -1 +1,2 @@ include("test_gaussian_approximation.jl") +include("test_gaussian_approximation_chordal.jl") diff --git a/test/gaussian_approximation/test_gaussian_approximation_chordal.jl b/test/gaussian_approximation/test_gaussian_approximation_chordal.jl new file mode 100644 index 00000000..658fadd9 --- /dev/null +++ b/test/gaussian_approximation/test_gaussian_approximation_chordal.jl @@ -0,0 +1,194 @@ +using Test +using GaussianMarkovRandomFields +using LinearAlgebra +using SparseArrays +using Distributions + +@testset "Gaussian Approximation - ChordalGMRF" begin + + @testset "Gaussian Likelihood - Analytical Solution" begin + # For Gaussian likelihood, the Gaussian approximation should be exact + n = 5 + Q_prior = spdiagm(0 => 2.0 * ones(n)) + μ_prior = zeros(n) + prior_gmrf = ChordalGMRF(μ_prior, Q_prior) + + obs_model = ExponentialFamily(Normal) + y = [0.1, 0.2, -0.1, 0.3, -0.2] + + obs_lik = obs_model(y; σ = 0.5) + result = gaussian_approximation(prior_gmrf, obs_lik) + + # Should return a ChordalGMRF + @test result isa ChordalGMRF + @test length(mean(result)) == n + + # For Gaussian case, verify against analytical solution + σ = 0.5 + Q_obs = sparse(I, n, n) / σ^2 + Q_analytical = Q_prior + Q_obs + μ_analytical = Q_analytical \ (Q_prior * μ_prior + Q_obs * y) + + @test precision_matrix(result) ≈ Q_analytical atol = 1.0e-8 + @test mean(result) ≈ μ_analytical atol = 1.0e-8 + end + + @testset "Bernoulli Likelihood - Mathematical Properties" begin + # Test with Bernoulli observation model (non-linear) + n = 8 + Q_prior = spdiagm(0 => ones(n), 1 => fill(-0.3, n - 1), -1 => fill(-0.3, n - 1)) + μ_prior = zeros(n) + prior_gmrf = ChordalGMRF(μ_prior, Q_prior) + + obs_model = ExponentialFamily(Bernoulli) + y = [1, 1, 0, 1, 0, 0, 1, 0] + + obs_lik = obs_model(y) + result = gaussian_approximation(prior_gmrf, obs_lik) + + # Should return a ChordalGMRF + @test result isa ChordalGMRF + @test length(mean(result)) == n + + # Mode should reflect the data pattern + μ_result = mean(result) + @test μ_result[1] > 0 # First observation is 1 + @test μ_result[3] < 0 # Third observation is 0 + end + + @testset "Poisson Likelihood - Mathematical Properties" begin + # Test with Poisson observation model + n = 6 + Q_prior = spdiagm(0 => ones(n)) + μ_prior = zeros(n) + prior_gmrf = ChordalGMRF(μ_prior, Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([1, 3, 0, 2, 4, 1]) + + obs_lik = obs_model(y) + result = gaussian_approximation(prior_gmrf, obs_lik) + + # Should return a ChordalGMRF + @test result isa ChordalGMRF + @test length(mean(result)) == n + + # Mode should be reasonable for Poisson data + μ_result = mean(result) + @test all(isfinite.(μ_result)) + + # Higher counts should correspond to higher modes + @test μ_result[5] > μ_result[3] # y[5]=4 > y[3]=0 + @test μ_result[2] > μ_result[3] # y[2]=3 > y[3]=0 + end + + @testset "Consistency with GMRF" begin + # Results should match between GMRF and ChordalGMRF + n = 5 + Q_prior = spdiagm(0 => 2.0 * ones(n), 1 => fill(-0.5, n - 1), -1 => fill(-0.5, n - 1)) + μ_prior = zeros(n) + + gmrf_prior = GMRF(μ_prior, Q_prior) + chordal_prior = ChordalGMRF(μ_prior, Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([2, 5, 1, 3, 4]) + obs_lik = obs_model(y) + + result_gmrf = gaussian_approximation(gmrf_prior, obs_lik) + result_chordal = gaussian_approximation(chordal_prior, obs_lik) + + @test mean(result_gmrf) ≈ mean(result_chordal) atol = 1.0e-6 + @test precision_matrix(result_gmrf) ≈ precision_matrix(result_chordal) atol = 1.0e-6 + end + + @testset "Sparse precision - tridiagonal" begin + # Test with tridiagonal precision (common in GMRFs) + n = 10 + Q_prior = spdiagm(0 => 2.0 * ones(n), 1 => -ones(n - 1), -1 => -ones(n - 1)) + μ_prior = zeros(n) + prior_gmrf = ChordalGMRF(μ_prior, Q_prior) + + obs_model = ExponentialFamily(Normal) + y = randn(n) + obs_lik = obs_model(y; σ = 0.5) + + result = gaussian_approximation(prior_gmrf, obs_lik) + + @test result isa ChordalGMRF + @test length(mean(result)) == n + @test all(isfinite.(mean(result))) + end + + @testset "Warm-start with x0" begin + n = 6 + Q_prior = spdiagm(0 => ones(n)) + prior_gmrf = ChordalGMRF(zeros(n), Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([5, 12, 2, 8, 15, 3]) + obs_lik = obs_model(y) + + # Cold-start result + result_cold = gaussian_approximation(prior_gmrf, obs_lik) + x_star = mean(result_cold) + + # Warm-start from converged mode + result_warm = gaussian_approximation(prior_gmrf, obs_lik; x0 = x_star) + @test mean(result_warm) ≈ x_star atol = 1.0e-4 + end + + @testset "Adaptive stepsize - extreme Poisson" begin + # Test case where adaptive stepsize is needed + n = 3 + Q_prior = spdiagm(0 => 0.01 * ones(n)) # Weak prior + prior_gmrf = ChordalGMRF(zeros(n), Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([100, 500, 50]) + obs_lik = obs_model(y) + + result = gaussian_approximation(prior_gmrf, obs_lik) + + @test result isa ChordalGMRF + μ_result = mean(result) + @test all(isfinite.(μ_result)) + + # Each component should be close to log of its count + for i in 1:n + @test abs(μ_result[i] - log(y.counts[i])) < 1.5 + end + end + + @testset "Non-convergence path" begin + # Force non-convergence with max_iter = 1 + n = 5 + Q_prior = spdiagm(0 => ones(n)) + prior_gmrf = ChordalGMRF(zeros(n), Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([10, 15, 8, 12, 20]) + obs_lik = obs_model(y) + + result = gaussian_approximation(prior_gmrf, obs_lik; max_iter = 1) + + # Should still return a ChordalGMRF + @test result isa ChordalGMRF + @test all(isfinite.(mean(result))) + end + + @testset "Verbose output" begin + n = 4 + Q_prior = spdiagm(0 => ones(n)) + prior_gmrf = ChordalGMRF(zeros(n), Q_prior) + + obs_model = ExponentialFamily(Poisson) + y = PoissonObservations([2, 3, 1, 4]) + obs_lik = obs_model(y) + + # Should not error with verbose=true + result = gaussian_approximation(prior_gmrf, obs_lik; verbose = true) + @test result isa ChordalGMRF + end + +end