Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions brian2/stateupdaters/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,25 @@ def get_linear_system(eqs, variables):
ValueError
If the equations cannot be converted into an M * X + B form.
'''
diff_eqs = eqs.get_substituted_expressions(variables)
diff_eq_names = [name for name, _ in diff_eqs]
diff_eqs = {name: str_to_sympy(expr.code, variables).expand()
for name, expr in eqs.get_substituted_expressions(variables)}

symbols = [Symbol(name, real=True) for name in diff_eq_names]
# Sometimes, in particular in testing, variables defined as differential
# equations are actually constant (e.g. `dv/dt = 0/second`). We ignore
# them here
symbols = [Symbol(name, real=True) for name, expr in diff_eqs.items()
if expr != 0]

coefficients = sp.zeros(len(diff_eq_names))
constants = sp.zeros(len(diff_eq_names), 1)

for row_idx, (name, expr) in enumerate(diff_eqs):
s_expr = str_to_sympy(expr.code, variables).expand()
coefficients = sp.zeros(len(symbols))
constants = sp.zeros(len(symbols), 1)

for row_idx, symbol in enumerate(symbols):
s_expr = diff_eqs[symbol.name]
current_s_expr = s_expr
for col_idx, symbol in enumerate(symbols):
current_s_expr = current_s_expr.collect(symbol)
constant_wildcard = Wild('c', exclude=[symbol])
factor_wildcard = Wild('c_'+name, exclude=symbols)
factor_wildcard = Wild('c_'+symbol.name, exclude=symbols)
one_pattern = factor_wildcard*symbol + constant_wildcard
matches = current_s_expr.match(one_pattern)
if matches is None:
Expand All @@ -64,15 +67,16 @@ def get_linear_system(eqs, variables):
'%s, could not be '
'separated into linear '
'components.') %
(expr, name))
(sympy_to_str(s_expr),
symbol.name))

coefficients[row_idx, col_idx] = matches[factor_wildcard]
current_s_expr = matches[constant_wildcard]

# The remaining constant should be a true constant
constants[row_idx] = current_s_expr

return (diff_eq_names, coefficients, constants)
return [s.name for s in symbols], coefficients, constants


class IndependentStateUpdater(StateUpdateMethod):
Expand Down Expand Up @@ -191,40 +195,49 @@ def __call__(self, equations, variables=None, method_options=None):
('Expression "{}" is not guaranteed to be constant over a '
'time step').format(sympy_to_str(entry)))

symbols = [Symbol(variable, real=True) for variable in varnames]
solution = sp.solve_linear_system(matrix.row_join(constants), *symbols)
if solution is None or set(symbols) != set(solution.keys()):
raise UnsupportedEquationsException('Cannot solve the given '
'equations with this '
'stateupdater.')
b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols])

# Solve the system
dt = Symbol('dt', real=True, positive=True)
# Add the constant terms as new variables
const_vars = []
const_terms = []
for idx, (varname, const_term) in enumerate(zip(varnames, constants)):
if const_term != 0:
matrix = matrix.col_insert(matrix.cols, sp.Matrix([1 if i == idx else 0
for i in range(matrix.rows)]))
matrix = matrix.row_insert(matrix.rows, sp.zeros(1, matrix.cols))
const_vars.append('_const_term_' + varname)
const_terms.append(const_term)

try:
A = (matrix * dt).exp()
except NotImplementedError:
raise UnsupportedEquationsException('Cannot solve the given '
'equations with this '
'stateupdater.')

if method_options['simplify']:
A = A.applyfunc(lambda x:
sp.factor_terms(sp.cancel(sp.signsimp(x))))
C = sp.ImmutableMatrix(A * b) - b
_S = sp.MatrixSymbol('_S', len(varnames), 1)
updates = A * _S + C

_S = sp.MatrixSymbol('_S', len(varnames) + len(const_vars), 1)
updates = A * _S
updates = updates.as_explicit()
abstract_code = []

# Add code for the constant terms:
for const_var, const_term in zip(const_vars, const_terms):
abstract_code.append(const_var + ' = ' + sympy_to_str(const_term))

# The solution contains _S[0, 0], _S[1, 0] etc. for the state variables,
# replace them with the state variable names
abstract_code = []
for idx, (variable, update) in enumerate(zip(varnames, updates)):
# replace them with the state variable names
for variable, update in zip(varnames, updates[:len(varnames)]):
rhs = update
if rhs.has(I, re, im):
raise UnsupportedEquationsException('The solution to the linear system '
'contains complex values '
'which is currently not implemented.')
for row_idx, varname in enumerate(varnames):

for row_idx, varname in enumerate(itertools.chain(varnames, const_vars)):
rhs = rhs.subs(_S[row_idx, 0], varname)

# Do not overwrite the real state variables yet, the update step
Expand Down
5 changes: 4 additions & 1 deletion brian2/tests/test_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,8 @@ def test_event_driven_dependency_error():

@pytest.mark.codegen_independent
def test_event_driven_dependency_error2():
pytest.xfail("This will be fixed with the rewrite of the equation "
"dependency check.")
stim = SpikeGeneratorGroup(1, [0], [0]*ms, period=5*ms)
tau = 5*ms
syn = Synapses(stim, stim, '''
Expand Down Expand Up @@ -1962,7 +1964,8 @@ def test_vectorisation_STDP_like():
neurons = NeuronGroup(6, '''dv/dt = rate : 1
ge : 1
rate : Hz
dA/dt = -A/(1*ms) : 1''', threshold='v>1', reset='v=0')
dA/dt = -A/(1*ms) : 1''', threshold='v>1',
reset='v=0', method='euler')
# Note that the synapse does not actually increase the target v, we want
# to have simple control about when neurons spike. Also, we separate the
# "depression" and "facilitation" completely. The example also uses
Expand Down