Skip to content

Commit b5da94d

Browse files
committed
Fix type stability
1 parent 4b71e3b commit b5da94d

5 files changed

Lines changed: 38 additions & 27 deletions

File tree

src/coloring_compat.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@
1212
1313
Wrapper around TreeSetColoringResult that also stores local_indices mapping.
1414
"""
15-
struct ColoringResult{M<:AbstractMatrix,T<:Integer,G<:SparseMatrixColorings.AdjacencyGraph{T},GT<:SparseMatrixColorings.AbstractGroups{T},R}
16-
result::SparseMatrixColorings.TreeSetColoringResult{M,T,G,GT,R}
15+
struct ColoringResult{R<:SparseMatrixColorings.AbstractColoringResult}
16+
result::R
1717
local_indices::Vector{Int} # map from local to global indices
1818
end
1919

2020
"""
2121
_hessian_color_preprocess(
2222
edgelist,
2323
num_total_var,
24+
algo::SparseMatrixColorings.GreedyColoringAlgorithm,
2425
seen_idx = IndexedSet(0),
2526
)
2627
@@ -34,6 +35,7 @@ SparseMatrixColorings.
3435
function _hessian_color_preprocess(
3536
edgelist,
3637
num_total_var,
38+
algo::SparseMatrixColorings.GreedyColoringAlgorithm,
3739
seen_idx = IndexedSet(0),
3840
)
3941
resize!(seen_idx, num_total_var)
@@ -56,7 +58,6 @@ function _hessian_color_preprocess(
5658
# Note: This case should rarely occur in practice
5759
S = SparseArrays.spdiagm(0 => [true])
5860
problem = SparseMatrixColorings.ColoringProblem(; structure=:symmetric, partition=:column)
59-
algo = SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution)
6061
tree_result = SparseMatrixColorings.coloring(S, problem, algo)
6162
result = ColoringResult(tree_result, Int[])
6263
return I, J, result
@@ -68,7 +69,6 @@ function _hessian_color_preprocess(
6869
n = length(local_indices)
6970
S = SparseArrays.spdiagm(0 => trues(n))
7071
problem = SparseMatrixColorings.ColoringProblem(; structure=:symmetric, partition=:column)
71-
algo = SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution)
7272
tree_result = SparseMatrixColorings.coloring(S, problem, algo)
7373
result = ColoringResult(tree_result, local_indices)
7474
# I and J are already empty, which is correct for no off-diagonal elements
@@ -96,7 +96,6 @@ function _hessian_color_preprocess(
9696

9797
# Perform coloring using SparseMatrixColorings
9898
problem = SparseMatrixColorings.ColoringProblem(; structure=:symmetric, partition=:column)
99-
algo = SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution)
10099
tree_result = SparseMatrixColorings.coloring(S, problem, algo)
101100

102101
# Reconstruct I and J from the tree structure (matching original _indirect_recover_structure)

src/forward_over_reverse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function _eval_hessian(
6565
)
6666
end
6767
# TODO(odow): consider reverting to a view.
68-
output_slice = _UnsafeVectorView(nzcount, length(ex.hess_I), pointer(H))
68+
output_slice = _UnsafeVectorView{Float64}(nzcount, length(ex.hess_I), pointer(H))::_UnsafeVectorView{Float64}
6969
_recover_from_matmat!(
7070
output_slice,
7171
ex.seed_matrix,

src/mathoptinterface_api.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function MOI.features_available(d::NLPEvaluator)
1919
return [:Grad, :Jac, :JacVec, :Hess, :HessVec]
2020
end
2121

22-
function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
22+
function MOI.initialize(d::NLPEvaluator{R}, requested_features::Vector{Symbol}) where {R}
2323
# Check that we support the features requested by the user.
2424
available_features = MOI.features_available(d)
2525
for feature in requested_features
@@ -38,7 +38,7 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
3838
d.objective = nothing
3939
d.user_output_buffer = zeros(largest_user_input_dimension)
4040
d.jac_storage = zeros(max(N, largest_user_input_dimension))
41-
d.constraints = _FunctionStorage[]
41+
d.constraints = _FunctionStorage{R}[]
4242
d.last_x = fill(NaN, N)
4343
d.want_hess = :Hess in requested_features
4444
want_hess_storage = (:HessVec in requested_features) || d.want_hess
@@ -111,11 +111,11 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
111111
shared_partials_storage_ϵ,
112112
d,
113113
)
114-
objective = _FunctionStorage(
114+
objective = _FunctionStorage{R}(
115115
subexpr,
116116
N,
117117
coloring_storage,
118-
d.want_hess,
118+
d.want_hess ? d.coloring_algorithm : nothing,
119119
d.subexpressions,
120120
individual_order[1],
121121
subexpression_edgelist,
@@ -137,11 +137,11 @@ function MOI.initialize(d::NLPEvaluator, requested_features::Vector{Symbol})
137137
)
138138
push!(
139139
d.constraints,
140-
_FunctionStorage(
140+
_FunctionStorage{R}(
141141
subexpr,
142142
N,
143143
coloring_storage,
144-
d.want_hess,
144+
d.want_hess ? d.coloring_algorithm : nothing,
145145
d.subexpressions,
146146
individual_order[idx],
147147
subexpression_edgelist,

src/types.jl

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,28 +63,28 @@ function _subexpression_and_linearity(
6363
linearity
6464
end
6565

66-
struct _FunctionStorage
66+
struct _FunctionStorage{R<:SparseMatrixColorings.AbstractColoringResult}
6767
expr::_SubexpressionStorage
6868
grad_sparsity::Vector{Int}
6969
# Nonzero pattern of Hessian matrix
7070
hess_I::Vector{Int}
7171
hess_J::Vector{Int}
72-
rinfo::Union{ColoringResult,Nothing} # coloring info for hessians
72+
rinfo::Union{Nothing,ColoringResult{R}}
7373
seed_matrix::Matrix{Float64}
7474
# subexpressions which this function depends on, ordered for forward pass.
7575
dependent_subexpressions::Vector{Int}
7676

77-
function _FunctionStorage(
77+
function _FunctionStorage{R}(
7878
expr::_SubexpressionStorage,
7979
num_variables,
8080
coloring_storage::IndexedSet,
81-
want_hess::Bool,
81+
coloring_algorithm::Union{Nothing,SparseMatrixColorings.GreedyColoringAlgorithm},
8282
subexpressions::Vector{_SubexpressionStorage},
8383
dependent_subexpressions,
8484
subexpression_edgelist,
8585
subexpression_variables,
8686
linearity::Vector{Linearity},
87-
)
87+
) where {R}
8888
empty!(coloring_storage)
8989
_compute_gradient_sparsity!(coloring_storage, expr.nodes)
9090
for k in dependent_subexpressions
@@ -95,7 +95,7 @@ struct _FunctionStorage
9595
end
9696
grad_sparsity = sort!(collect(coloring_storage))
9797
empty!(coloring_storage)
98-
if want_hess
98+
if !isnothing(coloring_algorithm)
9999
edgelist = _compute_hessian_sparsity(
100100
expr.nodes,
101101
expr.adj,
@@ -107,10 +107,11 @@ struct _FunctionStorage
107107
hess_I, hess_J, rinfo = _hessian_color_preprocess(
108108
edgelist,
109109
num_variables,
110+
coloring_algorithm,
110111
coloring_storage,
111112
)
112113
seed_matrix = _seed_matrix(rinfo)
113-
return new(
114+
return new{R}(
114115
expr,
115116
grad_sparsity,
116117
hess_I,
@@ -120,7 +121,7 @@ struct _FunctionStorage
120121
dependent_subexpressions,
121122
)
122123
else
123-
return new(
124+
return new{R}(
124125
expr,
125126
grad_sparsity,
126127
Int[],
@@ -137,6 +138,7 @@ end
137138
NLPEvaluator(
138139
model::Nonlinear.Model,
139140
ordered_variables::Vector{MOI.VariableIndex},
141+
coloring_algorithm::SparseMatrixColorings.AbstractColoringAlgorithm = SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution),
140142
)
141143
142144
Return an `NLPEvaluator` object that implements the `MOI.AbstractNLPEvaluator`
@@ -145,12 +147,13 @@ interface.
145147
!!! warning
146148
Before using, you must initialize the evaluator using `MOI.initialize`.
147149
"""
148-
mutable struct NLPEvaluator <: MOI.AbstractNLPEvaluator
150+
mutable struct NLPEvaluator{R,C<:SparseMatrixColorings.GreedyColoringAlgorithm} <: MOI.AbstractNLPEvaluator
149151
data::Nonlinear.Model
150152
ordered_variables::Vector{MOI.VariableIndex}
153+
coloring_algorithm::C
151154

152-
objective::Union{Nothing,_FunctionStorage}
153-
constraints::Vector{_FunctionStorage}
155+
objective::Union{Nothing,_FunctionStorage{R}}
156+
constraints::Vector{_FunctionStorage{R}}
154157
subexpressions::Vector{_SubexpressionStorage}
155158
subexpression_order::Vector{Int}
156159
# Storage for the subexpressions in reverse-mode automatic differentiation.
@@ -183,8 +186,17 @@ mutable struct NLPEvaluator <: MOI.AbstractNLPEvaluator
183186

184187
function NLPEvaluator(
185188
data::Nonlinear.Model,
186-
ordered_variables::Vector{MOI.VariableIndex},
189+
ordered_variables::Vector{MOI.VariableIndex};
190+
coloring_algorithm::SparseMatrixColorings.GreedyColoringAlgorithm = SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution),
187191
)
188-
return new(data, ordered_variables)
192+
problem = SparseMatrixColorings.ColoringProblem(; structure=:symmetric, partition=:column)
193+
C = typeof(coloring_algorithm)
194+
R = Base.promote_op(
195+
SparseMatrixColorings.coloring,
196+
SparseArrays.SparseMatrixCSC{Bool,Int},
197+
typeof(problem),
198+
C,
199+
)
200+
return new{R,C}(data, ordered_variables, coloring_algorithm)
189201
end
190202
end

test/ReverseAD.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import MathOptInterface as MOI
1414
const Nonlinear = MOI.Nonlinear
1515

1616
import ArrayDiff
17-
# Coloring submodule removed - using SparseMatrixColorings instead
1817

1918
function runtests()
2019
for name in names(@__MODULE__; all = true)
@@ -481,7 +480,8 @@ end
481480

482481
function test_coloring_end_to_end_hessian_coloring_and_recovery()
483482
# Test the new coloring API through the compatibility layer
484-
I, J, rinfo = ArrayDiff._hessian_color_preprocess(Set([(1, 2)]), 2, ArrayDiff.IndexedSet(0))
483+
coloring_algorithm = ArrayDiff.SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution)
484+
I, J, rinfo = ArrayDiff._hessian_color_preprocess(Set([(1, 2)]), 2, coloring_algorithm, ArrayDiff.IndexedSet(0))
485485
R = ArrayDiff._seed_matrix(rinfo)
486486
ArrayDiff._prepare_seed_matrix!(R, rinfo)
487487
@test I == [1, 2, 2]

0 commit comments

Comments
 (0)