Skip to content
116 changes: 107 additions & 9 deletions brian2/equations/codestrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
`Expression` and `Statements` are the ones that are actually used.
"""

import re
from collections.abc import Hashable

import sympy
Expand Down Expand Up @@ -58,23 +59,120 @@ def __hash__(self):

class Statements(CodeString):
"""
Class for representing statements.
Class for representing statements with support for substitution.

Parameters
----------
code : str
The statement or statements. Several statements can be given as a
multi-line string or separated by semicolons.

Notes
-----
Currently, the implementation of this class does not add anything to
`~brian2.equations.codestrings.CodeString`, but it should be used instead
of that class for clarity and to allow for future functionality that is
only relevant to statements and not to expressions.
**substitutions
Substitutions to apply to the code. Can be either strings (to replace
a name with another name) or values (to replace a name with a value).

Examples
--------
>>> Statements('g += w')
Statements('g += w')
>>> Statements('g += k*w', k=0.3)
Statements('g += (0.3)*w')
>>> Statements('g += k*w', g='g_ampa')
Statements('g_ampa += k*w')
>>> Statements('g += k*w', g='g_ampa', k=0.3)
Statements('g_ampa += (0.3)*w')
"""

pass
def __init__(self, code, **substitutions):
if len(substitutions) > 0:
code = self._substitute(code, substitutions)
super().__init__(code)

@staticmethod
def _substitute(code, substitutions):
"""
Perform substitutions in the code.

Parameters
----------
code : str
The original code string
substitutions : dict
Dictionary mapping identifiers to replacements (strings or values)

Returns
-------
new_code : str
Code with substitutions applied

Raises
------
ValueError
If a value substitution is attempted on a variable that appears
on the left-hand side of an assignment (which would create invalid code)
"""
# Check for invalid value substitutions (LHS assignments)
for identifier, replacement in substitutions.items():
if not isinstance(replacement, str):
# This is a value substitution - check if identifier is on LHS
# We need to parse the statement to find LHS variables
lines = code.split("\n")
for line in lines:
# Strip comments
line_no_comment = line.split("#")[0].strip()
if not line_no_comment:
continue

# Check if this is an assignment (contains +=, -=, *=, /=, or =)
if any(
op in line_no_comment for op in ["+=", "-=", "*=", "/=", "="]
):
# Extract the LHS (before the operator)
for op in ["+=", "-=", "*=", "/=", "="]:
if op in line_no_comment:
lhs = line_no_comment.split(op)[0].strip()
# Check if the identifier being substituted is the LHS
if lhs == identifier:
raise ValueError(
f"Cannot substitute value for '{identifier}' "
f"on left-hand side of assignment '{line_no_comment}'. "
f"Use a string substitution instead."
)
break

new_code = code
for identifier, replacement in substitutions.items():
if isinstance(replacement, str):
# Replace identifier with another identifier
# Replace in both code and comments
new_code = re.sub(r"\b" + identifier + r"\b", replacement, new_code)
else:
# Replace identifier with a value
# Only replace in code, not in comments
lines = new_code.split("\n")
new_lines = []
for line in lines:
if "#" in line:
# Split into code and comment parts
code_part, comment_part = line.split("#", 1)
# Apply substitution only to code part
code_part = re.sub(
r"\b" + identifier + r"\b",
"(" + repr(replacement) + ")",
code_part,
)
# Keep comment as is (don't substitute values in comments)
new_lines.append(code_part + "#" + comment_part)
else:
# No comment, apply substitution to whole line
line = re.sub(
r"\b" + identifier + r"\b",
"(" + repr(replacement) + ")",
line,
)
new_lines.append(line)
new_code = "\n".join(new_lines)

return new_code


class Expression(CodeString):
Expand Down
109 changes: 97 additions & 12 deletions brian2/tests/test_codestrings.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
import numpy as np
import pytest
import sympy
from numpy.testing import assert_equal

import brian2
from brian2 import (
DimensionMismatchError,
Expression,
Hz,
Statements,
get_dimensions,
ms,
mV,
second,
volt,
)
from brian2.core.preferences import prefs
from brian2.utils.logger import catch_logs


def sympy_equals(expr1, expr2):
Expand All @@ -39,15 +29,14 @@ def test_expr_creation():
assert (
"v" in expr.identifiers
and "mV" in expr.identifiers
and not "V" in expr.identifiers
and "V" not in expr.identifiers
)
with pytest.raises(SyntaxError):
Expression("v 5 * mV")


@pytest.mark.codegen_independent
def test_split_stochastic():
tau = 5 * ms
expr = Expression("(-v + I) / tau")
# No stochastic part
assert expr.split_stochastic() == (expr, None)
Expand Down Expand Up @@ -103,7 +92,103 @@ def test_str_repr():
assert repr(statement) == "Statements('v += w')"


@pytest.mark.codegen_independent
def test_statements_substitution():
"""
Test that Statements correctly handles substitutions.
"""
# Test string substitution (rename variable)
stmt = Statements("v += w", v="x")
assert str(stmt) == "x += w"

# Test value substitution
stmt = Statements("v += k*w", k=0.3)
assert "0.3" in str(stmt)

# Test both types of substitutions
stmt = Statements("v += k*w", v="x", k=0.3)
assert "x" in str(stmt)
assert "0.3" in str(stmt)


@pytest.mark.codegen_independent
def test_statements_substitution_lhs_error():
"""
Test that Statements raises an error when trying to substitute a value
for a variable on the left-hand side of an assignment.
"""
# Trying to replace LHS variable with a value should raise an error
with pytest.raises(ValueError, match="Cannot substitute value"):
Statements("v += x", v=3 * mV)

with pytest.raises(ValueError, match="Cannot substitute value"):
Statements("v = x", v=5)

# This should work fine (string substitution on LHS)
stmt = Statements("v += x", v="y")
assert str(stmt) == "y += x"

# This should work fine (value substitution on RHS)
stmt = Statements("v += x", x=3 * mV)
assert "(3. * mvolt)" in str(stmt)


@pytest.mark.codegen_independent
def test_statements_substitution_comments():
"""
Test that value substitutions do not affect comments, but name
substitutions do.
"""
# Value substitution should not affect comments
stmt = Statements("x += weight # Use a small weight", weight=1 * brian2.nS)
code = str(stmt)
# Comment should remain unchanged
assert "# Use a small weight" in code
# Code should have the substitution
assert "(1. * nsiemens)" in code

# Name substitution should affect both code and comments
stmt = Statements("x += weight # x is the post-synaptic target variable", x="y")
assert str(stmt) == "y += weight # y is the post-synaptic target variable"

# Multiple lines with comments
stmt = Statements(
"""
x += weight
y += x # x is the variable
""",
x="z",
weight=0.5,
)
code = str(stmt)
assert "z" in code
assert "0.5" in code
assert "# z is the variable" in code


@pytest.mark.codegen_independent
def test_statements_substitution_multiple_lines():
"""
Test substitutions in multi-line statements.
"""
stmt = Statements(
"""
v += w
u += v
""",
v="x",
)
code = str(stmt)
# Both occurrences of v should be replaced
assert "x += w" in code
assert "u += x" in code


if __name__ == "__main__":
test_expr_creation()
test_split_stochastic()
test_str_repr()
test_statements_substitution()
test_statements_substitution_lhs_error()
test_statements_substitution_comments()
test_statements_substitution_multiple_lines()
Loading