-
Notifications
You must be signed in to change notification settings - Fork 58
feat[next]: Tracer support part 1: tree_map #2586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
1b4707c
902f8a3
02f881f
0ec4692
ab84ecc
36d6956
152300e
d459b0e
97af81e
8d75708
067bc29
32e5b2d
a7175d7
4f89818
2779fd0
80f3273
454e15f
8a1febd
31b969a
c7fc102
7993b9c
b7f8ba9
3d38868
7d5c86c
747f36e
b767700
e91f1f1
7b270a3
d3d4e46
d0272df
56f234e
b7bb0b2
158d540
7d8f56c
7808b0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| prune_empty_concat_where, | ||
| remove_broadcast, | ||
| symbol_ref_utils, | ||
| unroll_tree_map, | ||
| ) | ||
| from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet | ||
| from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple | ||
|
|
@@ -176,6 +177,26 @@ def apply_common_transforms( | |
| ) # domain inference does not support dynamic offsets yet | ||
| ir = infer_domain_ops.InferDomainOps.apply(ir) | ||
| ir = concat_where.canonicalize_domain_argument(ir) | ||
| ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) | ||
|
|
||
| # After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that | ||
| # domain inference does not encounter `as_fieldop` nodes inside dead tuple elements | ||
| # (which would receive NEVER domain). Do multiple iterations for nested `let`s. | ||
| for _ in range(10): | ||
| collapsed = ir | ||
| ir = CollapseTuple.apply( | ||
| ir, | ||
| enabled_transformations=( | ||
| CollapseTuple.Transformation.PROPAGATE_TUPLE_GET | ||
| | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE | ||
| ), | ||
| uids=uids, | ||
| offset_provider_type=offset_provider_type, | ||
| ) # type: ignore[assignment] # always an itir.Program | ||
| if ir == collapsed: | ||
| break | ||
|
SF-N marked this conversation as resolved.
Outdated
|
||
| else: | ||
| raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: probably this is also the test case where the loop is required. I'll take a look if another configuration of the pass helps to avoid the loop.
SF-N marked this conversation as resolved.
Outdated
|
||
|
|
||
| ir = infer_domain.infer_program( | ||
| ir, | ||
|
|
@@ -290,6 +311,23 @@ def apply_fieldview_transforms( | |
|
|
||
| ir = infer_domain_ops.InferDomainOps.apply(ir) | ||
| ir = concat_where.canonicalize_domain_argument(ir) | ||
| ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids) | ||
| for _ in range(10): | ||
| prev = ir | ||
| ir = CollapseTuple.apply( | ||
| ir, | ||
| enabled_transformations=( | ||
| CollapseTuple.Transformation.PROPAGATE_TUPLE_GET | ||
| | CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE | ||
| ), | ||
| uids=uids, | ||
| offset_provider_type=offset_provider_type, | ||
| ) # type: ignore[assignment] # always an itir.Program | ||
| if ir == prev: | ||
| break | ||
|
SF-N marked this conversation as resolved.
Outdated
|
||
| else: | ||
| raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.") | ||
|
|
||
| ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program | ||
|
|
||
| ir = infer_domain.infer_program( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # GT4Py - GridTools Framework | ||
| # | ||
| # Copyright (c) 2014-2024, ETH Zurich | ||
| # All rights reserved. | ||
| # | ||
| # Please, refer to the LICENSE file in the root directory. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
| import dataclasses | ||
|
|
||
| from gt4py import eve | ||
| from gt4py.next import utils | ||
| from gt4py.next.iterator import ir as itir | ||
| from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im | ||
| from gt4py.next.iterator.type_system import inference as itir_inference | ||
| from gt4py.next.type_system import type_specifications as ts | ||
|
|
||
|
|
||
| def _unroll( | ||
| f: itir.Expr, | ||
| tup_types: list[ts.TupleType], | ||
| tup_exprs: list[itir.Expr], | ||
| ) -> itir.Expr: | ||
| """Recursively expand ``tree_map(f)(tup0, tup1, ...)`` into ``make_tuple`` / ``tuple_get``.""" | ||
| assert tup_types, "tree_map requires at least one tuple argument." | ||
|
SF-N marked this conversation as resolved.
Outdated
|
||
| n = len(tup_types[0].types) | ||
| if any(len(t.types) != n for t in tup_types[1:]): | ||
| raise ValueError( | ||
| f"All tree_map arguments must have the same tuple structure at each level, " | ||
| f"got {[len(t.types) for t in tup_types]}." | ||
| ) | ||
|
|
||
| elements: list[itir.Expr] = [] | ||
| for i in range(n): | ||
| child_types = [t.types[i] for t in tup_types] | ||
| child_exprs = [im.tuple_get(i, e) for e in tup_exprs] | ||
|
|
||
| all_tuples = all(isinstance(ct, ts.TupleType) for ct in child_types) | ||
| all_leaves = all(not isinstance(ct, ts.TupleType) for ct in child_types) | ||
| if all_tuples: | ||
| nested_types = [ct for ct in child_types if isinstance(ct, ts.TupleType)] | ||
| elements.append(_unroll(f, nested_types, child_exprs)) | ||
| elif all_leaves: | ||
| elements.append(im.call(f)(*child_exprs)) | ||
| else: | ||
| raise ValueError( | ||
| "All tree_map arguments must have the same tree structure " | ||
| "(all leaves must be reached simultaneously)." | ||
| ) | ||
|
|
||
| return im.make_tuple(*elements) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class UnrollTreeMap(eve.NodeTranslator): | ||
| PRESERVED_ANNEX_ATTRS = ("domain",) | ||
|
|
||
| uids: utils.IDGeneratorPool | ||
|
|
||
| @classmethod | ||
| def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): | ||
| return cls(uids=uids).visit(program) | ||
|
|
||
| def visit_FunCall(self, node: itir.FunCall): | ||
| node = self.generic_visit(node) | ||
|
|
||
| if not cpm.is_call_to(node.fun, "tree_map"): | ||
| return node | ||
|
|
||
| f = node.fun.args[0] | ||
| tup_args = node.args | ||
| tup_types: list[ts.TupleType] = [] | ||
| for tup in tup_args: | ||
| itir_inference.reinfer(tup) | ||
| assert isinstance(tup.type, ts.TupleType) | ||
| tup_types.append(tup.type) | ||
|
SF-N marked this conversation as resolved.
Outdated
|
||
|
|
||
| tup_refs = [next(self.uids["_utm"]) for _ in tup_args] | ||
| body = _unroll(f, tup_types, [im.ref(r) for r in tup_refs]) | ||
|
|
||
| result = body | ||
| for ref_name, tup in reversed(list(zip(tup_refs, tup_args))): | ||
| result = im.let(ref_name, tup)(result) | ||
|
|
||
| itir_inference.reinfer(result) | ||
| return result | ||
|
SF-N marked this conversation as resolved.
Outdated
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| # GT4Py - GridTools Framework | ||
| # | ||
| # Copyright (c) 2014-2024, ETH Zurich | ||
| # All rights reserved. | ||
| # | ||
| # Please, refer to the LICENSE file in the root directory. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from gt4py.next import common, utils | ||
| from gt4py.next.iterator import ir as itir | ||
| from gt4py.next.iterator.ir_utils import ir_makers as im | ||
| from gt4py.next.iterator.transforms.unroll_tree_map import UnrollTreeMap | ||
| from gt4py.next.type_system import type_specifications as ts | ||
|
|
||
| IDim = common.Dimension("IDim") | ||
| T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) | ||
| i_field = ts.FieldType(dims=[IDim], dtype=T) | ||
| i_tuple_field = ts.TupleType(types=[i_field, i_field]) | ||
| i_nested_tuple_field = ts.TupleType(types=[i_tuple_field, i_field]) | ||
|
|
||
| i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1)) | ||
|
|
||
|
|
||
| def _make_program(params: list[itir.Sym], expr: itir.Expr) -> itir.Program: | ||
| return itir.Program( | ||
| id="testee", | ||
| function_definitions=[], | ||
| params=[*params, im.sym("out", i_field)], | ||
| declarations=[], | ||
| body=[ | ||
| itir.SetAt( | ||
| expr=expr, | ||
| domain=i_domain, | ||
| target=im.ref("out"), | ||
| ) | ||
|
SF-N marked this conversation as resolved.
Outdated
|
||
| ], | ||
| ) | ||
|
|
||
|
|
||
| def _neg(): | ||
| return im.lambda_("__a")(im.op_as_fieldop("neg")("__a")) | ||
|
|
||
|
|
||
| def _plus(): | ||
| return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b")) | ||
|
|
||
|
|
||
| def test_multi_arg(): | ||
| uids = utils.IDGeneratorPool() | ||
| program = _make_program( | ||
| [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], | ||
| im.call(im.call("tree_map")(_plus()))( | ||
| im.ref("a", i_tuple_field), im.ref("b", i_tuple_field) | ||
| ), | ||
| ) | ||
| result = UnrollTreeMap.apply(program, uids=uids) | ||
|
|
||
| expected = _make_program( | ||
| [im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)], | ||
| im.let("_utm_0", "a")( | ||
| im.let("_utm_1", "b")( | ||
| im.make_tuple( | ||
| im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")), | ||
| im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")), | ||
| ) | ||
| ) | ||
| ), | ||
| ) | ||
| assert result == expected | ||
|
|
||
|
|
||
| def test_nested(): | ||
| uids = utils.IDGeneratorPool() | ||
| program = _make_program( | ||
| [im.sym("t", i_nested_tuple_field)], | ||
| im.call(im.call("tree_map")(_neg()))(im.ref("t", i_nested_tuple_field)), | ||
| ) | ||
| result = UnrollTreeMap.apply(program, uids=uids) | ||
|
|
||
| expected = _make_program( | ||
| [im.sym("t", i_nested_tuple_field)], | ||
| im.let("_utm_0", "t")( | ||
| im.make_tuple( | ||
| im.make_tuple( | ||
| im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "_utm_0"))), | ||
| im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "_utm_0"))), | ||
| ), | ||
| im.call(_neg())(im.tuple_get(1, "_utm_0")), | ||
| ) | ||
| ), | ||
| ) | ||
| assert result == expected | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tuple-
whereover a per-neighbour (local) condition produces invalid IR.Hardcoding the per-leaf op to
op_as_fieldop("if_")drops the local-field handling the non-tuple path keeps:_lower_and_map("if_", …)→_mapwraps the op inmap_list(if_)and promotes the condition withmake_const_listwhen the leaves are local fields. The single, uniformtree_map_tupleleaf can't do that.This is reachable from the frontend and breaks on roundtrip, gtfn and dace whenever the
wherecondition is itself a local (per-neighbour) field — the mask only has to be aFieldType, and a local field qualifies:The non-tuple equivalent (
neighbor_sum(where(a(V2E)>c(V2E), a(V2E), c(V2E)), …)) works on all three backends; the tuple version fails on all three withAssertionErrorin theif_type synthesizer (isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL), because afterUnrollTupleMapsthe leaf is⇑(λ(c,x,y) → if ·c then ·x else ·y)applied to a list-typed predicate — nomap_list.Broadcast (non-local) conditions happen to pass, since wholesale
if_(scalar, list, list)equals element-wise selection and the consuming reduce absorbs the list — only a varying condition exposes it.This is the same per-leaf type-dependent reason given for not migrating
_visit_astype. It also isn't fixable by "always emitmap_list": with a non-local cond and mixed leaves (one branch local, one not), no single uniform leaf lambda is correct. Options: keepwhereonprocess_elements, or wrapif_inmap_listpost-unroll when the predicate is a list.