@@ -63,28 +63,28 @@ function _subexpression_and_linearity(
6363 linearity
6464end
6565
66- struct _FunctionStorage
66+ struct _FunctionStorage{R <: SparseMatrixColorings.AbstractColoringResult }
6767 expr:: _SubexpressionStorage
6868 grad_sparsity:: Vector{Int}
6969 # Nonzero pattern of Hessian matrix
7070 hess_I:: Vector{Int}
7171 hess_J:: Vector{Int}
72- rinfo:: Union{ColoringResult, Nothing} # coloring info for hessians
72+ rinfo:: Union{Nothing,ColoringResult{R}}
7373 seed_matrix:: Matrix{Float64}
7474 # subexpressions which this function depends on, ordered for forward pass.
7575 dependent_subexpressions:: Vector{Int}
7676
77- function _FunctionStorage (
77+ function _FunctionStorage {R} (
7878 expr:: _SubexpressionStorage ,
7979 num_variables,
8080 coloring_storage:: IndexedSet ,
81- want_hess :: Bool ,
81+ coloring_algorithm :: Union{Nothing,SparseMatrixColorings.GreedyColoringAlgorithm} ,
8282 subexpressions:: Vector{_SubexpressionStorage} ,
8383 dependent_subexpressions,
8484 subexpression_edgelist,
8585 subexpression_variables,
8686 linearity:: Vector{Linearity} ,
87- )
87+ ) where {R}
8888 empty! (coloring_storage)
8989 _compute_gradient_sparsity! (coloring_storage, expr. nodes)
9090 for k in dependent_subexpressions
@@ -95,7 +95,7 @@ struct _FunctionStorage
9595 end
9696 grad_sparsity = sort! (collect (coloring_storage))
9797 empty! (coloring_storage)
98- if want_hess
98+ if ! isnothing (coloring_algorithm)
9999 edgelist = _compute_hessian_sparsity (
100100 expr. nodes,
101101 expr. adj,
@@ -107,10 +107,11 @@ struct _FunctionStorage
107107 hess_I, hess_J, rinfo = _hessian_color_preprocess (
108108 edgelist,
109109 num_variables,
110+ coloring_algorithm,
110111 coloring_storage,
111112 )
112113 seed_matrix = _seed_matrix (rinfo)
113- return new (
114+ return new {R} (
114115 expr,
115116 grad_sparsity,
116117 hess_I,
@@ -120,7 +121,7 @@ struct _FunctionStorage
120121 dependent_subexpressions,
121122 )
122123 else
123- return new (
124+ return new {R} (
124125 expr,
125126 grad_sparsity,
126127 Int[],
137138 NLPEvaluator(
138139 model::Nonlinear.Model,
139140 ordered_variables::Vector{MOI.VariableIndex},
141+ coloring_algorithm::SparseMatrixColorings.AbstractColoringAlgorithm = SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution),
140142 )
141143
142144Return an `NLPEvaluator` object that implements the `MOI.AbstractNLPEvaluator`
@@ -145,12 +147,13 @@ interface.
145147!!! warning
146148 Before using, you must initialize the evaluator using `MOI.initialize`.
147149"""
148- mutable struct NLPEvaluator <: MOI.AbstractNLPEvaluator
150+ mutable struct NLPEvaluator{R,C <: SparseMatrixColorings.GreedyColoringAlgorithm } <: MOI.AbstractNLPEvaluator
149151 data:: Nonlinear.Model
150152 ordered_variables:: Vector{MOI.VariableIndex}
153+ coloring_algorithm:: C
151154
152- objective:: Union{Nothing,_FunctionStorage}
153- constraints:: Vector{_FunctionStorage}
155+ objective:: Union{Nothing,_FunctionStorage{R} }
156+ constraints:: Vector{_FunctionStorage{R} }
154157 subexpressions:: Vector{_SubexpressionStorage}
155158 subexpression_order:: Vector{Int}
156159 # Storage for the subexpressions in reverse-mode automatic differentiation.
@@ -183,8 +186,17 @@ mutable struct NLPEvaluator <: MOI.AbstractNLPEvaluator
183186
184187 function NLPEvaluator (
185188 data:: Nonlinear.Model ,
186- ordered_variables:: Vector{MOI.VariableIndex} ,
189+ ordered_variables:: Vector{MOI.VariableIndex} ;
190+ coloring_algorithm:: SparseMatrixColorings.GreedyColoringAlgorithm = SparseMatrixColorings. GreedyColoringAlgorithm (; decompression= :substitution ),
187191 )
188- return new (data, ordered_variables)
192+ problem = SparseMatrixColorings. ColoringProblem (; structure= :symmetric , partition= :column )
193+ C = typeof (coloring_algorithm)
194+ R = Base. promote_op (
195+ SparseMatrixColorings. coloring,
196+ SparseArrays. SparseMatrixCSC{Bool,Int},
197+ typeof (problem),
198+ C,
199+ )
200+ return new {R,C} (data, ordered_variables, coloring_algorithm)
189201 end
190202end
0 commit comments