Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions loopy/transform/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def __init__(self,
self.subst_rule_name = subst_rule_name
self.inames = inames
self.inames_set = frozenset(inames)
self.n_mapped_redns = 0

@override
def map_reduction(self, expr: Reduction, /,
Expand All @@ -905,20 +906,22 @@ def map_reduction(self, expr: Reduction, /,
self.rule_mapping_context.make_unique_var_name(subst_rule_prefix)
else:
proposed_name = self.subst_rule_name
if self.n_mapped_redns:
raise LoopyError("substitution rule '%s' already exists"
% proposed_name)

actual_name = self.rule_mapping_context.register_subst_rule(
intermediate_name = self.rule_mapping_context.register_subst_rule(
proposed_name, tuple(self.inames), expr.expr)
if proposed_name != actual_name and self.subst_rule_name:
raise LoopyError("substitution rule '%s' already exists"
% proposed_name)

from pymbolic import var
iname_vars = [var(iname) for iname in self.inames]

self.n_mapped_redns += 1

return type(expr)(
operation=expr.operation,
inames=tuple(expr.inames),
expr=var(proposed_name)(*iname_vars),
expr=var(intermediate_name)(*iname_vars),
allow_simultaneous=expr.allow_simultaneous)


Expand Down
78 changes: 78 additions & 0 deletions test/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,84 @@ def test_extract_subst(ctx_factory: cl.CtxFactory):
assert insn.expression == parse("bsquare(23) + bsquare(25)")


def test_reduction_arg_to_subst_rule_single():
from loopy.transform.data import reduction_arg_to_subst_rule

t_unit = lp.make_kernel(
"{[i,j]: 0<=i,j<n}",
"out[i] = sum(j, a[i,j]) {id=red}",
name="red_subst")

from loopy.symbolic import parse

# {{{ auto-generated substitution rule name

auto_t_unit = reduction_arg_to_subst_rule(t_unit, "j")
knl = auto_t_unit["red_subst"]
assert knl.id_to_insn["red"].expression == \
parse("reduce(sum, [j], red_j_arg(j))")

subst = knl.substitutions["red_j_arg"]
assert subst.arguments == ("j",)
assert subst.expression == parse("a[i, j]")

# }}}

# {{{ explicit substitution rule name

named_t_unit = reduction_arg_to_subst_rule(t_unit, "j", subst_rule_name="mysubst")
knl = named_t_unit["red_subst"]
assert knl.id_to_insn["red"].expression == \
parse("reduce(sum, [j], mysubst(j))")

subst = knl.substitutions["mysubst"]
assert subst.arguments == ("j",)
assert subst.expression == parse("a[i, j]")

# }}}


def test_reduction_arg_to_subst_rule_multiple():
from loopy.transform.data import reduction_arg_to_subst_rule

t_unit = lp.make_kernel(
"{[i,j]: 0<=i,j<n}",
"""
out1[i] = sum(j, a[i,j]) {id=red1}
out2[i] = sum(j, b[i,j]) {id=red2}
""",
name="red_subst")

from loopy.symbolic import parse

# {{{ auto-generated names handle multiple matching reductions

# Each matching reduction gets its own (distinct) rule.
auto_t_unit = reduction_arg_to_subst_rule(t_unit, "j")
knl = auto_t_unit["red_subst"]
assert set(knl.substitutions) == {"red_j_arg", "red_j_arg_0"}
assert knl.id_to_insn["red1"].expression == \
parse("reduce(sum, [j], red_j_arg(j))")
assert knl.id_to_insn["red2"].expression == \
parse("reduce(sum, [j], red_j_arg_0(j))")

assert knl.substitutions["red_j_arg"].arguments == ("j",)
assert knl.substitutions["red_j_arg"].expression == parse("a[i, j]")
assert knl.substitutions["red_j_arg_0"].arguments == ("j",)
assert knl.substitutions["red_j_arg_0"].expression == parse("b[i, j]")

# }}}

# {{{ explicit name with multiple matching reductions raises

# An explicit name can only apply to a single reduction, so a second
# matching reduction must raise.
with pytest.raises(lp.LoopyError):
reduction_arg_to_subst_rule(t_unit, "j", subst_rule_name="mysubst")

# }}}


def test_join_inames(ctx_factory: cl.CtxFactory):
ctx = ctx_factory()

Expand Down
Loading