Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ LibGEOS = "a90b1aa1-3769-5649-ba7e-abc5a9d163eb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SelectedInversion = "043bf095-3f01-458a-9f1c-8cf4448fe908"
Expand All @@ -30,6 +31,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4"

[weakdeps]
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Expand All @@ -40,7 +42,6 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
GaussianMarkovRandomFieldsAutoDiff = ["ForwardDiff", "Zygote"]
Expand All @@ -57,7 +58,7 @@ GaussianMarkovRandomFieldsSparseJacobian = ["Symbolics", "SparseDiffTools"]
AMD = "0.5"
Aqua = "0.8"
ChainRulesCore = "1"
CliqueTrees = "1.18.0"
CliqueTrees = "1.19.1"
DataStructures = "0.14 - 0.19"
DifferentiationInterface = "0.6, 0.7"
Distributions = "0.25"
Expand All @@ -75,6 +76,7 @@ LinearAlgebra = "<0.0.1, 1"
LinearMaps = "3.11"
LinearSolve = "2, 3"
Makie = "0.19 - 0.22"
Mooncake = "0.5.25"
NearestNeighbors = "0.4"
Pardiso = "1"
Random = "<0.0.1, 1"
Expand All @@ -92,7 +94,7 @@ StatsModels = "0.7"
Symbolics = "5"
Tensors = "1.16"
Test = "<0.0.1, 1"
Zygote = "0.6"
Zygote = "0.6, 0.7"
julia = "1.10"

[extras]
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CliqueTrees = "60701a23-6482-424a-84db-faee86b9b1f8"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
GaussianMarkovRandomFields = "d5f06795-35bb-4323-9f0b-405ef76cfc5b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
MatrixDepot = "b51810bb-c9f3-55da-ae3c-350fc1fbce05"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
171 changes: 144 additions & 27 deletions benchmarks/autodiff_comparison.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ using LinearSolve
using Printf
using Random

using Zygote, Enzyme, FiniteDiff
using Zygote, Enzyme, FiniteDiff, Mooncake

using CliqueTrees.Multifrontal: symbolic, chordal

println("="^80)
println("AUTODIFF BACKEND COMPARISON: HIGH-DIMENSIONAL HYPERPARAMETER SPACE")
Expand All @@ -37,7 +39,7 @@ println(" - $n mean parameters (one per time point)")
println(" - 1 precision parameter (τ)")

# Workflow: θ → GMRF → gaussian_approximation → logpdf
function benchmark_workflow(θ::Vector{Float64}, y::Vector{Int}, x_eval::Vector{Float64})
function benchmark_workflow(θ::Vector{Float64}, y::PoissonObservations, x_eval::Vector{Float64})
# Extract hyperparameters
μ = θ[1:n] # Mean field (100 params)
log_τ = θ[n + 1] # Log precision
Expand All @@ -61,6 +63,31 @@ function benchmark_workflow(θ::Vector{Float64}, y::Vector{Int}, x_eval::Vector{
return logpdf(posterior, x_eval)
end

# ChordalGMRF workflow (only supports Mooncake)
function benchmark_workflow_chordal(θ::Vector{Float64}, y::PoissonObservations, x_eval::Vector{Float64})
# Extract hyperparameters
μ = θ[1:n] # Mean field (100 params)
log_τ = θ[n + 1] # Log precision
τ = exp(log_τ)

# Build precision matrix using RW1 latent model
rw1_model = RW1Model(n)
Q = sparse(precision_matrix(rw1_model; τ = τ))

# Prior ChordalGMRF with custom mean
prior = ChordalGMRF(μ, Q)

# Poisson observation likelihood
obs_model = ExponentialFamily(Poisson)
obs_lik = obs_model(y)

# Gaussian approximation
posterior = gaussian_approximation(prior, obs_lik)

# Evaluate log-density
return logpdf(posterior, x_eval)
end

# Generate test data
println("\nGenerating test data...")
Random.seed!(123)
Expand All @@ -73,19 +100,24 @@ Random.seed!(123)
# Simulate observations (Poisson counts from smooth latent field)
x_true = μ_true .+ cumsum(randn(n)) .* sqrt(1 / τ_true) .* 0.5
x_true .-= mean(x_true) # Center
y_obs = rand.(Poisson.(exp.(x_true .+ 0.5)))
y_counts = rand.(Poisson.(exp.(x_true .+ 0.5)))
y_obs = PoissonObservations(y_counts)
x_eval = randn(n) .+ 0.3

# Initial parameter values (perturbed from truth)
θ_init = θ_true .+ randn(n_params) .* 0.1

println(" ✓ Generated $(length(y_obs)) Poisson observations")
println(" ✓ Generated $(length(y_counts)) Poisson observations")
println(" ✓ Initial parameters: $(n_params)-dimensional vector")

# Verify workflow works
println("\nVerifying workflow...")
# Verify workflows work
println("\nVerifying workflows...")
f_val = benchmark_workflow(θ_init, y_obs, x_eval)
println(" ✓ Function value: $(@sprintf("%.4f", f_val))")
println(" ✓ GMRF function value: $(@sprintf("%.4f", f_val))")

f_val_chordal = benchmark_workflow_chordal(θ_init, y_obs, x_eval)
println(" ✓ ChordalGMRF function value: $(@sprintf("%.4f", f_val_chordal))")
println(" ✓ Difference: $(@sprintf("%.2e", abs(f_val - f_val_chordal)))")

# Define backends
backends = [
Expand All @@ -95,7 +127,7 @@ backends = [
]

println("\n" * "="^80)
println("BENCHMARKING GRADIENT COMPUTATION")
println("BENCHMARKING GRADIENT COMPUTATION (via DifferentiationInterface, prepared)")
println("="^80)

results = Dict()
Expand All @@ -105,22 +137,22 @@ for (name, backend) in backends
println("-"^40)

try
# Warmup
print(" Warming up... ")
grad = DifferentiationInterface.gradient(
θ -> benchmark_workflow(θ, y_obs, x_eval),
backend,
θ_init
)
# Define loss function
loss = θ -> benchmark_workflow(θ, y_obs, x_eval)

# Prepare gradient (includes warmup/compilation)
print(" Preparing... ")
prep = DifferentiationInterface.prepare_gradient(loss, backend, θ_init)
println("✓")

# Compute gradient once
print(" Computing gradient... ")
grad = DifferentiationInterface.gradient(loss, prep, backend, θ_init)
println("✓")

# Benchmark
# Benchmark with prepared gradient
print(" Benchmarking... ")
bench = @benchmark DifferentiationInterface.gradient(
θ -> benchmark_workflow(θ, y_obs, x_eval),
$backend,
$θ_init
) samples = 10 seconds = 30
bench = @benchmark DifferentiationInterface.gradient($loss, $prep, $backend, $θ_init) samples = 10 seconds = 30

results[name] = (
gradient = grad,
Expand All @@ -135,23 +167,72 @@ for (name, backend) in backends
println(" Memory: $(@sprintf("%.2f", bench.memory / 1.0e6)) MB")

catch e
println(" ✗ Failed: $e")
println(" ✗ Failed: $(typeof(e).name.name)")
if e isa ErrorException
println(" $(first(split(e.msg, '\n')))")
end
results[name] = nothing
end
end

# ChordalGMRF benchmark (Mooncake only)
println("\n" * "="^80)
println("BENCHMARKING ChordalGMRF (Mooncake only)")
println("="^80)

println("\nChordalGMRF + Mooncake:")
println("-"^40)

try
# Define loss function
loss_chordal = θ -> benchmark_workflow_chordal(θ, y_obs, x_eval)

# Prepare gradient (includes warmup/compilation)
print(" Preparing... ")
prep_chordal = DifferentiationInterface.prepare_gradient(loss_chordal, AutoMooncake(; config=nothing), θ_init)
println("✓")

# Compute gradient once
print(" Computing gradient... ")
grad_chordal = DifferentiationInterface.gradient(loss_chordal, prep_chordal, AutoMooncake(; config=nothing), θ_init)
println("✓")

# Benchmark with prepared gradient
print(" Benchmarking... ")
bench_chordal = @benchmark DifferentiationInterface.gradient($loss_chordal, $prep_chordal, AutoMooncake(; config=nothing), $θ_init) samples = 10 seconds = 30

results["ChordalGMRF+Mooncake"] = (
gradient = grad_chordal,
time = minimum(bench_chordal.times) / 1.0e6,
bench = bench_chordal,
)

println("✓")
println(" Time (min): $(@sprintf("%.2f", results["ChordalGMRF+Mooncake"].time)) ms")
println(" Time (median): $(@sprintf("%.2f", median(bench_chordal.times) / 1.0e6)) ms")
println(" Allocations: $(bench_chordal.allocs)")
println(" Memory: $(@sprintf("%.2f", bench_chordal.memory / 1.0e6)) MB")

catch e
println(" ✗ Failed: $(typeof(e).name.name)")
if e isa ErrorException
println(" $(first(split(e.msg, '\n')))")
end
results["ChordalGMRF+Mooncake"] = nothing
end

# Summary comparison
println("\n" * "="^80)
println("SUMMARY")
println("="^80)

if all(v !== nothing for v in values(results))
if results["FiniteDiff"] !== nothing
# Verify gradients match
println("\nGradient verification (comparing to FiniteDiff):")
fd_grad = results["FiniteDiff"].gradient

for name in ["Zygote", "Enzyme"]
if results[name] !== nothing
for name in ["Zygote", "Enzyme", "ChordalGMRF+Mooncake"]
if get(results, name, nothing) !== nothing
grad = results[name].gradient
abs_error = abs.(grad - fd_grad)
max_error = maximum(abs_error)
Expand All @@ -164,7 +245,7 @@ if all(v !== nothing for v in values(results))
end

# Performance comparison table
println("\nPerformance comparison:")
println("\nPerformance comparison (GMRF backends):")
println(" " * "─"^76)
println(@sprintf(" %-20s %12s %12s %12s %12s", "Backend", "Time (ms)", "Speedup", "Allocs", "Memory (MB)"))
println(" " * "─"^76)
Expand All @@ -191,10 +272,46 @@ if all(v !== nothing for v in values(results))
enzyme_vs_zygote = results["Zygote"].time / results["Enzyme"].time
winner = enzyme_vs_zygote > 1 ? "Enzyme" : "Zygote"
ratio = max(enzyme_vs_zygote, 1.0 / enzyme_vs_zygote)
println("\n 🏆 $winner is fastest: $(@sprintf("%.1f", ratio))× faster than the other")
println("\n GMRF winner: $winner ($(@sprintf("%.1f", ratio))× faster)")
end
end

# ChordalGMRF vs GMRF comparison (Zygote only)
if get(results, "ChordalGMRF+Mooncake", nothing) !== nothing && get(results, "Zygote", nothing) !== nothing
println("\n" * "="^80)
println("GMRF vs ChordalGMRF COMPARISON (Zygote vs Mooncake)")
println("="^80)

r_gmrf = results["Zygote"]
r_chordal = results["ChordalGMRF+Mooncake"]

println("\n " * "─"^76)
println(@sprintf(" %-20s %12s %12s %12s %12s", "Implementation", "Time (ms)", "Speedup", "Allocs", "Memory (MB)"))
println(" " * "─"^76)

println(
@sprintf(
" %-20s %12.2f %12s %12d %12.2f",
"GMRF", r_gmrf.time, "1.0×", r_gmrf.bench.allocs, r_gmrf.bench.memory / 1.0e6
)
)

chordal_speedup = r_gmrf.time / r_chordal.time
println(
@sprintf(
" %-20s %12.2f %12s %12d %12.2f",
"ChordalGMRF", r_chordal.time, @sprintf("%.1f×", chordal_speedup),
r_chordal.bench.allocs, r_chordal.bench.memory / 1.0e6
)
)

println(" " * "─"^76)

winner = chordal_speedup > 1 ? "ChordalGMRF" : "GMRF"
ratio = max(chordal_speedup, 1.0 / chordal_speedup)
println("\n Winner: $winner ($(@sprintf("%.1f", ratio))× faster)")
end

println("\n" * "="^80)
println("BENCHMARK COMPLETE")
println("="^80)
Loading
Loading