diff --git a/pyomo/common/collections/__init__.py b/pyomo/common/collections/__init__.py index ab29d7070a6..bf4270af39f 100644 --- a/pyomo/common/collections/__init__.py +++ b/pyomo/common/collections/__init__.py @@ -12,6 +12,6 @@ from collections.abc import Mapping, MutableMapping, MutableSet, Sequence, Set from .bunch import Bunch -from .component_map import ComponentMap, DefaultComponentMap -from .component_set import ComponentSet +from .component_map import ComponentMap, DefaultComponentMap, ObjectIdMap +from .component_set import ComponentSet, ObjectIdSet from .orderedset import OrderedSet diff --git a/pyomo/common/collections/_hasher.py b/pyomo/common/collections/_hasher.py index ffbde5d8670..c0d4a4d6993 100644 --- a/pyomo/common/collections/_hasher.py +++ b/pyomo/common/collections/_hasher.py @@ -10,6 +10,19 @@ from collections import defaultdict +class _HashKey: + """Utility class to support hashing by object id() + + This class should never be instantiated, and should never be + accessed referenced by user code. Instead this provides a simple + :class:`type` that we can use as an internal flag to differentiate + between an :class:`int` key and the result from :func:`id()`. + + """ + + pass + + class HashDispatcher(defaultdict): """Dispatch table for generating "universal" hashing of all Python objects. @@ -25,11 +38,18 @@ class HashDispatcher(defaultdict): appropriate hashing strategy to each element within the tuple. """ + __slots__ = () + def __init__(self, *args, **kwargs): super().__init__(lambda: self._missing_impl, *args, **kwargs) self[tuple] = self._tuple def _missing_impl(self, val): + # Inherit the hasher from a base class, if found + for _type in val.__class__.__mro__[1:]: + if _type in self: + self[val.__class__] = ans = self[_type] + return ans(val) try: hash(val) self[val.__class__] = self._hashable @@ -43,10 +63,18 @@ def _hashable(val): @staticmethod def _unhashable(val): - return id(val) + return _HashKey, id(val) def _tuple(self, val): - return tuple(self[i.__class__](i) for i in val) + try: + # if *this tuple* is hashable, then use it as the key + hash(val) + return val + except: + # duplicate the tuple, recursively processing all fields. + # The use of val.__class__ ensures that derived things (like + # namedtuples) have their class preserved. + return val.__class__(self[i.__class__](i) for i in val) def hashable(self, obj, hashable=None): if isinstance(obj, type): @@ -60,6 +88,10 @@ def hashable(self, obj, hashable=None): return fcn is self._hashable self[cls] = self._hashable if hashable else self._unhashable + def __call__(self, obj): + # Make the dispatcher callable so that it can be used in place of id() + return self[obj.__class__](obj) + #: The global 'hasher' instance for managing "universal" hashing. #: diff --git a/pyomo/common/collections/component_map.py b/pyomo/common/collections/component_map.py index 467aa7afe91..b2aedbfe9cf 100644 --- a/pyomo/common/collections/component_map.py +++ b/pyomo/common/collections/component_map.py @@ -7,62 +7,136 @@ # software. This software is distributed under the 3-clause BSD License. # ____________________________________________________________________________________ -from collections.abc import Mapping, MutableMapping +from collections.abc import Set, Mapping, MutableMapping +from functools import partial +from operator import itemgetter from pyomo.common.autoslots import AutoSlots +from pyomo.common.formatting import tostr +from pyomo.common.numeric_types import native_logical_types from ._hasher import hasher -def _rehash_keys(encode, val): +def _rehash_keys(keygen, encode, val): if encode: - return val + return list(val.values()) else: # object id() may have changed after unpickling, # so we rebuild the dictionary keys - return {hasher[obj.__class__](obj): (obj, v) for obj, v in val.values()} + return {keygen(v[0]): v for v in val} + + +class ComponentMap_keys(Set): + """A dictionary keys view object for :class:`ComponentMap`""" + + __slots__ = ('_cm',) + + def __init__(self, cm): + self._cm = cm + + def __iter__(self): + return iter(map(itemgetter(0), self._cm._dict.values())) + + def __contains__(self, key): + return self._cm.__contains__(key) + + def __len__(self): + return self._cm.__len__() + + +class ComponentMap_items(Set): + """A dictionary items view object for :class:`ComponentMap`""" + + __slots__ = ('_cm',) + + def __init__(self, cm): + self._cm = cm + + def __iter__(self): + return iter(self._cm._dict.values()) + + def __contains__(self, item): + try: + key, val = item + except (TypeError, ValueError): + return False + return key in self._cm and self._cm._value_eq(val, self._cm[key]) + + def __len__(self): + return self._cm.__len__() + + +class ComponentMap_values(Set): + """A dictionary values view object for :class:`ComponentMap`""" + + __slots__ = ('_cm',) + + def __init__(self, cm): + self._cm = cm + + def __iter__(self): + return iter(map(itemgetter(1), self._cm._dict.values())) + + def __contains__(self, val): + """Returns True if `val` appears as a value in this ComponentMap + + .. warning:: + + This method is provided for API compatibility and is NOT + efficient (it is a linear scan through the underlying + `dict`). We *do not* recommend using it in large + applications or when performance matters. + + """ + return any(self._cm._value_eq(v, val) for v in self) + + def __len__(self): + return self._cm.__len__() class ComponentMap(AutoSlots.Mixin, MutableMapping): - """ - This class is a replacement for dict that allows Pyomo - modeling components to be used as entry keys. The - underlying mapping is based on the Python id() of the - object, which gets around the problem of hashing - subclasses of NumericValue. This class is meant for - creating mappings from Pyomo components to values. The - use of non-Pyomo components as entry keys should be - avoided. + """Mapping that admits unhashable objects as keys + + This class is a replacement for :py:`dict` that allows Pyomo + modeling components to be used as keys. The underlying mapping is + based on the Python :py:`id()` of the object, which gets around the + problem of hashing subclasses of :py:class:`NumericValue`. This + class is meant for creating mappings from Pyomo components to + values. A reference to the object is kept around as long as it has a corresponding entry in the container, so there is - no need to worry about id() clashes. + no need to worry about id() collisions. + + This class leverages :py:class:`AutoSlots` to update any id() keys + during pickling, restoration, or deepcopying. - We also override __setstate__ so that we can rebuild the - container based on possibly updated object ids after - a deepcopy or pickle. + .. warning:: + + An instance of this class should never be deepcopied/pickled + unless it is done so along with its component entries (e.g., as + part of a block). - *** An instance of this class should never be - deepcopied/pickled unless it is done so along with the - components for which it contains map entries (e.g., as - part of a block). *** """ __slots__ = ("_dict",) - __autoslot_mappers__ = {"_dict": _rehash_keys} - # Expose a "public" interface to the global _hasher dict + __autoslot_mappers__ = {"_dict": partial(_rehash_keys, hasher.__call__)} + # Expose a "public" interface to the global hasher dict (for + # backwards compatibility) hasher = hasher - def __init__(self, *args, **kwds): + def __init__(self, *args, **kwargs): # maps id_hash(obj) -> (obj,val) self._dict = {} # handle the dict-style initialization scenarios - self.update(*args, **kwds) + if args or kwargs: + self.update(*args, **kwargs) def __str__(self): """String representation of the mapping.""" - tmp = {f"{v[0]} (key={k})": v[1] for k, v in self._dict.items()} - return f"ComponentMap({tmp})" + tmp = ', '.join(f"{tostr(v[0])}: {tostr(v[1])}" for v in self._dict.values()) + return f"{self.__class__.__name__}({tmp})" # # Implement MutableMapping abstract methods @@ -72,8 +146,7 @@ def __getitem__(self, obj): try: return self._dict[hasher[obj.__class__](obj)][1] except KeyError: - _id = hasher[obj.__class__](obj) - raise KeyError(f"{obj} (key={_id})") from None + raise KeyError(obj) from None def __setitem__(self, obj, val): self._dict[hasher[obj.__class__](obj)] = (obj, val) @@ -82,11 +155,10 @@ def __delitem__(self, obj): try: del self._dict[hasher[obj.__class__](obj)] except KeyError: - _id = hasher[obj.__class__](obj) - raise KeyError(f"{obj} (key={_id})") from None + raise KeyError(obj) from None def __iter__(self): - return (obj for obj, val in self._dict.values()) + return iter(self.keys()) def __len__(self): return self._dict.__len__() @@ -96,43 +168,64 @@ def __len__(self): # # We want a specialization of update() to avoid unnecessary calls to - # id() when copying / merging ComponentMaps + # the hasher when copying / merging ComponentMaps def update(self, *args, **kwargs): - if len(args) == 1 and not kwargs and isinstance(args[0], ComponentMap): - return self._dict.update(args[0]._dict) + if len(args) == 1 and args[0].__class__ is self.__class__: + self._dict.update(args[0]._dict) + args = () return super().update(*args, **kwargs) + def _rekey_items(self, items): + """Utility method for mapping key-value pairs into local hash keys""" + return ((hasher[key.__class__](key), val) for key, val in items) + + @staticmethod + def _value_eq(a, b): + # Note: check "is" first to help avoid creation of Pyomo + # expressions (for the case that the values contain the same + # Pyomo component) + if a is b: + return True + diff = a != b + return (not diff) if diff.__class__ in native_logical_types else False + # We want to avoid generating Pyomo expressions due to comparing the # keys, so look up each entry from other in this dict. def __eq__(self, other): + """Return self==other.""" if self is other: return True if not isinstance(other, Mapping) or len(self) != len(other): return False # Note we have already verified the dicts are the same size - for key, val in other.items(): - other_id = hasher[key.__class__](key) - if other_id not in self._dict: - return False - self_val = self._dict[other_id][1] - # Note: check "is" first to help avoid creation of Pyomo - # expressions (for the case that the values contain the same - # Pyomo component) - if self_val is not val and self_val != val: - return False - return True + if other.__class__ is self.__class__: + # shortcut for comparing ComponentMaps to each other: avoid + # regenerating any keys + other_items = ((key, val[1]) for key, val in other._dict.items()) + else: + other_items = self._rekey_items(other.items()) + + _dict = self._dict + _eq = self._value_eq + return all(key in _dict and _eq(val, _dict[key][1]) for key, val in other_items) def __ne__(self, other): - return not (self == other) + """Return self!=other.""" + return not self.__eq__(other) # - # The remaining methods have slow default - # implementations for MutableMapping. In particular, - # they rely KeyError catching, which is slow for this - # class because KeyError messages use fully qualified - # names. + # The remaining methods have slow default implementations # + def keys(self): + return ComponentMap_keys(self) + + def values(self): + return ComponentMap_values(self) + + def items(self): + return ComponentMap_items(self) + def __contains__(self, obj): return hasher[obj.__class__](obj) in self._dict @@ -167,6 +260,9 @@ class DefaultComponentMap(ComponentMap): __slots__ = ("default_factory",) def __init__(self, default_factory=None, *args, **kwargs): + if default_factory is not None and not callable(default_factory): + args = (default_factory,) + args + default_factory = None super().__init__(*args, **kwargs) self.default_factory = default_factory @@ -182,3 +278,59 @@ def __getitem__(self, obj): return self._dict[_key][1] else: return self.__missing__(obj) + + +class ObjectIdMap(ComponentMap): + """A faster version of :py:class:`ComponentMap` + + :py:class:`ObjectIdMap` is a lighter-weight version of + :py:class:`ComponentMap`. By unconditionally using :py:`id()` to + generate all keys, this class performs approximately 50% faster than + :py:class:`ComponentMap` at the expense of being slightly more + fragile. + + It is _strongly_ recommended to only use Pyomo components as + :class:`ObjectIdMap` keys. + + .. warning:: + + **DO NOT** store keys that do not return persistent + :py:func:`id()` values. In particular, avoid certain immutable + data types like :class:`tuple` or other immutable objects, + strings, and long integers. Doing so may result in failed + lookups or duplicate entries. + + If you want to mix immutable data types with other unhashable + objects (like Pyomo :class:`Var` or :class:`Param` components), + please use :class:`ComponentMap`. + + """ + + __slots__ = () + __autoslot_mappers__ = {"_dict": partial(_rehash_keys, id)} + + def __getitem__(self, obj): + try: + return self._dict[id(obj)][1] + except KeyError: + raise KeyError(obj) from None + + def __setitem__(self, obj, val): + self._dict[id(obj)] = (obj, val) + + def __delitem__(self, obj): + try: + del self._dict[id(obj)] + except KeyError: + raise KeyError(obj) from None + + def __contains__(self, obj): + return id(obj) in self._dict + + def _rekey_items(self, items): + return ((id(key), val) for key, val in items) + + def __str__(self): + """String representation of the mapping.""" + tmp = [f"{v[0]} (key={k}): {v[1]}" for k, v in self._dict.items()] + return f"{self.__class__.__name__}({', '.join(tmp)})" diff --git a/pyomo/common/collections/component_set.py b/pyomo/common/collections/component_set.py index 3216530959c..67c8562e84f 100644 --- a/pyomo/common/collections/component_set.py +++ b/pyomo/common/collections/component_set.py @@ -8,13 +8,15 @@ # ____________________________________________________________________________________ from collections.abc import MutableSet, Set +from functools import partial from pyomo.common.autoslots import AutoSlots +from pyomo.common.formatting import tostr from ._hasher import hasher -def _rehash_keys(encode, val): +def _rehash_keys(keygen, encode, val): if encode: # TBD [JDS 2/2024]: if we # @@ -26,40 +28,41 @@ def _rehash_keys(encode, val): # autoslots.fast_deepcopy, but couldn't find an obvious bug. # There is no error if we just return the original dict, or if # we return a tuple(val.values) - return val + return tuple(val.values()) else: # object id() may have changed after unpickling, # so we rebuild the dictionary keys - return {hasher[obj.__class__](obj): obj for obj in val.values()} + return {keygen(obj): obj for obj in val} class ComponentSet(AutoSlots.Mixin, MutableSet): - """ - This class is a replacement for set that allows Pyomo - modeling components to be used as entries. The - underlying hash is based on the Python id() of the - object, which gets around the problem of hashing - subclasses of NumericValue. This class is meant for - creating sets of Pyomo components. The use of non-Pyomo - components as entries should be avoided (as the behavior - is undefined). + """Set that admits unhashable objects. + + This class is a replacement for :py:`set` that allows Pyomo modeling + components to be used as entries. The underlying hash is based on + the Python :py:`id()` of the object, which gets around the problem + of hashing subclasses of :py:class:`NumericValue`. This class is + meant for creating sets of Pyomo components. References to objects are kept around as long as they are entries in the container, so there is no need to - worry about id() clashes. + worry about id() collisions. + + This class leverages :py:class:`AutoSlots` to update any id() keys + during pickling, restoration, or deepcopying. - We also override __setstate__ so that we can rebuild the - container based on possibly updated object ids after - a deepcopy or pickle. + .. warning:: + + An instance of this class should never be deepcopied/pickled + unless it is done so along with its component entries (e.g., as + part of a block). - *** An instance of this class should never be - deepcopied/pickled unless it is done so along with - its component entries (e.g., as part of a block). *** """ __slots__ = ("_data",) - __autoslot_mappers__ = {"_data": _rehash_keys} - # Expose a "public" interface to the global _hasher dict + __autoslot_mappers__ = {"_data": partial(_rehash_keys, hasher.__call__)} + # Expose a "public" interface to the global hasher dict (for + # backwards compatibility) hasher = hasher def __init__(self, iterable=None): @@ -69,16 +72,18 @@ def __init__(self, iterable=None): self.update(iterable) def __str__(self): - """String representation of the mapping.""" - tmp = [f"{v} (key={k})" for k, v in self._data.items()] - return f"ComponentSet({tmp})" + """String representation of the set.""" + tmp = (tostr(k) for k in self._data.values()) + return f"{self.__class__.__name__}({', '.join(tmp)})" - def update(self, iterable): + def update(self, *iterables): """Update a set with the union of itself and others.""" - if isinstance(iterable, ComponentSet): - self._data.update(iterable._data) - else: - self._data.update((hasher[val.__class__](val), val) for val in iterable) + for iterable in iterables: + if iterable.__class__ is self.__class__: + self._data.update(iterable._data) + else: + for val in iterable: + self.add(val) # # Implement MutableSet abstract methods @@ -99,9 +104,10 @@ def add(self, val): def discard(self, val): """Remove an element. Do not raise an exception if absent.""" - _id = hasher[val.__class__](val) - if _id in self._data: - del self._data[_id] + try: + del self._data[hasher[val.__class__](val)] + except KeyError: + pass # # Overload MutableSet default implementations @@ -110,11 +116,12 @@ def discard(self, val): def __eq__(self, other): if self is other: return True - if not isinstance(other, Set): + if not isinstance(other, Set) or len(self._data) != len(other): return False - return len(self) == len(other) and all( - hasher[val.__class__](val) in self._data for val in other - ) + if other.__class__ is self.__class__: + return all(key in self._data for key in other._data) + else: + return all(map(self.__contains__, other)) def __ne__(self, other): return not (self == other) @@ -129,9 +136,59 @@ def clear(self): self._data.clear() def remove(self, val): - """Remove an element. If not a member, raise a KeyError.""" + """Remove an element. If not a member, raise a :class:`KeyError`.""" try: del self._data[hasher[val.__class__](val)] except KeyError: - _id = hasher[val.__class__](val) - raise KeyError(f"{val} (key={_id})") from None + raise KeyError(val) + + +class ObjectIdSet(ComponentSet): + """A faster version of :py:class:`ComponentSet` + + :py:class:`ObjectIdSet` is a lighter-weight version of + :py:class:`ComponentSet`. By unconditionally using :py:`id()` to + hash all members, this class performs approximately 50% faster than + :py:class:`ComponentSet` at the expense of being slightly more + fragile. + + It is _strongly_ recommended to only store Pyomo components in + :class:`ObjectIdSet` containers. + + .. warning:: + + **DO NOT** store objects that do not return persistent + :py:func:`id()` values. In particular, avoid certain immutable + data types like :class:`tuple` or other immutable objects, + strings, and long integers. Doing so may result in failed + lookups or duplicate entries. + + If you want to mix immutable data types with other unhashable + objects (like Pyomo :class:`Var` or :class:`Param` components), + please use :class:`ComponentSet`. + + """ + + __slots__ = () + __autoslot_mappers__ = {"_data": partial(_rehash_keys, id)} + + def __contains__(self, val): + return id(val) in self._data + + def add(self, val): + """Add an element.""" + self._data[id(val)] = val + + def discard(self, val): + """Remove an element. Do not raise an exception if absent.""" + try: + del self._data[id(val)] + except KeyError: + pass + + def remove(self, val): + """Remove an element. If not a member, raise a KeyError.""" + try: + del self._data[id(val)] + except KeyError: + raise KeyError(val) from None diff --git a/pyomo/common/tests/test_component_map.py b/pyomo/common/tests/test_component_map.py index 19a0d192e5c..69e19ed81af 100644 --- a/pyomo/common/tests/test_component_map.py +++ b/pyomo/common/tests/test_component_map.py @@ -7,18 +7,233 @@ # software. This software is distributed under the 3-clause BSD License. # ____________________________________________________________________________________ +import pickle import pyomo.common.unittest as unittest -from pyomo.common.collections import ComponentMap, ComponentSet, DefaultComponentMap -from pyomo.environ import ConcreteModel, Block, Var, Constraint +from pyomo.common.collections._hasher import _HashKey +from pyomo.common.collections.component_map import ( + ComponentMap, + DefaultComponentMap, + ObjectIdMap, +) +from pyomo.common.collections.component_set import ComponentSet +from pyomo.common.envvar import is_pypy +from pyomo.environ import ConcreteModel, Block, Var, Constraint, Param -class TestComponentMap(unittest.TestCase): +class ComponentMapBaseTests: + + def test_str(self): + m = ConcreteModel() + m.x = Var() + m.y = Param([1], mutable=True) + cm = self.CM() + cm[m.x] = m.y[1] + _id = id(m.x) + cm[_id] = 42 + cm[(5, m.x)] = 7 + self.assertEqual( + f"{self.CM.__name__}" f"(x: y[1], {_id}: 42, (5, x): 7)", str(cm) + ) + + def test_get_del_item(self): + m = ConcreteModel() + m.x = Var() + + cm = self.CM() + cm[m] = 10 + cm[m.x] = 20 + cm[3] = 30 + self.assertEqual(3, len(cm)) + self.assertEqual(cm[m], 10) + self.assertEqual(cm[m.x], 20) + self.assertEqual(cm[3], 30) + + del cm[m.x] + self.assertEqual(2, len(cm)) + self.assertEqual(cm[m], 10) + self.assertEqual(cm[3], 30) + + self.assertEqual(cm.get(m), 10) + self.assertEqual(cm.get(m, 100), 10) + self.assertEqual(cm.get(m.x), None) + self.assertEqual(cm.get(m.x, 100), 100) + self.assertEqual(cm.get(3), 30) + self.assertEqual(cm.get(3, 100), 30) + + with self.assertRaisesRegex(KeyError, repr(m.x)): + cm[m.x] + + with self.assertRaisesRegex(KeyError, repr(m.x)): + del cm[m.x] + + self.assertEqual(2, len(cm)) + self.assertEqual(cm[m], 10) + self.assertEqual(cm[3], 30) + + def test_iters(self): + m = ConcreteModel() + m.x = Var() + + cm = self.CM() + cm[m] = 10 + cm[m.x] = 20 + cm[3] = 10 + + self.assertEqual([m, m.x, 3], list(cm)) + + k = cm.keys() + self.assertEqual([m, m.x, 3], list(k)) + self.assertEqual(3, len(k)) + self.assertIn(m, k) + self.assertIn(m.x, k) + self.assertNotIn(4, k) + + v = cm.values() + self.assertEqual([10, 20, 10], list(v)) + self.assertEqual(3, len(v)) + self.assertIn(10, v) + self.assertIn(20, v) + self.assertNotIn(30, v) + + i = cm.items() + self.assertEqual([(m, 10), (m.x, 20), (3, 10)], list(i)) + self.assertEqual(3, len(i)) + self.assertIn((m, 10), i) + self.assertIn((m.x, 20), i) + self.assertIn((3, 10), i) + self.assertNotIn((3, 30), i) + self.assertNotIn((4, 10), i) + self.assertNotIn('hi', i) + self.assertNotIn(50, i) + self.assertNotIn((1, 2, 3), i) + + # These are views... and should update to reflect the current state + del cm[m] + + self.assertEqual([m.x, 3], list(k)) + self.assertEqual(2, len(k)) + self.assertNotIn(m, k) + self.assertIn(m.x, k) + self.assertNotIn(4, k) + + self.assertEqual([20, 10], list(v)) + self.assertEqual(2, len(v)) + self.assertIn(10, v) + self.assertIn(20, v) + self.assertNotIn(30, v) + + self.assertEqual([(m.x, 20), (3, 10)], list(i)) + self.assertEqual(2, len(i)) + self.assertNotIn((m, 10), i) + self.assertIn((m.x, 20), i) + self.assertIn((3, 10), i) + self.assertNotIn((3, 30), i) + self.assertNotIn((4, 10), i) + self.assertNotIn('hi', i) + self.assertNotIn(50, i) + self.assertNotIn((1, 2, 3), i) + + def test_eq(self): + m = ConcreteModel() + m.x = Var() + m.y = Var() + m.c = Constraint() + + cm1 = self.CM() + cm1[m] = 10 + cm1[m.x] = 20 + cm1[m.c] = 30 + + self.assertEqual(cm1, cm1) + + cm2 = self.CM() + cm2[m] = 10 + cm2[m.c] = 30 + self.assertNotEqual(cm1, cm2) + + cm2[m.y] = 20 + self.assertNotEqual(cm1, cm2) + + del cm2[m.y] + cm2[m.x] = 20 + self.assertEqual(cm1, cm2) + + self.assertNotEqual(cm1, {m: 10, m.c: 30}) + del cm1[m.x] + self.assertEqual(cm1, {m: 10, m.c: 30}) + self.assertNotEqual(cm1, {m: 10, m.c: 40}) + + def test_init_update(self): + m = ConcreteModel() + m.x = Var() + m.c = Constraint() + + cm1 = self.CM() + cm1[m] = 10 + cm1[m.x] = 20 + cm1[m.c] = 30 + + cm2 = self.CM(cm1) + self.assertIsNot(cm1, cm2) + self.assertIsNot(cm1._dict, cm2._dict) + self.assertEqual(cm1, cm2) + + cm3 = self.CM({m: 10, m.c: 30}) + del cm2[m.x] + self.assertEqual(cm2, cm3) + + cm3.update(cm1) + self.assertNotEqual(cm2, cm3) + self.assertEqual(cm1, cm3) + + def test_set_default(self): + m = ConcreteModel() + m.x = Var() + m.c = Constraint() + + cm = self.CM() + self.assertIs(cm.setdefault(m, m.x), m.x) + self.assertEqual(cm, {m: m.x}) + self.assertIs(cm.setdefault(m, m.c), m.x) + self.assertEqual(cm, {m: m.x}) + + cm.clear() + self.assertEqual(cm, {}) + self.assertIs(cm.setdefault(m, m.c), m.c) + self.assertEqual(cm, {m: m.c}) + + +class TestComponentMap(ComponentMapBaseTests, unittest.TestCase): + def setUp(self): + self.CM = ComponentMap + + def test_hasher(self): + m = self.CM() + a = 'str' + m[a] = 5 + self.assertTrue(m.hasher.hashable(a)) + self.assertTrue(m.hasher.hashable(str)) + self.assertEqual(m._dict, {a: (a, 5)}) + del m[a] + + m.hasher.hashable(a, False) + m[a] = 5 + self.assertFalse(m.hasher.hashable(a)) + self.assertFalse(m.hasher.hashable(str)) + self.assertEqual(m._dict, {(_HashKey, id(a)): (a, 5)}) + + class TMP: + pass + + with self.assertRaises(KeyError): + m.hasher.hashable(TMP) + def test_tuple(self): m = ConcreteModel() m.v = Var() m.c = Constraint(expr=m.v >= 0) - m.cm = cm = ComponentMap() + m.cm = cm = self.CM() cm[(1, 2)] = 5 self.assertEqual(len(cm), 1) @@ -47,31 +262,79 @@ def test_tuple(self): self.assertIn((1, (2, m.v)), m.cm) self.assertNotIn((1, (2, m.v)), i.cm) - def test_hasher(self): - m = ComponentMap() - a = 'str' - m[a] = 5 - self.assertTrue(m.hasher.hashable(a)) - self.assertTrue(m.hasher.hashable(str)) - self.assertEqual(m._dict, {a: (a, 5)}) - del m[a] + def test_id_int_collision(self): + m = ConcreteModel() + m.x = Var() + cm = self.CM() - m.hasher.hashable(a, False) - m[a] = 5 - self.assertFalse(m.hasher.hashable(a)) - self.assertFalse(m.hasher.hashable(str)) - self.assertEqual(m._dict, {id(a): (a, 5)}) + cm[m.x] = 1 + cm[id(m.x)] = 2 + self.assertEqual(len(cm), 2) + self.assertIn(m.x, cm) + self.assertIn(id(m.x), cm) # Note: different from ObjectIdMap + self.assertEqual(cm[m.x], 1) - class TMP: - pass + a = (1, (m.x, 3)) + b = (1, (m.x, 3)) + self.assertNotEqual(id(a), id(b)) - with self.assertRaises(KeyError): - m.hasher.hashable(TMP) + cm[a] = 3 + cm[b] = 4 + self.assertEqual(len(cm), 3) # Note: different from ObjectIdMap + self.assertIn(a, cm) + self.assertIn(b, cm) + self.assertEqual(cm[a], 4) # Note: different from ObjectIdMap + self.assertEqual(cm[b], 4) + self.assertIn((1, (m.x, 3)), cm) # Note: different from ObjectIdMap + + def test_pickle(self): + m = ConcreteModel() + m.x = Var() + m.c = Constraint() + + cm = self.CM() + cm[1] = 10 + cm[m.x] = 20 + cm[(1, (2, m.x))] = 30 + cm[m.c] = 40 + m.cm = cm + + i = pickle.loads(pickle.dumps(m)) + self.assertIsNot(i, m) + self.assertIsNot(i.cm, m.cm) + self.assertIn(1, i.cm) + self.assertEqual(i.cm[1], 10) + self.assertNotIn(m.x, i.cm) + self.assertIn(i.x, i.cm) + self.assertEqual(i.cm[i.x], 20) + self.assertNotIn((1, (2, m.x)), i.cm) + self.assertIn((1, (2, i.x)), i.cm) + self.assertEqual(i.cm[(1, (2, i.x))], 30) + self.assertNotIn(m.c, i.cm) + self.assertIn(i.c, i.cm) + self.assertEqual(i.cm[i.c], 40) + _items = iter(i.cm._dict.items()) + k, v = next(_items) + self.assertEqual(k, 1) + self.assertEqual(v, (1, 10)) + k, v = next(_items) + self.assertEqual(k, (_HashKey, id(i.x))) + self.assertEqual(v, (i.x, 20)) + k, v = next(_items) + self.assertEqual(k, (1, (2, (_HashKey, id(i.x))))) + self.assertEqual(v, ((1, (2, i.x)), 30)) + k, v = next(_items) + self.assertEqual(k, i.c) + self.assertEqual(v, (i.c, 40)) + + +class TestDefaultComponentMap(ComponentMapBaseTests, unittest.TestCase): + def setUp(self): + self.CM = DefaultComponentMap -class TestDefaultComponentMap(unittest.TestCase): def test_default_component_map(self): - dcm = DefaultComponentMap(ComponentSet) + dcm = self.CM(ComponentSet) m = ConcreteModel() m.x = Var() @@ -98,7 +361,7 @@ def test_default_component_map(self): self.assertIn(m.b, dcm[m.b.y]) def test_no_default_factory(self): - dcm = DefaultComponentMap() + dcm = self.CM() dcm['found'] = 5 self.assertEqual(len(dcm), 1) @@ -107,3 +370,94 @@ def test_no_default_factory(self): with self.assertRaisesRegex(KeyError, "'missing'"): dcm["missing"] + + +class TestObjectIdMap(ComponentMapBaseTests, unittest.TestCase): + def setUp(self): + self.CM = ObjectIdMap + + def test_str(self): + m = ConcreteModel() + m.x = Var() + m.y = Param([1], mutable=True) + cm = self.CM() + cm[m.x] = m.y[1] + _id = id(m.x) + cm[_id] = 42 + _idid = id(_id) + _tup = (5, m.x) + cm[_tup] = 7 + self.assertEqual( + f"{self.CM.__name__}" + f"(x (key={_id}): y[1], {_id} (key={_idid}): 42, {_tup} (key={id(_tup)}): 7)", + str(cm), + ) + + def test_id_int_collision(self): + m = ConcreteModel() + m.x = Var() + cm = self.CM() + + cm[m.x] = 1 + cm[id(m.x)] = 2 + self.assertEqual(len(cm), 2) + self.assertIn(m.x, cm) + # In pypy, ints from id() hash consistently; in cpython they do not + if is_pypy: + self.assertIn(id(m.x), cm) # Note: different from ComponentMap + else: + self.assertNotIn(id(m.x), cm) # Note: different from ComponentMap + self.assertEqual(cm[m.x], 1) + + a = (1, (m.x, 3)) + b = (1, (m.x, 3)) + self.assertNotEqual(id(a), id(b)) + + cm[a] = 3 + cm[b] = 4 + self.assertEqual(len(cm), 4) # Note: different from ComponentMap + self.assertIn(a, cm) + self.assertIn(b, cm) + self.assertEqual(cm[a], 3) # Note: different from ComponentMap + self.assertEqual(cm[b], 4) + self.assertNotIn((1, (m.x, 3)), cm) # Note: different from ComponentMap + + def test_pickle(self): + m = ConcreteModel() + m.x = Var() + m.c = Constraint() + + cm = self.CM() + cm[1] = 10 + cm[m.x] = 20 + cm[(1, (2, m.x))] = 30 + cm[m.c] = 40 + m.cm = cm + + i = pickle.loads(pickle.dumps(m)) + self.assertIsNot(i, m) + self.assertIsNot(i.cm, m.cm) + self.assertIn(1, i.cm) + self.assertEqual(i.cm[1], 10) + self.assertNotIn(m.x, i.cm) + self.assertIn(i.x, i.cm) + self.assertEqual(i.cm[i.x], 20) + self.assertNotIn((1, (2, m.x)), i.cm) + self.assertNotIn((1, (2, i.x)), i.cm) # Note: different from ComponentMap + self.assertNotIn(m.c, i.cm) + self.assertIn(i.c, i.cm) + self.assertEqual(i.cm[i.c], 40) + + _items = iter(i.cm._dict.items()) + k, v = next(_items) + self.assertEqual(k, id(1)) + self.assertEqual(v, (1, 10)) + k, v = next(_items) + self.assertEqual(k, id(i.x)) + self.assertEqual(v, (i.x, 20)) + k, v = next(_items) + self.assertEqual(k, id(v[0])) + self.assertEqual(v, ((1, (2, i.x)), 30)) + k, v = next(_items) + self.assertEqual(k, id(i.c)) + self.assertEqual(v, (i.c, 40)) diff --git a/pyomo/common/tests/test_component_set.py b/pyomo/common/tests/test_component_set.py new file mode 100644 index 00000000000..26f9dff299e --- /dev/null +++ b/pyomo/common/tests/test_component_set.py @@ -0,0 +1,206 @@ +# ___________________________________________________________________________ +# +# Pyomo: Python Optimization Modeling Objects +# Copyright (c) 2008-2025 +# National Technology and Engineering Solutions of Sandia, LLC +# Under the terms of Contract DE-NA0003525 with National Technology and +# Engineering Solutions of Sandia, LLC, the U.S. Government retains certain +# rights in this software. +# This software is distributed under the 3-clause BSD License. +# ___________________________________________________________________________ + +import pickle +import pyomo.common.unittest as unittest + +from pyomo.common.collections._hasher import _HashKey +from pyomo.common.collections.component_set import ComponentSet, ObjectIdSet +from pyomo.environ import ConcreteModel, Var, Constraint + + +class ComponentSetBaseTests: + + def test_str(self): + m = ConcreteModel() + m.x = Var() + cs = self.CS() + cs.add(m.x) + _id = id(m.x) + cs.add(_id) + cs.add((5, m.x)) + self.assertEqual(f"{self.CS.__name__}" f"(x, {_id}, (5, x))", str(cs)) + + def test_add_remove_discard(self): + m = ConcreteModel() + m.x = Var() + + cs = self.CS() + cs.add(m) + cs.add(m.x) + cs.add(3) + self.assertEqual(3, len(cs)) + self.assertIn(m, cs) + self.assertIn(m.x, cs) + self.assertIn(3, cs) + + # Re-adding doesn't change anything + cs.add(m) + cs.add(m.x) + cs.add(3) + self.assertEqual(3, len(cs)) + self.assertIn(m, cs) + self.assertIn(m.x, cs) + self.assertIn(3, cs) + + cs.remove(m.x) + self.assertEqual(2, len(cs)) + self.assertIn(m, cs) + self.assertIn(3, cs) + + with self.assertRaisesRegex(KeyError, repr(m.x)): + cs.remove(m.x) + cs.discard(m.x) + self.assertEqual(2, len(cs)) + self.assertIn(m, cs) + self.assertIn(3, cs) + + cs.discard(m) + self.assertEqual(1, len(cs)) + self.assertIn(3, cs) + + def test_iter(self): + m = ConcreteModel() + m.x = Var() + + cs = self.CS([m, m.x, 3]) + self.assertEqual([m, m.x, 3], list(cs)) + + def test_eq(self): + m = ConcreteModel() + m.x = Var() + m.y = Var() + m.c = Constraint() + + cs1 = self.CS([m, m.x, m.c]) + self.assertEqual(cs1, cs1) + + cs2 = self.CS([m, m.c]) + self.assertNotEqual(cs1, cs2) + + cs2.add(m.y) + self.assertNotEqual(cs1, cs2) + + cs2.remove(m.y) + cs2.add(m.x) + self.assertEqual(cs1, cs2) + + self.assertNotEqual(cs1, {m, m.c}) + cs1.remove(m.x) + self.assertEqual(cs1, {m, m.c}) + + def test_clear(self): + m = ConcreteModel() + m.x = Var() + m.c = Constraint() + + cs1 = self.CS([m, m.x, m.c]) + cs2 = self.CS(cs1) + self.assertEqual(cs1, cs2) + cs1.clear() + self.assertNotEqual(cs1, cs2) + self.assertEqual(cs1, set()) + + def test_init_update(self): + m = ConcreteModel() + m.x = Var() + m.c = Constraint() + + cs1 = self.CS([m, m.x, m.c]) + + cs2 = self.CS(cs1) + self.assertIsNot(cs1, cs2) + self.assertIsNot(cs1._data, cs2._data) + self.assertEqual(cs1, cs2) + + cs3 = self.CS({m, m.c}) + cs2.discard(m.x) + self.assertEqual(cs2, cs3) + + cs3.update(cs1) + self.assertNotEqual(cs2, cs3) + self.assertEqual(cs1, cs3) + + +class TestComponentSet(ComponentSetBaseTests, unittest.TestCase): + def setUp(self): + self.CS = ComponentSet + + def test_pickle(self): + m = ConcreteModel() + m.x = Var() + m.c = Constraint() + + cs = self.CS([1, m.x, (1, (2, m.x)), m.c]) + m.cs = cs + + i = pickle.loads(pickle.dumps(m)) + self.assertIsNot(i, m) + self.assertIsNot(i.cs, m.cs) + self.assertIn(1, i.cs) + self.assertNotIn(m.x, i.cs) + self.assertIn(i.x, i.cs) + self.assertNotIn((1, (2, m.x)), i.cs) + self.assertIn((1, (2, i.x)), i.cs) + self.assertNotIn(m.c, i.cs) + self.assertIn(i.c, i.cs) + + _items = iter(i.cs._data.items()) + k, v = next(_items) + self.assertEqual(k, 1) + self.assertEqual(v, 1) + k, v = next(_items) + self.assertEqual(k, (_HashKey, id(i.x))) + self.assertEqual(v, i.x) + k, v = next(_items) + self.assertEqual(k, (1, (2, (_HashKey, id(i.x))))) + self.assertEqual(v, (1, (2, i.x))) + k, v = next(_items) + self.assertEqual(k, i.c) + self.assertEqual(v, i.c) + + +class TestObjectIdSet(ComponentSetBaseTests, unittest.TestCase): + def setUp(self): + self.CS = ObjectIdSet + + def test_pickle(self): + m = ConcreteModel() + m.x = Var() + m.c = Constraint() + + cs = self.CS([1, m.x, (1, (2, m.x)), m.c]) + m.cs = cs + + i = pickle.loads(pickle.dumps(m)) + self.assertIsNot(i, m) + self.assertIsNot(i.cs, m.cs) + self.assertIn(1, i.cs) # Note: different from ComponentMap + self.assertNotIn(m.x, i.cs) + self.assertIn(i.x, i.cs) + self.assertNotIn((1, (2, m.x)), i.cs) + self.assertNotIn((1, (2, i.x)), i.cs) # Note: different from ComponentMap + self.assertNotIn(m.c, i.cs) + self.assertIn(i.c, i.cs) + + _items = iter(i.cs._data.items()) + k, v = next(_items) + self.assertEqual(k, id(v)) + self.assertEqual(v, 1) + k, v = next(_items) + self.assertEqual(k, id(i.x)) + self.assertEqual(v, i.x) + k, v = next(_items) + self.assertEqual(k, id(v)) + self.assertEqual(v, (1, (2, i.x))) + k, v = next(_items) + self.assertEqual(k, id(i.c)) + self.assertEqual(v, i.c) diff --git a/pyomo/contrib/solver/tests/solvers/test_gurobi_minlp_walker.py b/pyomo/contrib/solver/tests/solvers/test_gurobi_minlp_walker.py index 19a8498f08e..658ad555ff4 100644 --- a/pyomo/contrib/solver/tests/solvers/test_gurobi_minlp_walker.py +++ b/pyomo/contrib/solver/tests/solvers/test_gurobi_minlp_walker.py @@ -91,13 +91,13 @@ def test_var_domains(self): # we need to update here in order to be able to test expr. visitor.grb_model.update() - x1 = visitor.var_map[id(m.x1)] - x2 = visitor.var_map[id(m.x2)] - x3 = visitor.var_map[id(m.x3)] - y1 = visitor.var_map[id(m.y1)] - y2 = visitor.var_map[id(m.y2)] - y3 = visitor.var_map[id(m.y3)] - z1 = visitor.var_map[id(m.z1)] + x1 = visitor.var_map[m.x1] + x2 = visitor.var_map[m.x2] + x3 = visitor.var_map[m.x3] + y1 = visitor.var_map[m.y1] + y2 = visitor.var_map[m.y2] + y3 = visitor.var_map[m.y3] + z1 = visitor.var_map[m.z1] self.assertEqual(x1.lb, 0) self.assertEqual(x1.ub, float('inf')) @@ -144,11 +144,11 @@ def test_var_bounds(self): # we need to update here in order to be able to test expr. visitor.grb_model.update() - x2 = visitor.var_map[id(m.x2)] - x3 = visitor.var_map[id(m.x3)] - y1 = visitor.var_map[id(m.y1)] - y2 = visitor.var_map[id(m.y2)] - z1 = visitor.var_map[id(m.z1)] + x2 = visitor.var_map[m.x2] + x3 = visitor.var_map[m.x3] + y1 = visitor.var_map[m.y1] + y2 = visitor.var_map[m.y2] + z1 = visitor.var_map[m.z1] self.assertEqual(x2.lb, -34) self.assertEqual(x2.ub, 45) @@ -174,8 +174,8 @@ def test_write_addition(self): visitor = self.get_visitor() _, expr = visitor.walk_expression(m.c.body) - x1 = visitor.var_map[id(m.x1)] - x2 = visitor.var_map[id(m.x2)] + x1 = visitor.var_map[m.x1] + x2 = visitor.var_map[m.x2] # This is a linear expression self.assertEqual(expr.size(), 2) @@ -191,8 +191,8 @@ def test_write_subtraction(self): visitor = self.get_visitor() _, expr = visitor.walk_expression(m.c.body) - x1 = visitor.var_map[id(m.x1)] - x2 = visitor.var_map[id(m.x2)] + x1 = visitor.var_map[m.x1] + x2 = visitor.var_map[m.x2] # Also linear, whoot! self.assertEqual(expr.size(), 2) @@ -208,8 +208,8 @@ def test_write_product(self): visitor = self.get_visitor() _, expr = visitor.walk_expression(m.c.body) - x1 = visitor.var_map[id(m.x1)] - x2 = visitor.var_map[id(m.x2)] + x1 = visitor.var_map[m.x1] + x2 = visitor.var_map[m.x2] # This is quadratic self.assertEqual(expr.size(), 1) @@ -227,7 +227,7 @@ def test_write_product_with_fixed_var(self): visitor = self.get_visitor() _, expr = visitor.walk_expression(m.c.body) - x1 = visitor.var_map[id(m.x1)] + x1 = visitor.var_map[m.x1] # this is linear self.assertEqual(expr.size(), 1) @@ -279,8 +279,8 @@ def test_write_division_linear(self): visitor = self.get_visitor() _, expr = visitor.walk_expression(m.c.body) - x1 = visitor.var_map[id(m.x1)] - x2 = visitor.var_map[id(m.x2)] + x1 = visitor.var_map[m.x1] + x2 = visitor.var_map[m.x2] # linear self.assertEqual(expr.size(), 2) @@ -297,7 +297,7 @@ def test_write_linear_power_expression_var_const(self): visitor = self.get_visitor() _, expr = visitor.walk_expression(m.c.body) - x1 = visitor.var_map[id(m.x1)] + x1 = visitor.var_map[m.x1] # It's just a single var self.assertIs(expr, x1) @@ -335,7 +335,7 @@ def test_write_quadratic_power_expression_var_const(self): _, expr = visitor.walk_expression(m.c.body) # This is quadratic - x1 = visitor.var_map[id(m.x1)] + x1 = visitor.var_map[m.x1] self.assertEqual(expr.size(), 1) lin_expr = expr.getLinExpr() @@ -386,8 +386,8 @@ def test_write_power_expression_var_var(self): # You can't actually use this in a model in Gurobi 12, but you can build the # expression... (It fails during the solve.) - x1 = visitor.var_map[id(m.x1)] - x2 = visitor.var_map[id(m.x2)] + x1 = visitor.var_map[m.x1] + x2 = visitor.var_map[m.x2] opcode, data, parent = self._get_nl_expr_tree(visitor, expr) @@ -402,7 +402,7 @@ def test_write_power_expression_const_var(self): visitor = self.get_visitor() _, expr = visitor.walk_expression(m.c.body) - x2 = visitor.var_map[id(m.x2)] + x2 = visitor.var_map[m.x2] opcode, data, parent = self._get_nl_expr_tree(visitor, expr) @@ -438,7 +438,7 @@ def test_write_absolute_value_of_var(self): # expr is actually an auxiliary variable. We should # get a constraint: # expr == abs(x1) - x1 = visitor.var_map[id(m.x1)] + x1 = visitor.var_map[m.x1] self.assertIsInstance(expr, gurobipy.Var) grb_model = visitor.grb_model @@ -507,7 +507,7 @@ def test_write_expression_with_mutable_param(self): _, expr = visitor.walk_expression(m.c.body) # expr is nonlinear - x2 = visitor.var_map[id(m.x2)] + x2 = visitor.var_map[m.x2] opcode, data, parent = self._get_nl_expr_tree(visitor, expr) @@ -526,7 +526,7 @@ def test_monomial_expression(self): visitor = self.get_visitor() _, expr = visitor.walk_expression(const_expr) - x1 = visitor.var_map[id(m.x1)] + x1 = visitor.var_map[m.x1] self.assertEqual(expr.size(), 1) self.assertEqual(expr.getConstant(), 0.0) self.assertIs(expr.getVar(0), x1) @@ -552,7 +552,7 @@ def test_log_expression(self): _, expr = visitor.walk_expression(m.c.body) # expr is nonlinear - x1 = visitor.var_map[id(m.x1)] + x1 = visitor.var_map[m.x1] opcode, data, parent = self._get_nl_expr_tree(visitor, expr) diff --git a/pyomo/contrib/solver/tests/solvers/test_gurobi_minlp_writer.py b/pyomo/contrib/solver/tests/solvers/test_gurobi_minlp_writer.py index 5b4f3cd33ed..43dc1dfa51b 100644 --- a/pyomo/contrib/solver/tests/solvers/test_gurobi_minlp_writer.py +++ b/pyomo/contrib/solver/tests/solvers/test_gurobi_minlp_writer.py @@ -87,13 +87,13 @@ def test_small_model(self): ).write(m, symbolic_solver_labels=True) self.assertEqual(len(var_map), 7) - x1 = var_map[id(m.x1)] - x2 = var_map[id(m.x2)] - x3 = var_map[id(m.x3)] - y1 = var_map[id(m.y1)] - y2 = var_map[id(m.y2)] - y3 = var_map[id(m.y3)] - z1 = var_map[id(m.z1)] + x1 = var_map[m.x1] + x2 = var_map[m.x2] + x3 = var_map[m.x3] + y1 = var_map[m.y1] + y2 = var_map[m.y2] + y3 = var_map[m.y3] + z1 = var_map[m.z1] self.assertEqual(grb_model.numVars, 9) self.assertEqual(grb_model.numIntVars, 4) @@ -199,7 +199,7 @@ def test_write_NPV_negation_in_RHS(self): ).write(m, symbolic_solver_labels=True) self.assertEqual(len(var_map), 1) - x1 = var_map[id(m.x1)] + x1 = var_map[m.x1] self.assertEqual(grb_model.numVars, 1) self.assertEqual(grb_model.numIntVars, 0) @@ -245,7 +245,7 @@ def test_writer_ignores_deactivated_logical_constraints(self): ).write(m, symbolic_solver_labels=True) self.assertEqual(len(var_map), 1) - x1 = var_map[id(m.x1)] + x1 = var_map[m.x1] self.assertEqual(grb_model.numVars, 1) self.assertEqual(grb_model.numIntVars, 0) @@ -289,8 +289,8 @@ def test_named_expression_quadratic(self): ).write(m, symbolic_solver_labels=True) self.assertEqual(len(var_map), 2) - x = var_map[id(m.x)] - y = var_map[id(m.y)] + x = var_map[m.x] + y = var_map[m.y] self.assertEqual(grb_model.numVars, 2) self.assertEqual(grb_model.numIntVars, 0) @@ -352,8 +352,8 @@ def test_named_expression_nonlinear(self): ).write(m, symbolic_solver_labels=True) self.assertEqual(len(var_map), 2) - x = var_map[id(m.x)] - y = var_map[id(m.y)] + x = var_map[m.x] + y = var_map[m.y] reverse_var_map = {grbv: pyov for pyov, grbv in var_map.items()} self.assertEqual(grb_model.numVars, 4) @@ -454,9 +454,9 @@ def test_unbounded_because_of_multiplying_by_0(self): ).write(m, symbolic_solver_labels=True) self.assertEqual(len(var_map), 3) - x1 = var_map[id(m.x1)] - x2 = var_map[id(m.x2)] - x3 = var_map[id(m.x3)] + x1 = var_map[m.x1] + x2 = var_map[m.x2] + x3 = var_map[m.x3] self.assertEqual(grb_model.numVars, 4) self.assertEqual(grb_model.numIntVars, 0) diff --git a/pyomo/core/expr/visitor.py b/pyomo/core/expr/visitor.py index 210c0ec32e7..50b59f4330f 100644 --- a/pyomo/core/expr/visitor.py +++ b/pyomo/core/expr/visitor.py @@ -16,6 +16,7 @@ logger = logging.getLogger('pyomo.core') +from pyomo.common.collections import ComponentMap from pyomo.common.deprecation import deprecated, deprecation_warning from pyomo.common.errors import DeveloperError, TemplateExpressionError from pyomo.common.numeric_types import ( @@ -967,7 +968,7 @@ def replace_expressions( ---------- expr : Pyomo expression The source expression - substitution_map : dict + substitution_map : dict | ComponentMap A dictionary mapping object ids in the source to the replacement objects. descend_into_named_expressions : bool True if replacement should go into named expression objects, False to halt at @@ -996,6 +997,16 @@ def __init__( ): if substitute is None: substitute = {} + elif isinstance(substitute, ComponentMap): + # ComponentMaps hold references to the keys that they took + # the id() of. Those *could* be the only references to + # those objects, so we want to keep a reference to the + # ComponentMap to guarantee that they don't fall out of + # scope and are collected. + self._cm_substitute = substitute + substitute = { + k if k.__class__ is int else id(k): v for k, v in substitute.items() + } # Note: preserving the attribute names from the previous # implementation of the expression walker. self.substitute = substitute diff --git a/pyomo/core/tests/unit/kernel/test_component_map.py b/pyomo/core/tests/unit/kernel/test_component_map.py index 1af6b855bd5..c4596908a0a 100644 --- a/pyomo/core/tests/unit/kernel/test_component_map.py +++ b/pyomo/core/tests/unit/kernel/test_component_map.py @@ -100,7 +100,7 @@ def test_type(self): def test_str(self): cmap = ComponentMap() - self.assertEqual(str(cmap), "ComponentMap({})") + self.assertEqual(str(cmap), "ComponentMap()") cmap.update(self._components) str(cmap) diff --git a/pyomo/core/tests/unit/kernel/test_component_set.py b/pyomo/core/tests/unit/kernel/test_component_set.py index fe90d1912ea..05ebfe6fc75 100644 --- a/pyomo/core/tests/unit/kernel/test_component_set.py +++ b/pyomo/core/tests/unit/kernel/test_component_set.py @@ -101,7 +101,7 @@ def test_type(self): def test_str(self): cset = ComponentSet() - self.assertEqual(str(cset), "ComponentSet([])") + self.assertEqual(str(cset), "ComponentSet()") cset.update(self._components) str(cset) diff --git a/pyomo/gdp/tests/test_hull.py b/pyomo/gdp/tests/test_hull.py index 15bb5463e82..34553cda3c8 100644 --- a/pyomo/gdp/tests/test_hull.py +++ b/pyomo/gdp/tests/test_hull.py @@ -2335,7 +2335,7 @@ def test_mapping_method_errors(self): with LoggingIntercept(log, 'pyomo.gdp.hull', logging.ERROR): self.assertRaisesRegex( KeyError, - r".*disjunction", + r".*disjunct.ScalarDisjunction object", hull.get_disaggregation_constraint, m.d[1].transformation_block.disaggregatedVars.w, m.disjunction, diff --git a/pyomo/util/components.py b/pyomo/util/components.py index 5cd9fbf7d3b..1d451f18740 100644 --- a/pyomo/util/components.py +++ b/pyomo/util/components.py @@ -33,7 +33,7 @@ def rename_components(model, component_list, prefix): >>> c_list = list(model.component_objects(ctype=pyo.Var, descend_into=True)) >>> new = rename_components(model, component_list=c_list, prefix='special_') >>> str(new) - "ComponentMap({'special_x (key=...)': 'x', 'special_y (key=...)': 'y'})" + 'ComponentMap(special_x: x, special_y: y)' Returns ------- @@ -46,7 +46,8 @@ def rename_components(model, component_list, prefix): generator since this can lead to an infinite loop """ - # Need to collect any Reference first so that we can record the old mapping of data objects before renaming + # Need to collect any Reference first so that we can record the old + # mapping of data objects before renaming refs = {} for c in component_list: if c.is_reference():