diff --git a/brian2/equations/equations.py b/brian2/equations/equations.py index 9b987e587..9e49788bc 100644 --- a/brian2/equations/equations.py +++ b/brian2/equations/equations.py @@ -1025,8 +1025,7 @@ def _sort_subexpressions(self): they should be updated """ - # Get a dictionary of all the dependencies on other subexpressions, - # i.e. ignore dependencies on parameters and differential equations + # Get dependencies for all subexpressions static_deps = {} for eq in self._equations.values(): if eq.type == SUBEXPRESSION: @@ -1037,13 +1036,83 @@ def _sort_subexpressions(self): and self._equations[dep].type == SUBEXPRESSION ] + # Try to sort all subexpressions together (preserves old behavior for non-cyclical cases) try: sorted_eqs = topsort(static_deps) except ValueError: - raise ValueError( - "Cannot resolve dependencies between static " - "equations, dependencies contain a cycle." - ) from None + # There's a cycle. Check if it's only among "constant over dt" subexpressions + constant_over_dt_subexprs = { + eq.varname + for eq in self._equations.values() + if eq.type == SUBEXPRESSION and "constant over dt" in eq.flags + } + + # Check if the cycle involves only "constant over dt" subexpressions + def has_cycle_involving_non_constant( + subexpr, visited=None, rec_stack=None, visited_in_path=None + ): + """ + Check if there's a cycle that involves a non-constant subexpression. + + Returns True if the cycle involves at least one non-constant subexpression. + """ + if visited is None: + visited = set() + if rec_stack is None: + rec_stack = set() + if visited_in_path is None: + visited_in_path = set() + + visited.add(subexpr) + rec_stack.add(subexpr) + visited_in_path.add(subexpr) + + for dep in static_deps.get(subexpr, []): + # If dependency is not in static_deps, it's not a subexpression + if dep not in static_deps: + continue + + # If this dependency is not "constant over dt", check if we found a cycle + if dep not in constant_over_dt_subexprs: + if dep in rec_stack: + # Found a cycle involving a non-constant subexpression + return True + if dep not in visited: + if has_cycle_involving_non_constant( + dep, visited, rec_stack, visited_in_path + ): + return True + else: + # Dependency is "constant over dt" + if dep not in visited: + if has_cycle_involving_non_constant( + dep, visited, rec_stack, visited_in_path + ): + return True + elif dep in rec_stack: + # Found a cycle, but it only involves constant subexpressions so far + # Continue checking other paths + pass + + rec_stack.remove(subexpr) + return False + + # Check all subexpressions for cycles involving non-constant ones + has_bad_cycle = any( + has_cycle_involving_non_constant(subexpr) + for subexpr in static_deps.keys() + ) + + if has_bad_cycle: + # Cycle involves non-constant subexpressions, not allowed + raise ValueError( + "Cannot resolve dependencies between static " + "equations, dependencies contain a cycle." + ) from None + + # Cycle is only among "constant over dt" subexpressions, which is allowed + # Use cycle-tolerant sort for these + sorted_eqs = self._topsort_with_cycles(static_deps) # put the equations objects in the correct order for order, static_variable in enumerate(sorted_eqs): @@ -1056,6 +1125,49 @@ def _sort_subexpressions(self): elif eq.type == PARAMETER: eq.update_order = len(sorted_eqs) + 1 + @staticmethod + def _topsort_with_cycles(dependencies): + """ + Topological sort that can handle cycles by removing edges. + + Parameters + ---------- + dependencies : dict + Dictionary mapping variable names to lists of dependencies + + Returns + ------- + list + Sorted list of variable names (cycles are broken arbitrarily) + """ + # Make a copy to avoid modifying the original + deps = {k: list(v) for k, v in dependencies.items()} + + # Kahn's algorithm with cycle handling + sorted_vars = [] + no_deps = [var for var, deps_list in deps.items() if not deps_list] + + while no_deps: + var = no_deps.pop(0) + sorted_vars.append(var) + + # Remove this variable from all dependency lists + for other_var in deps: + if var in deps[other_var]: + deps[other_var].remove(var) + + # Check for new variables with no dependencies + for other_var in list(deps.keys()): + if other_var not in sorted_vars and not deps[other_var]: + no_deps.append(other_var) + + # If there are still variables left, they have cycles + # Add them in any order (cycle is allowed for "constant over dt") + remaining = [var for var in deps if var not in sorted_vars] + sorted_vars.extend(remaining) + + return sorted_vars + @property def dependencies(self): """ diff --git a/brian2/tests/test_equations.py b/brian2/tests/test_equations.py index ed985c465..1640295b9 100644 --- a/brian2/tests/test_equations.py +++ b/brian2/tests/test_equations.py @@ -214,13 +214,13 @@ def test_correct_replacements(): # replace a variable name with a new name eqs = Equations("dv/dt = -v / tau : 1", v="V") # Correct left hand side - assert ("V" in eqs) and not ("v" in eqs) + assert ("V" in eqs) and "v" not in eqs # Correct right hand side - assert ("V" in eqs["V"].identifiers) and not ("v" in eqs["V"].identifiers) + assert ("V" in eqs["V"].identifiers) and "v" not in eqs["V"].identifiers # replace a variable name with a value eqs = Equations("dv/dt = -v / tau : 1", tau=10 * ms) - assert not "tau" in eqs["v"].identifiers + assert "tau" not in eqs["v"].identifiers @pytest.mark.codegen_independent @@ -567,6 +567,20 @@ def test_extract_subexpressions(): assert constant["s2"].type == SUBEXPRESSION +@pytest.mark.codegen_independent +def test_cyclical_subexpressions(): + with pytest.raises(ValueError): + # dependency cycle + Equations("""dv/dt = (-v + s1)/ (10*ms) : 1 + s1 = 2 * s2 : 1 + s2 = s1/2 : 1""") + + # With constant over dt on BOTH, the cycle is allowed + Equations("""dv/dt = (-v + s1)/ (10*ms) : 1 + s1 = 2 * s2 : 1 (constant over dt) + s2 = s1/2 : 1 (constant over dt)""") + + @pytest.mark.codegen_independent def test_repeated_construction(): eqs1 = Equations("dx/dt = x : 1") @@ -674,5 +688,6 @@ def test_ipython_pprint(): test_unit_checking() test_properties() test_extract_subexpressions() + test_cyclical_subexpressions() test_repeated_construction() test_str_repr() diff --git a/brian2/tests/test_neurongroup.py b/brian2/tests/test_neurongroup.py index d9de59192..d81a3c75a 100644 --- a/brian2/tests/test_neurongroup.py +++ b/brian2/tests/test_neurongroup.py @@ -396,18 +396,18 @@ def test_linked_variable_incorrect(): # incorrect unit with pytest.raises(DimensionMismatchError): - setattr(G3, "l", linked_var(G1.y)) + G3.l = linked_var(G1.y) # incorrect group size with pytest.raises(ValueError): - setattr(G3, "l", linked_var(G2.x)) + G3.l = linked_var(G2.x) # incorrect use of linked_var with pytest.raises(ValueError): - setattr(G3, "l", linked_var(G1.x, "x")) + G3.l = linked_var(G1.x, "x") with pytest.raises(ValueError): - setattr(G3, "l", linked_var(G1)) + G3.l = linked_var(G1) # Not a linked variable with pytest.raises(TypeError): - setattr(G3, "not_linked", linked_var(G1.x)) + G3.not_linked = linked_var(G1.x) @pytest.mark.standalone_compatible @@ -728,15 +728,15 @@ def test_linked_variable_indexed_incorrect(): G.x = np.arange(10) * 0.1 with pytest.raises(TypeError): - setattr(G, "y", linked_var(G.x, index=np.arange(10) * 1.0)) + G.y = linked_var(G.x, index=np.arange(10) * 1.0) with pytest.raises(TypeError): - setattr(G, "y", linked_var(G.x, index=np.arange(10).reshape(5, 2))) + G.y = linked_var(G.x, index=np.arange(10).reshape(5, 2)) with pytest.raises(TypeError): - setattr(G, "y", linked_var(G.x, index=np.arange(5))) + G.y = linked_var(G.x, index=np.arange(5)) with pytest.raises(ValueError): - setattr(G, "y", linked_var(G.x, index=np.arange(10) - 1)) + G.y = linked_var(G.x, index=np.arange(10) - 1) with pytest.raises(ValueError): - setattr(G, "y", linked_var(G.x, index=np.arange(10) + 1)) + G.y = linked_var(G.x, index=np.arange(10) + 1) @pytest.mark.codegen_independent @@ -749,7 +749,7 @@ def test_linked_synapses(): S.connect() G2 = NeuronGroup(100, "x : 1 (linked)") with pytest.raises(NotImplementedError): - setattr(G2, "x", linked_var(S, "w")) + G2.x = linked_var(S, "w") @pytest.mark.standalone_compatible @@ -904,8 +904,12 @@ def test_namespace_warnings(): net = Network(G2) with catch_logs() as l: net.run(0 * ms) - assert len(l) == 1, f"got {str(l)} as warnings" - assert l[0][1].endswith(".resolution_conflict") + # Filter to only count resolution_conflict warnings + conflict_warnings = [ + log for log in l if log[1].endswith(".resolution_conflict") + ] + assert len(conflict_warnings) == 1, f"got {str(l)} as warnings" + assert conflict_warnings[0][1].endswith(".resolution_conflict") del y i = 5 @@ -1402,7 +1406,7 @@ def test_unknown_state_variables(): # variable are handled G = NeuronGroup(10, "v : 1") with pytest.raises(AttributeError): - setattr(G, "unknown", 42) + G.unknown = 42 # Creating a new private attribute should be fine G._unknown = 42 @@ -1648,6 +1652,222 @@ def test_constant_subexpression_order(): assert code_lines[2].startswith("s2") +@pytest.mark.codegen_independent +def test_mixed_subexpression_order(): + """ + Test that the order of operations is correct when regular subexpressions + depend on 'constant over dt' subexpressions, and vice versa. + + This verifies that the fix for issue #1187 doesn't change the behavior + for non-cyclical cases. + """ + # Case 1: Regular subexpression depends on "constant over dt" subexpression + G = NeuronGroup( + 10, + """ + dv/dt = -v / (10*ms) : 1 + s_constant = v + 1 : 1 (constant over dt) + s_regular = 2 * s_constant : 1 + """, + ) + run(0 * ms) + code_lines = G.subexpression_updater.abstract_code.split("\n") + # Only s_constant is in SubexpressionUpdater (s_regular is inline) + assert len(code_lines) == 1 + assert code_lines[0].startswith("s_constant") + + # Case 2: "Constant over dt" subexpression depends on regular subexpression + # This is a bit tricky, but should work: the regular subexpression will + # be evaluated once (when the "constant over dt" is computed) + G = NeuronGroup( + 10, + """ + dv/dt = -v / (10*ms) : 1 + s_regular = v + 1 : 1 + s_constant = 2 * s_regular : 1 (constant over dt) + """, + ) + run(0 * ms) + code_lines = G.subexpression_updater.abstract_code.split("\n") + # Only s_constant is in SubexpressionUpdater (s_regular is inline) + assert len(code_lines) == 1 + assert code_lines[0].startswith("s_constant") + + # Case 3: Chain of dependencies across types + G = NeuronGroup( + 10, + """ + dv/dt = -v / (10*ms) : 1 + s1 = v : 1 (constant over dt) + s2 = s1 + 1 : 1 + s3 = s2 * 2 : 1 (constant over dt) + s4 = s3 + 1 : 1 + """, + ) + run(0 * ms) + code_lines = G.subexpression_updater.abstract_code.split("\n") + # Only s1 and s3 are in SubexpressionUpdater (constant over dt) + # s2 and s4 are regular (inlined) + assert len(code_lines) == 2 + assert code_lines[0].startswith("s1") + assert code_lines[1].startswith("s3") + + +@pytest.mark.codegen_independent +def test_mixed_subexpression_dependencies_order(): + """ + Test that subexpression order respects cross-dependencies + between regular and 'constant over dt' subexpressions. + + This verifies the fix for issue #1187 where dependencies + across subexpression types were broken. + """ + # Case 1: Constant → Regular → Constant → Regular + G = NeuronGroup( + 1, + """ + dv/dt = -v / (10*ms) : 1 + s1 = v + 1 : 1 (constant over dt) + s2 = s1 * 2 : 1 + s3 = s2 + 5 : 1 (constant over dt) + s4 = s3 * 3 : 1 + """, + method="euler", + ) + net = Network(G) + net.run(0 * ms) + code_lines = G.subexpression_updater.abstract_code.split("\n") + # Should compute s1 first, then s3 (in dependency order) + assert code_lines[0].startswith("s1") + assert code_lines[1].startswith("s3") + + # Case 2: Regular → Constant + G2 = NeuronGroup( + 1, + """ + dv/dt = -v / (10*ms) : 1 + s_regular = v + 1 : 1 + s_constant = s_regular * 2 : 1 (constant over dt) + """, + method="euler", + ) + net2 = Network(G2) + net2.run(0 * ms) + code_lines2 = G2.subexpression_updater.abstract_code.split("\n") + # Should compute s_constant (s_regular is inlined) + assert len(code_lines2) == 1 + assert code_lines2[0].startswith("s_constant") + + +@pytest.mark.standalone_compatible +@pytest.mark.multiple_runs +def test_mixed_subexpression_dependencies_values(): + """ + Test that subexpression values are computed correctly when there are + cross-dependencies between regular and 'constant over dt' subexpressions. + + This verifies that the fix for issue #1187 produces correct values. + """ + # Test case: constant → regular → constant + G = NeuronGroup( + 1, + """ + v : 1 + s1 = v + 1 : 1 (constant over dt) + s2 = s1 * 2 : 1 + s3 = s2 + 5 : 1 (constant over dt) + """, + method="euler", + ) + G.v = 5 + + # Expected values: s1 = 6, s2 = 12, s3 = 17 + net = Network(G) + net.run(0.1 * ms) # Run for a short duration to trigger SubexpressionUpdater + + # Verify s1 and s3 are computed correctly + # Note: v may have changed due to no differential equation, but s1 should + # still be based on the initial v value at start of timestep + assert_allclose(G.s1[0], 6.0) + assert_allclose(G.s3[0], 17.0) + + # Verify s2 (regular subexpression) also works + assert_allclose(G.s2[0], 12.0) + + device.build(direct_call=False, **device.build_options) + + +@pytest.mark.codegen_independent +def test_cyclical_constant_subexpressions_allowed(): + """ + Test that cyclical dependencies among 'constant over dt' subexpressions + are allowed (new feature from issue #1187). + """ + # This should NOT raise an error + G = NeuronGroup( + 1, + """ + dv/dt = -v / (10*ms) : 1 + s1 = 2 * s2 : 1 (constant over dt) + s2 = s1 / 2 : 1 (constant over dt) + """, + method="euler", + ) + net = Network(G) + net.run(0 * ms) + # Should successfully create the group and run + assert G.subexpression_updater is not None + + +@pytest.mark.codegen_independent +def test_cyclical_regular_subexpressions_still_error(): + """ + Test that cyclical dependencies among regular subexpressions + still raise an error (unchanged behavior). + """ + with pytest.raises(ValueError, match="cycle"): + NeuronGroup( + 1, + """ + dv/dt = -v / (10*ms) : 1 + s1 = 2 * s2 : 1 + s2 = s1 / 2 : 1 + """, + ) + + +@pytest.mark.codegen_independent +def test_mixed_cyclical_subexpressions_error(): + """ + Test that cyclical dependencies involving both regular and + 'constant over dt' subexpressions still raise an error. + + Only cycles that are ENTIRELY within 'constant over dt' are allowed. + """ + # Case 1: Regular depends on constant, constant depends on regular (cycle) + with pytest.raises(ValueError, match="cycle"): + NeuronGroup( + 1, + """ + dv/dt = -v / (10*ms) : 1 + s1 = s2 + 1 : 1 (constant over dt) + s2 = s1 * 2 : 1 + """, + ) + + # Case 2: Chain that creates a cycle involving regular + with pytest.raises(ValueError, match="cycle"): + NeuronGroup( + 1, + """ + dv/dt = -v / (10*ms) : 1 + s1 = s3 + 1 : 1 (constant over dt) + s2 = s1 * 2 : 1 + s3 = s2 + 5 : 1 + """, + ) + + @pytest.mark.codegen_independent def test_subexpression_checks(): group = NeuronGroup(