Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 118 additions & 6 deletions brian2/equations/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,8 +1025,7 @@ def _sort_subexpressions(self):
they should be updated
"""

# Get a dictionary of all the dependencies on other subexpressions,
# i.e. ignore dependencies on parameters and differential equations
# Get dependencies for all subexpressions
static_deps = {}
for eq in self._equations.values():
if eq.type == SUBEXPRESSION:
Expand All @@ -1037,13 +1036,83 @@ def _sort_subexpressions(self):
and self._equations[dep].type == SUBEXPRESSION
]

# Try to sort all subexpressions together (preserves old behavior for non-cyclical cases)
try:
sorted_eqs = topsort(static_deps)
except ValueError:
raise ValueError(
"Cannot resolve dependencies between static "
"equations, dependencies contain a cycle."
) from None
# There's a cycle. Check if it's only among "constant over dt" subexpressions
constant_over_dt_subexprs = {
eq.varname
for eq in self._equations.values()
if eq.type == SUBEXPRESSION and "constant over dt" in eq.flags
}

# Check if the cycle involves only "constant over dt" subexpressions
def has_cycle_involving_non_constant(
subexpr, visited=None, rec_stack=None, visited_in_path=None
):
"""
Check if there's a cycle that involves a non-constant subexpression.

Returns True if the cycle involves at least one non-constant subexpression.
"""
if visited is None:
visited = set()
if rec_stack is None:
rec_stack = set()
if visited_in_path is None:
visited_in_path = set()

visited.add(subexpr)
rec_stack.add(subexpr)
visited_in_path.add(subexpr)

for dep in static_deps.get(subexpr, []):
# If dependency is not in static_deps, it's not a subexpression
if dep not in static_deps:
continue

# If this dependency is not "constant over dt", check if we found a cycle
if dep not in constant_over_dt_subexprs:
if dep in rec_stack:
# Found a cycle involving a non-constant subexpression
return True
if dep not in visited:
if has_cycle_involving_non_constant(
dep, visited, rec_stack, visited_in_path
):
return True
else:
# Dependency is "constant over dt"
if dep not in visited:
if has_cycle_involving_non_constant(
dep, visited, rec_stack, visited_in_path
):
return True
elif dep in rec_stack:
# Found a cycle, but it only involves constant subexpressions so far
# Continue checking other paths
pass

rec_stack.remove(subexpr)
return False

# Check all subexpressions for cycles involving non-constant ones
has_bad_cycle = any(
has_cycle_involving_non_constant(subexpr)
for subexpr in static_deps.keys()
)

if has_bad_cycle:
# Cycle involves non-constant subexpressions, not allowed
raise ValueError(
"Cannot resolve dependencies between static "
"equations, dependencies contain a cycle."
) from None

# Cycle is only among "constant over dt" subexpressions, which is allowed
# Use cycle-tolerant sort for these
sorted_eqs = self._topsort_with_cycles(static_deps)

# put the equations objects in the correct order
for order, static_variable in enumerate(sorted_eqs):
Expand All @@ -1056,6 +1125,49 @@ def _sort_subexpressions(self):
elif eq.type == PARAMETER:
eq.update_order = len(sorted_eqs) + 1

@staticmethod
def _topsort_with_cycles(dependencies):
"""
Topological sort that can handle cycles by removing edges.

Parameters
----------
dependencies : dict
Dictionary mapping variable names to lists of dependencies

Returns
-------
list
Sorted list of variable names (cycles are broken arbitrarily)
"""
# Make a copy to avoid modifying the original
deps = {k: list(v) for k, v in dependencies.items()}

# Kahn's algorithm with cycle handling
sorted_vars = []
no_deps = [var for var, deps_list in deps.items() if not deps_list]

while no_deps:
var = no_deps.pop(0)
sorted_vars.append(var)

# Remove this variable from all dependency lists
for other_var in deps:
if var in deps[other_var]:
deps[other_var].remove(var)

# Check for new variables with no dependencies
for other_var in list(deps.keys()):
if other_var not in sorted_vars and not deps[other_var]:
no_deps.append(other_var)

# If there are still variables left, they have cycles
# Add them in any order (cycle is allowed for "constant over dt")
remaining = [var for var in deps if var not in sorted_vars]
sorted_vars.extend(remaining)

return sorted_vars

@property
def dependencies(self):
"""
Expand Down
21 changes: 18 additions & 3 deletions brian2/tests/test_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,13 @@ def test_correct_replacements():
# replace a variable name with a new name
eqs = Equations("dv/dt = -v / tau : 1", v="V")
# Correct left hand side
assert ("V" in eqs) and not ("v" in eqs)
assert ("V" in eqs) and "v" not in eqs
# Correct right hand side
assert ("V" in eqs["V"].identifiers) and not ("v" in eqs["V"].identifiers)
assert ("V" in eqs["V"].identifiers) and "v" not in eqs["V"].identifiers

# replace a variable name with a value
eqs = Equations("dv/dt = -v / tau : 1", tau=10 * ms)
assert not "tau" in eqs["v"].identifiers
assert "tau" not in eqs["v"].identifiers


@pytest.mark.codegen_independent
Expand Down Expand Up @@ -567,6 +567,20 @@ def test_extract_subexpressions():
assert constant["s2"].type == SUBEXPRESSION


@pytest.mark.codegen_independent
def test_cyclical_subexpressions():
with pytest.raises(ValueError):
# dependency cycle
Equations("""dv/dt = (-v + s1)/ (10*ms) : 1
s1 = 2 * s2 : 1
s2 = s1/2 : 1""")

# With constant over dt on BOTH, the cycle is allowed
Equations("""dv/dt = (-v + s1)/ (10*ms) : 1
s1 = 2 * s2 : 1 (constant over dt)
s2 = s1/2 : 1 (constant over dt)""")


@pytest.mark.codegen_independent
def test_repeated_construction():
eqs1 = Equations("dx/dt = x : 1")
Expand Down Expand Up @@ -674,5 +688,6 @@ def test_ipython_pprint():
test_unit_checking()
test_properties()
test_extract_subexpressions()
test_cyclical_subexpressions()
test_repeated_construction()
test_str_repr()
Loading