diff --git a/loopy/transform/data.py b/loopy/transform/data.py index 0961f643c..ca2f737f5 100644 --- a/loopy/transform/data.py +++ b/loopy/transform/data.py @@ -888,6 +888,7 @@ def __init__(self, self.subst_rule_name = subst_rule_name self.inames = inames self.inames_set = frozenset(inames) + self.n_mapped_redns = 0 @override def map_reduction(self, expr: Reduction, /, @@ -905,20 +906,22 @@ def map_reduction(self, expr: Reduction, /, self.rule_mapping_context.make_unique_var_name(subst_rule_prefix) else: proposed_name = self.subst_rule_name + if self.n_mapped_redns: + raise LoopyError("substitution rule '%s' already exists" + % proposed_name) - actual_name = self.rule_mapping_context.register_subst_rule( + intermediate_name = self.rule_mapping_context.register_subst_rule( proposed_name, tuple(self.inames), expr.expr) - if proposed_name != actual_name and self.subst_rule_name: - raise LoopyError("substitution rule '%s' already exists" - % proposed_name) from pymbolic import var iname_vars = [var(iname) for iname in self.inames] + self.n_mapped_redns += 1 + return type(expr)( operation=expr.operation, inames=tuple(expr.inames), - expr=var(proposed_name)(*iname_vars), + expr=var(intermediate_name)(*iname_vars), allow_simultaneous=expr.allow_simultaneous) diff --git a/test/test_transform.py b/test/test_transform.py index b6f87c108..219158172 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -314,6 +314,84 @@ def test_extract_subst(ctx_factory: cl.CtxFactory): assert insn.expression == parse("bsquare(23) + bsquare(25)") +def test_reduction_arg_to_subst_rule_single(): + from loopy.transform.data import reduction_arg_to_subst_rule + + t_unit = lp.make_kernel( + "{[i,j]: 0<=i,j