-
Notifications
You must be signed in to change notification settings - Fork 58
feat[next]: Tracer support part 1: tree_map #2586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
SF-N
wants to merge
35
commits into
GridTools:main
Choose a base branch
from
SF-N:tracer_support_tree_map
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 6 commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
1b4707c
Tracer prototype
SF-N 902f8a3
Merge branch 'main' into tracer_support
SF-N 02f881f
Introduce GTIR tree_map builtin and transform to make_tuple, also sup…
SF-N 0ec4692
Run pre-commit and fix some tests
SF-N ab84ecc
Run CollapseTuple after UnrollTreeMap
SF-N 36d6956
Merge branch 'main' into tracer_support_tree_map
SF-N 152300e
Address review comments
SF-N d459b0e
Address further review comments
SF-N 97af81e
Apply review comments
SF-N 8d75708
Merge branch 'main' into tracer_support_tree_map
SF-N 067bc29
Merge branch 'main' into tracer_support_tree_map
SF-N 32e5b2d
Rename map_ -> map_list
SF-N a7175d7
Run pre-commit
SF-N 4f89818
Merge branch 'main' into tracer_support_tree_map
SF-N 2779fd0
Refactor tree_map_tuple and add map_tuple with unrolling support
SF-N 80f3273
Rename
SF-N 454e15f
Minor fix
SF-N 8a1febd
Merge branch 'main' into tracer_support_tree_map
SF-N 31b969a
Remove unnecessary CollapseTuple loop
SF-N c7fc102
Reposition UnrollTupleMaps and simplify CollapseTuple usage
SF-N 7993b9c
Merge branch 'main' into tracer_support_tree_map
SF-N b7f8ba9
Refactor tree_map unrolling
SF-N 3d38868
Cleanup
SF-N 7d5c86c
Revert "Cleanup"
SF-N 747f36e
Revert "Refactor tree_map unrolling"
SF-N b767700
Cleanup
SF-N e91f1f1
Merge branch 'origin-main' into tracer_support_tree_map
SF-N 7b270a3
Address review comment
SF-N d3d4e46
Remove CollapseTuple pass after UnrollTupleMaps
SF-N d0272df
Remove program wrapper in tests
SF-N 56f234e
Merge branch 'tracer_support_tree_map' of github.com:SF-N/gt4py into …
SF-N b7bb0b2
Fix test
SF-N 158d540
Merge branch 'main' into tracer_support_tree_map
SF-N 7d8f56c
Also allow itir.Expr in UnrollTupleMaps and run tye_inference when ne…
SF-N 7808b0f
Merge branch 'tracer_support_tree_map' of github.com:SF-N/gt4py into …
SF-N File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| # GT4Py - GridTools Framework | ||
| # | ||
| # Copyright (c) 2014-2024, ETH Zurich | ||
| # All rights reserved. | ||
| # | ||
| # Please, refer to the LICENSE file in the root directory. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
| import dataclasses | ||
|
|
||
| from gt4py import eve | ||
| from gt4py.next import utils | ||
| from gt4py.next.iterator import ir as itir | ||
| from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im | ||
| from gt4py.next.iterator.type_system import inference as itir_inference | ||
| from gt4py.next.type_system import type_specifications as ts | ||
|
|
||
|
|
||
| def _unroll( | ||
| f: itir.Expr, | ||
| tup_types: list[ts.TupleType], | ||
| tup_exprs: list[itir.Expr], | ||
| ) -> itir.Expr: | ||
| """Recursively expand ``tree_map(f)(tup0, tup1, ...)`` into ``make_tuple`` / ``tuple_get``.""" | ||
| n = len(tup_types[0].types) | ||
|
|
||
| elements: list[itir.Expr] = [] | ||
| for i in range(n): | ||
| child_types = [t.types[i] for t in tup_types] | ||
| child_exprs = [im.tuple_get(i, e) for e in tup_exprs] | ||
|
|
||
| if all(isinstance(ct, ts.TupleType) for ct in child_types): | ||
| nested_types = [ct for ct in child_types if isinstance(ct, ts.TupleType)] | ||
| elements.append(_unroll(f, nested_types, child_exprs)) | ||
| else: | ||
| elements.append(im.call(f)(*child_exprs)) | ||
|
SF-N marked this conversation as resolved.
Outdated
|
||
|
|
||
| return im.make_tuple(*elements) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class UnrollTreeMap(eve.NodeTranslator): | ||
| PRESERVED_ANNEX_ATTRS = ("domain",) | ||
|
|
||
| uids: utils.IDGeneratorPool | ||
|
|
||
| @classmethod | ||
| def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): | ||
| return cls(uids=uids).visit(program) | ||
|
|
||
| def visit_FunCall(self, node: itir.FunCall): | ||
| node = self.generic_visit(node) | ||
|
|
||
| if not cpm.is_call_to(node.fun, "tree_map"): | ||
| return node | ||
|
|
||
| f = node.fun.args[0] | ||
| tup_args = node.args | ||
| tup_types: list[ts.TupleType] = [] | ||
| for tup in tup_args: | ||
| itir_inference.reinfer(tup) | ||
| assert isinstance(tup.type, ts.TupleType) | ||
| tup_types.append(tup.type) | ||
|
SF-N marked this conversation as resolved.
Outdated
|
||
|
|
||
| tup_refs = [next(self.uids["_utm"]) for _ in tup_args] | ||
| body = _unroll(f, tup_types, [im.ref(r) for r in tup_refs]) | ||
|
|
||
| result = body | ||
| for ref_name, tup in reversed(list(zip(tup_refs, tup_args))): | ||
| result = im.let(ref_name, tup)(result) | ||
|
|
||
| itir_inference.reinfer(result) | ||
| return result | ||
|
SF-N marked this conversation as resolved.
Outdated
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
43 changes: 43 additions & 0 deletions
43
tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_tree_map.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # GT4Py - GridTools Framework | ||
| # | ||
| # Copyright (c) 2014-2024, ETH Zurich | ||
| # All rights reserved. | ||
| # | ||
| # Please, refer to the LICENSE file in the root directory. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from gt4py.next.iterator.ir_utils import ir_makers as im | ||
| from gt4py.next.iterator.transforms.unroll_tree_map import _unroll | ||
| from gt4py.next.type_system import type_specifications as ts | ||
|
|
||
|
|
||
| T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) | ||
| TT = ts.TupleType(types=[T, T]) | ||
|
|
||
|
|
||
| def test_single_arg(): | ||
| result = _unroll(im.ref("f"), [TT], [im.ref("t")]) | ||
| expected = im.make_tuple(im.call("f")(im.tuple_get(0, "t")), im.call("f")(im.tuple_get(1, "t"))) | ||
|
SF-N marked this conversation as resolved.
Outdated
|
||
| assert result == expected | ||
|
|
||
|
|
||
| def test_multi_arg(): | ||
| result = _unroll(im.ref("f"), [TT, TT], [im.ref("a"), im.ref("b")]) | ||
| expected = im.make_tuple( | ||
| im.call("f")(im.tuple_get(0, "a"), im.tuple_get(0, "b")), | ||
| im.call("f")(im.tuple_get(1, "a"), im.tuple_get(1, "b")), | ||
| ) | ||
| assert result == expected | ||
|
|
||
|
|
||
| def test_nested(): | ||
| outer = ts.TupleType(types=[TT, T]) | ||
| result = _unroll(im.ref("f"), [outer], [im.ref("t")]) | ||
| expected = im.make_tuple( | ||
| im.make_tuple( | ||
| im.call("f")(im.tuple_get(0, im.tuple_get(0, "t"))), | ||
| im.call("f")(im.tuple_get(1, im.tuple_get(0, "t"))), | ||
| ), | ||
| im.call("f")(im.tuple_get(1, "t")), | ||
| ) | ||
| assert result == expected | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tuple-
whereover a per-neighbour (local) condition produces invalid IR.Hardcoding the per-leaf op to
op_as_fieldop("if_")drops the local-field handling the non-tuple path keeps:_lower_and_map("if_", …)→_mapwraps the op inmap_list(if_)and promotes the condition withmake_const_listwhen the leaves are local fields. The single, uniformtree_map_tupleleaf can't do that.This is reachable from the frontend and breaks on roundtrip, gtfn and dace whenever the
wherecondition is itself a local (per-neighbour) field — the mask only has to be aFieldType, and a local field qualifies:The non-tuple equivalent (
neighbor_sum(where(a(V2E)>c(V2E), a(V2E), c(V2E)), …)) works on all three backends; the tuple version fails on all three withAssertionErrorin theif_type synthesizer (isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL), because afterUnrollTupleMapsthe leaf is⇑(λ(c,x,y) → if ·c then ·x else ·y)applied to a list-typed predicate — nomap_list.Broadcast (non-local) conditions happen to pass, since wholesale
if_(scalar, list, list)equals element-wise selection and the consuming reduce absorbs the list — only a varying condition exposes it.This is the same per-leaf type-dependent reason given for not migrating
_visit_astype. It also isn't fixable by "always emitmap_list": with a non-local cond and mixed leaves (one branch local, one not), no single uniform leaf lambda is correct. Options: keepwhereonprocess_elements, or wrapif_inmap_listpost-unroll when the predicate is a list.