Skip to content
74 changes: 64 additions & 10 deletions pyomo/core/expr/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@


import inspect
import itertools
import logging
import sys
from copy import deepcopy
from collections import deque

logger = logging.getLogger('pyomo.core')

from pyomo.common.collections import ComponentSet
from pyomo.common.deprecation import deprecated, deprecation_warning
from pyomo.common.errors import DeveloperError, TemplateExpressionError
from pyomo.common.numeric_types import (
Expand Down Expand Up @@ -1395,32 +1397,75 @@ def identify_components(expr, component_types):


class IdentifyVariableVisitor(StreamBasedExpressionVisitor):
def __init__(self, include_fixed=False, named_expression_cache=None):
def __init__(
self, include_fixed=False, named_expression_cache=None, var_cache=None
):
"""Visitor that collects all unique variables participating in an
expression

Args:
include_fixed (bool): Whether to include fixed variables
named_expression_cache (optional, dict): Dict mapping ids of named
expressions to a tuple of the list of all variables and the
set of all variable ids contained in the named expression.
:meth:`walk_expression` returns a generator of the unique
variables found in the expression. If `var_cache` was
specified, then only the *new* variables found in `expr` are
returned (the full list of all variables is maintained in the
`var_cache` dict).

Parameters
----------
include_fixed : bool
If True, fixed variables will be reported

named_expression_cache : dict
Cache of named expressions that have been visited by this
walker. The value includes the variables within the named
expression as well as information for detecting when the
named expression has changed (for cache invalidation).

var_cache : ComponentSet
ComponentSet for recording all variables that have been
"seen" by this walker. If provided, this ComponentSet is
preserved between calls to :meth:`walk_expression` (so
repeated variables are not returned more than once).

"""
super().__init__()
self._include_fixed = include_fixed
# cache of visited named expressions. This dict maps
# {eid: (seen, exprs)}.
# - eid is the id() of the named expression
# - seen is the processed result for the named expression
# (including any nested named expressions)
# - exprs is used for automatically invalidating the cache (see below).
self._cache = named_expression_cache

# Stack of named expressions. This holds the tuple
# (eid, _seen, _exprs)
# where eid is the id() of the subexpression we are currently
# processing, and _seen and _exprs are from the parent context.
self._expr_stack = []
# The following attributes will be added by initializeWalker:
# self._seen: dict(eid: obj)

# cache of "seen" variables: dict(eid: VarData)
#
# Pyomo encourages the use of ComponentSet to store (ordered)
# sets of Pyomo components (and in particular, Pyomo Vars).
# However, to reduce overhead (this is about a 10-12%
# improvement), we will operate directly on the underlying dict
# data store. This is slightly evil (and definitely violates
# encapsulation), but we accept the risk as identify_variables()
# is a potentially expensive operation.
if isinstance(var_cache, ComponentSet):
var_cache = var_cache._data
self._seen = var_cache

# The following attribute will be added by initializeWalker:
# self._exprs: list of (e, e.expr) for any (nested) named expressions

def initializeWalker(self, expr):
assert not self._expr_stack
self._seen = {}
if self._seen is None:
self._seen = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use a component set?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly for historical reasons? I think that identify_variables actually predated the modern implementation of ComponentSet, and I was trying to do a minimal change here.

We should consider reworking the implementation to be based on ComponentSet. There might be a slight overhead increase (I am guessing the cost of an extra function call per leaf node) - we should benchmark that. (Although I am not expecting it to show up, as the core things like model generation, compilation, and writers do not use it).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick test with PMedian (test8) shows that using a ComponentSet within identify_variables is 10-12% slower. One possibility is that we could advertise that the walker takes a ComponentSet, but internally we operate directly on the underlying _data dictionary?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of using a ComponentSet but operating on the underlying _data dictionary. That is provides a cleaner interface without the performance hit.

self._expr_stack.append(None)
else:
self._expr_stack.append(len(self._seen))
self._exprs = None
if not self.beforeChild(None, expr, 0)[0]:
return False, self.finalizeResult(None)
Expand Down Expand Up @@ -1452,8 +1497,17 @@ def exitNode(self, node, data):
self._merge_obj_lists(_seen, _exprs)

def finalizeResult(self, result):
seen = self._seen
initial_num_seen = self._expr_stack.pop()
assert not self._expr_stack
return self._seen.values()
if initial_num_seen is None:
self._seen = None
return seen.values()
else:
# Only return the *new* variables found on this walk. This
# relies on dict iteration being in insertion order (which,
# since python 3.7, it is)
return itertools.islice(seen.values(), initial_num_seen, len(seen))

def _merge_obj_lists(self, _seen, _exprs):
self._seen.update(_seen)
Expand Down
21 changes: 21 additions & 0 deletions pyomo/core/tests/unit/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
SimpleExpressionVisitor,
StreamBasedExpressionVisitor,
ExpressionReplacementVisitor,
IdentifyVariableVisitor,
evaluate_expression,
expression_to_string,
replace_expressions,
Expand Down Expand Up @@ -252,6 +253,26 @@ def test_identify_vars_linear_expression(self):
expr = quicksum([m.x, m.x], linear=True)
self.assertEqual(list(identify_variables(expr, include_fixed=False)), [m.x])

def test_identify_vars_var_cache(self):
m = ConcreteModel()
m.x = Var()
m.y = Var()
m.z = Var()

e1 = m.x + m.y
e2 = m.y + m.z

v = IdentifyVariableVisitor()
self.assertEqual(list(v.walk_expression(e1)), [m.x, m.y])
self.assertEqual(list(v.walk_expression(e2)), [m.y, m.z])

seen = {}
v = IdentifyVariableVisitor(var_cache=seen)
self.assertEqual(list(v.walk_expression(e2)), [m.y, m.z])
self.assertEqual(list(seen.values()), [m.y, m.z])
self.assertEqual(list(v.walk_expression(e1)), [m.x])
self.assertEqual(list(seen.values()), [m.y, m.z, m.x])


class TestIdentifyParams(unittest.TestCase):
def test_identify_params_numeric(self):
Expand Down
7 changes: 4 additions & 3 deletions pyomo/devel/initialization/bounds/bound_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
# software. This software is distributed under the 3-clause BSD License.
# ____________________________________________________________________________________

from pyomo.core.base.block import BlockData
from pyomo.contrib.fbbt.fbbt import fbbt
from pyomo.devel.initialization.utils import get_vars
import logging

from pyomo.contrib.fbbt.fbbt import fbbt
from pyomo.core.base.block import BlockData
from pyomo.util.vars_from_expressions import get_vars

logger = logging.getLogger(__name__)


Expand Down
17 changes: 9 additions & 8 deletions pyomo/devel/initialization/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@
# software. This software is distributed under the 3-clause BSD License.
# ____________________________________________________________________________________

import logging

from enum import Enum
from typing import Optional
from pyomo.core.base.block import BlockData
from enum import Enum
from pyomo.devel.initialization.utils import get_vars
from pyomo.common.collections import ComponentMap
from pyomo.devel.initialization.pwl_init import (
_initialize_with_piecewise_linear_approximation,
)
from pyomo.devel.initialization.lp_approx_init import _initialize_with_LP_approximation
from pyomo.contrib.solver.common.base import SolverBase
from pyomo.devel.initialization.global_init import _initialize_with_global_solver
from pyomo.contrib.solver.common.factory import SolverFactory
from pyomo.contrib.solver.common.results import Results
import logging
from pyomo.contrib.solver.common.results import SolutionStatus
from pyomo.devel.initialization.pwl_init import (
_initialize_with_piecewise_linear_approximation,
)
from pyomo.devel.initialization.global_init import _initialize_with_global_solver
from pyomo.devel.initialization.lp_approx_init import _initialize_with_LP_approximation
from pyomo.util.vars_from_expressions import get_vars

logger = logging.getLogger(__name__)

Expand Down
6 changes: 1 addition & 5 deletions pyomo/devel/initialization/lp_approx_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,7 @@
bound_all_nonlinear_variables,
)
from pyomo.devel.initialization.pwl_init import _minimize_infeasibility
from pyomo.devel.initialization.utils import (
fix_vars_with_equal_bounds,
get_vars,
shallow_clone,
)
from pyomo.devel.initialization.utils import fix_vars_with_equal_bounds, shallow_clone
from pyomo.repn.linear import LinearRepn, LinearRepnVisitor
from pyomo.repn.util import ExitNodeDispatcher
from pyomo.common.dependencies.scipy import stats
Expand Down
10 changes: 3 additions & 7 deletions pyomo/devel/initialization/pwl_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from pyomo.contrib.piecewise.piecewise_linear_function import PiecewiseLinearFunction
from pyomo.contrib.solver.common.base import SolverBase
from pyomo.contrib.solver.common.results import SolutionStatus
from pyomo.contrib.solver.common.results import SolutionStatus, Results
from pyomo.core.base.block import BlockData
from pyomo.core.base.constraint import ConstraintData
from pyomo.core.base.expression import ExpressionData, ScalarExpression
Expand Down Expand Up @@ -52,13 +52,9 @@
from pyomo.devel.initialization.bounds.bound_variables import (
bound_all_nonlinear_variables,
)
from pyomo.devel.initialization.utils import (
fix_vars_with_equal_bounds,
get_vars,
shallow_clone,
)
from pyomo.devel.initialization.utils import fix_vars_with_equal_bounds, shallow_clone
from pyomo.repn.util import ExitNodeDispatcher
from pyomo.contrib.solver.common.results import Results
from pyomo.util.vars_from_expressions import get_vars

logger = logging.getLogger(__name__)

Expand Down
10 changes: 1 addition & 9 deletions pyomo/devel/initialization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,7 @@
from pyomo.common.collections import ComponentSet
from pyomo.core.base.block import BlockData
from pyomo.core.expr.visitor import identify_variables
from pyomo.util.vars_from_expressions import get_vars_from_components


def get_vars(m: BlockData):
return ComponentSet(
get_vars_from_components(
m, ctype=(pyo.Constraint, pyo.Objective), include_fixed=False, active=True
)
)
from pyomo.util.vars_from_expressions import get_vars


def shallow_clone(m1):
Expand Down
59 changes: 59 additions & 0 deletions pyomo/util/tests/test_vars_from_expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# ____________________________________________________________________________________
#
# Pyomo: Python Optimization Modeling Objects
# Copyright (c) 2008-2026 National Technology and Engineering Solutions of Sandia, LLC
# Under the terms of Contract DE-NA0003525 with National Technology and Engineering
# Solutions of Sandia, LLC, the U.S. Government retains certain rights in this
# software. This software is distributed under the 3-clause BSD License.
# ____________________________________________________________________________________

import pyomo.environ as pyo
from pyomo.common import unittest
from pyomo.util.vars_from_expressions import get_vars, get_vars_from_components


class TestVarsFromExpressions(unittest.TestCase):
def test_get_vars(self):
m = pyo.ConcreteModel()
m.x = pyo.Var(list(range(5)))
m.c1 = pyo.Constraint(expr=m.x[0] + m.x[1] == 0)
m.c2 = pyo.Constraint(expr=m.x[1] + m.x[2] == 0)
m.obj = pyo.Objective(expr=m.x[3] + m.x[4])

self.assertEqual(list(get_vars(m)), [m.x[0], m.x[1], m.x[2], m.x[3], m.x[4]])
# verify the default values for active and include_fixed
m.x[0].fix(0)
m.c2.deactivate()
self.assertEqual(list(get_vars(m)), [m.x[1], m.x[3], m.x[4]])

def test_get_vars_from_components(self):
m = pyo.ConcreteModel()
m.x = pyo.Var(list(range(5)))
m.c1 = pyo.Constraint(expr=m.x[0] + m.x[1] == 0)
m.c2 = pyo.Constraint(expr=m.x[1] + m.x[2] == 0)
m.obj = pyo.Objective(expr=m.x[3] + m.x[4])

self.assertEqual(
list(get_vars_from_components(m, pyo.Constraint)), [m.x[0], m.x[1], m.x[2]]
)
self.assertEqual(
list(get_vars_from_components(m, pyo.Objective)), [m.x[3], m.x[4]]
)
self.assertEqual(
list(get_vars_from_components(m, (pyo.Constraint, pyo.Objective))),
[m.x[0], m.x[1], m.x[2], m.x[3], m.x[4]],
)

# verify the default values for active and include_fixed
m.x[0].fix(0)
m.c2.deactivate()
self.assertEqual(
list(get_vars_from_components(m, pyo.Constraint)), [m.x[0], m.x[1], m.x[2]]
)
self.assertEqual(
list(get_vars_from_components(m, pyo.Objective)), [m.x[3], m.x[4]]
)
self.assertEqual(
list(get_vars_from_components(m, (pyo.Constraint, pyo.Objective))),
[m.x[0], m.x[1], m.x[2], m.x[3], m.x[4]],
)
Loading
Loading