diff --git a/brian2/stateupdaters/exact.py b/brian2/stateupdaters/exact.py index 1e806bdf3..37891fbe9 100644 --- a/brian2/stateupdaters/exact.py +++ b/brian2/stateupdaters/exact.py @@ -40,22 +40,25 @@ def get_linear_system(eqs, variables): ValueError If the equations cannot be converted into an M * X + B form. ''' - diff_eqs = eqs.get_substituted_expressions(variables) - diff_eq_names = [name for name, _ in diff_eqs] + diff_eqs = {name: str_to_sympy(expr.code, variables).expand() + for name, expr in eqs.get_substituted_expressions(variables)} - symbols = [Symbol(name, real=True) for name in diff_eq_names] + # Sometimes, in particular in testing, variables defined as differential + # equations are actually constant (e.g. `dv/dt = 0/second`). We ignore + # them here + symbols = [Symbol(name, real=True) for name, expr in diff_eqs.items() + if expr != 0] - coefficients = sp.zeros(len(diff_eq_names)) - constants = sp.zeros(len(diff_eq_names), 1) - - for row_idx, (name, expr) in enumerate(diff_eqs): - s_expr = str_to_sympy(expr.code, variables).expand() + coefficients = sp.zeros(len(symbols)) + constants = sp.zeros(len(symbols), 1) + for row_idx, symbol in enumerate(symbols): + s_expr = diff_eqs[symbol.name] current_s_expr = s_expr for col_idx, symbol in enumerate(symbols): current_s_expr = current_s_expr.collect(symbol) constant_wildcard = Wild('c', exclude=[symbol]) - factor_wildcard = Wild('c_'+name, exclude=symbols) + factor_wildcard = Wild('c_'+symbol.name, exclude=symbols) one_pattern = factor_wildcard*symbol + constant_wildcard matches = current_s_expr.match(one_pattern) if matches is None: @@ -64,7 +67,8 @@ def get_linear_system(eqs, variables): '%s, could not be ' 'separated into linear ' 'components.') % - (expr, name)) + (sympy_to_str(s_expr), + symbol.name)) coefficients[row_idx, col_idx] = matches[factor_wildcard] current_s_expr = matches[constant_wildcard] @@ -72,7 +76,7 @@ def get_linear_system(eqs, variables): # The remaining constant should be a true constant constants[row_idx] = current_s_expr - return (diff_eq_names, coefficients, constants) + return [s.name for s in symbols], coefficients, constants class IndependentStateUpdater(StateUpdateMethod): @@ -191,40 +195,49 @@ def __call__(self, equations, variables=None, method_options=None): ('Expression "{}" is not guaranteed to be constant over a ' 'time step').format(sympy_to_str(entry))) - symbols = [Symbol(variable, real=True) for variable in varnames] - solution = sp.solve_linear_system(matrix.row_join(constants), *symbols) - if solution is None or set(symbols) != set(solution.keys()): - raise UnsupportedEquationsException('Cannot solve the given ' - 'equations with this ' - 'stateupdater.') - b = sp.ImmutableMatrix([solution[symbol] for symbol in symbols]) - # Solve the system dt = Symbol('dt', real=True, positive=True) + # Add the constant terms as new variables + const_vars = [] + const_terms = [] + for idx, (varname, const_term) in enumerate(zip(varnames, constants)): + if const_term != 0: + matrix = matrix.col_insert(matrix.cols, sp.Matrix([1 if i == idx else 0 + for i in range(matrix.rows)])) + matrix = matrix.row_insert(matrix.rows, sp.zeros(1, matrix.cols)) + const_vars.append('_const_term_' + varname) + const_terms.append(const_term) + try: A = (matrix * dt).exp() except NotImplementedError: raise UnsupportedEquationsException('Cannot solve the given ' 'equations with this ' 'stateupdater.') + if method_options['simplify']: A = A.applyfunc(lambda x: sp.factor_terms(sp.cancel(sp.signsimp(x)))) - C = sp.ImmutableMatrix(A * b) - b - _S = sp.MatrixSymbol('_S', len(varnames), 1) - updates = A * _S + C + + _S = sp.MatrixSymbol('_S', len(varnames) + len(const_vars), 1) + updates = A * _S updates = updates.as_explicit() + abstract_code = [] + + # Add code for the constant terms: + for const_var, const_term in zip(const_vars, const_terms): + abstract_code.append(const_var + ' = ' + sympy_to_str(const_term)) # The solution contains _S[0, 0], _S[1, 0] etc. for the state variables, - # replace them with the state variable names - abstract_code = [] - for idx, (variable, update) in enumerate(zip(varnames, updates)): + # replace them with the state variable names + for variable, update in zip(varnames, updates[:len(varnames)]): rhs = update if rhs.has(I, re, im): raise UnsupportedEquationsException('The solution to the linear system ' 'contains complex values ' 'which is currently not implemented.') - for row_idx, varname in enumerate(varnames): + + for row_idx, varname in enumerate(itertools.chain(varnames, const_vars)): rhs = rhs.subs(_S[row_idx, 0], varname) # Do not overwrite the real state variables yet, the update step diff --git a/brian2/tests/test_synapses.py b/brian2/tests/test_synapses.py index 768116704..dd13c52d1 100644 --- a/brian2/tests/test_synapses.py +++ b/brian2/tests/test_synapses.py @@ -1585,6 +1585,8 @@ def test_event_driven_dependency_error(): @pytest.mark.codegen_independent def test_event_driven_dependency_error2(): + pytest.xfail("This will be fixed with the rewrite of the equation " + "dependency check.") stim = SpikeGeneratorGroup(1, [0], [0]*ms, period=5*ms) tau = 5*ms syn = Synapses(stim, stim, ''' @@ -1962,7 +1964,8 @@ def test_vectorisation_STDP_like(): neurons = NeuronGroup(6, '''dv/dt = rate : 1 ge : 1 rate : Hz - dA/dt = -A/(1*ms) : 1''', threshold='v>1', reset='v=0') + dA/dt = -A/(1*ms) : 1''', threshold='v>1', + reset='v=0', method='euler') # Note that the synapse does not actually increase the target v, we want # to have simple control about when neurons spike. Also, we separate the # "depression" and "facilitation" completely. The example also uses