Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1b4707c
Tracer prototype
SF-N Feb 20, 2026
902f8a3
Merge branch 'main' into tracer_support
SF-N Apr 27, 2026
02f881f
Introduce GTIR tree_map builtin and transform to make_tuple, also sup…
SF-N Apr 27, 2026
0ec4692
Run pre-commit and fix some tests
SF-N Apr 27, 2026
ab84ecc
Run CollapseTuple after UnrollTreeMap
SF-N Apr 28, 2026
36d6956
Merge branch 'main' into tracer_support_tree_map
SF-N Apr 28, 2026
152300e
Address review comments
SF-N Apr 28, 2026
d459b0e
Address further review comments
SF-N Apr 28, 2026
97af81e
Apply review comments
SF-N Apr 29, 2026
8d75708
Merge branch 'main' into tracer_support_tree_map
SF-N Apr 29, 2026
067bc29
Merge branch 'main' into tracer_support_tree_map
SF-N May 4, 2026
32e5b2d
Rename map_ -> map_list
SF-N May 28, 2026
a7175d7
Run pre-commit
SF-N May 28, 2026
4f89818
Merge branch 'main' into tracer_support_tree_map
SF-N May 28, 2026
2779fd0
Refactor tree_map_tuple and add map_tuple with unrolling support
SF-N May 28, 2026
80f3273
Rename
SF-N May 28, 2026
454e15f
Minor fix
SF-N May 28, 2026
8a1febd
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 2, 2026
31b969a
Remove unnecessary CollapseTuple loop
SF-N Jun 2, 2026
c7fc102
Reposition UnrollTupleMaps and simplify CollapseTuple usage
SF-N Jun 2, 2026
7993b9c
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 2, 2026
b7f8ba9
Refactor tree_map unrolling
SF-N Jun 16, 2026
3d38868
Cleanup
SF-N Jun 16, 2026
7d5c86c
Revert "Cleanup"
SF-N Jun 17, 2026
747f36e
Revert "Refactor tree_map unrolling"
SF-N Jun 17, 2026
b767700
Cleanup
SF-N Jun 17, 2026
e91f1f1
Merge branch 'origin-main' into tracer_support_tree_map
SF-N Jun 17, 2026
7b270a3
Address review comment
SF-N Jun 18, 2026
d3d4e46
Remove CollapseTuple pass after UnrollTupleMaps
SF-N Jun 19, 2026
d0272df
Remove program wrapper in tests
SF-N Jun 19, 2026
56f234e
Merge branch 'tracer_support_tree_map' of github.com:SF-N/gt4py into …
SF-N Jun 19, 2026
b7bb0b2
Fix test
SF-N Jun 19, 2026
158d540
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 24, 2026
7d8f56c
Also allow itir.Expr in UnrollTupleMaps and run tye_inference when ne…
SF-N Jun 24, 2026
7808b0f
Merge branch 'tracer_support_tree_map' of github.com:SF-N/gt4py into …
SF-N Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,23 +412,15 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self._lower_and_map("if_", *node.args)

cond_ = self.visit(node.args[0])
true_ = self.visit(node.args[1])
false_ = self.visit(node.args[2])
cond_symref_name = f"__cond_{cond_.fingerprint()}"

def create_if(
true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec]
) -> itir.FunCall:
return _map(
"if_",
(im.ref(cond_symref_name), true_, false_),
(node.args[0].type, *arg_types),
result = im.tree_map_tuple(
im.lambda_("__a", "__b")(
im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b"))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tuple-where over a per-neighbour (local) condition produces invalid IR.

Hardcoding the per-leaf op to op_as_fieldop("if_") drops the local-field handling the non-tuple path keeps: _lower_and_map("if_", …)_map wraps the op in map_list(if_) and promotes the condition with make_const_list when the leaves are local fields. The single, uniform tree_map_tuple leaf can't do that.

This is reachable from the frontend and breaks on roundtrip, gtfn and dace whenever the where condition is itself a local (per-neighbour) field — the mask only has to be a FieldType, and a local field qualifies:

@gtx.field_operator
def tup(a: EdgeF, b: EdgeF, c: EdgeF, d: EdgeF) -> tuple[VertF, VertF]:
    cond = a(V2E) > c(V2E)                               # per-neighbour (local) bool field
    t = where(cond, (a(V2E), b(V2E)), (c(V2E), d(V2E)))
    return (neighbor_sum(t[0], axis=V2EDim), neighbor_sum(t[1], axis=V2EDim))

The non-tuple equivalent (neighbor_sum(where(a(V2E)>c(V2E), a(V2E), c(V2E)), …)) works on all three backends; the tuple version fails on all three with AssertionError in the if_ type synthesizer (isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL), because after UnrollTupleMaps the leaf is ⇑(λ(c,x,y) → if ·c then ·x else ·y) applied to a list-typed predicate — no map_list.

Broadcast (non-local) conditions happen to pass, since wholesale if_(scalar, list, list) equals element-wise selection and the consuming reduce absorbs the list — only a varying condition exposes it.

This is the same per-leaf type-dependent reason given for not migrating _visit_astype. It also isn't fixable by "always emit map_list": with a non-local cond and mixed leaves (one branch local, one not), no single uniform leaf lambda is correct. Options: keep where on process_elements, or wrap if_ in map_list post-unroll when the predicate is a list.

)

result = lowering_utils.process_elements(
create_if,
(self.visit(node.args[1]), self.visit(node.args[2])),
node.type,
arg_types=(node.args[1].type, node.args[2].type),
)
)(true_, false_)

return im.let(cond_symref_name, cond_)(result)

Expand Down Expand Up @@ -534,7 +526,7 @@ def _map(
original_arg_types: tuple[ts.TypeSpec, ...],
) -> itir.FunCall:
"""
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_list`ing lists.
"""
if all(
isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType))
Expand All @@ -547,7 +539,7 @@ def _map(
promote_to_list(arg_type)(larg)
for arg_type, larg in zip(original_arg_types, lowered_args)
)
op = im.map_(op)
op = im.map_list(op)

return im.op_as_fieldop(op)(*lowered_args)

Expand Down
16 changes: 14 additions & 2 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,17 @@ def neighbors(*args):


@builtin_dispatch
def map_(*args):
def map_list(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def tree_map_tuple(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def map_tuple(*args):
raise BackendNotSelectedError()


Expand Down Expand Up @@ -498,7 +508,9 @@ def get_domain_range(*args):
"lift",
"make_const_list",
"make_tuple",
"map_",
"tree_map_tuple",
"map_tuple",
"map_list",
"named_range",
"neighbors",
"reduce",
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,8 +1455,8 @@ def _get_offset(*lists: _List | _ConstList) -> Optional[runtime.Offset]:
raise AssertionError("All lists must have the same offset.")


@builtins.map_.register(EMBEDDED)
def map_(op):
@builtins.map_list.register(EMBEDDED)
def map_list(op):
def impl_(*lists):
offset = _get_offset(*lists)
if offset is None:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def is_applied_map(arg: itir.Node) -> TypeGuard[_FunCallToFunCallToRef]:
isinstance(arg, itir.FunCall)
and isinstance(arg.fun, itir.FunCall)
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "map_"
and arg.fun.fun.id == "map_list"
)


Expand Down
16 changes: 13 additions & 3 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,9 +624,19 @@ def index(dim: common.Dimension) -> itir.FunCall:
return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind))


def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))
def map_list(op):
"""Create a `map_list` call."""
return call(call("map_list")(op))


def tree_map_tuple(op):
"""Create a `tree_map_tuple` call: tree_map_tuple(op)(tup1, tup2, ...)."""
return call(call("tree_map_tuple")(op))


def map_tuple(op):
"""Create a `map_tuple` call: map_tuple(op)(tup)."""
return call(call("map_tuple")(op))


def reduce(op, expr):
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/collapse_list_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node:
if cpm.is_call_to(node.args[1], "make_const_list"):
return node.args[1].args[0]
if cpm.is_applied_map(node.args[1]):
# list_get(0, map_(λ(val_) → foo(val_, int64))(·__sym_1))
# list_get(0, map_list(λ(val_) → foo(val_, int64))(·__sym_1))
# -> (λ(val_) → foo(val_, int64))(list_get(0, ·__sym_1))
lsts = node.args[1].args
assert len(node.args[1].fun.args) == 1 # a single lambda in the map
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _is_collectable_expr(node: itir.Node) -> bool:
if isinstance(node, itir.FunCall):
# do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be
# visited, to ensure symbol dependencies are recognized correctly.
# do also not collect reduce, map_ and neighbors nodes if they are left in the IR at this point, this may lead to
# do also not collect reduce, map_list and neighbors nodes if they are left in the IR at this point, this may lead to
# conceptual problems (other parts of the tool chain rely on the arguments being present directly
# on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend
# backend (single pass eager depth first visit approach), see also https://github.com/GridTools/gt4py/issues/1795
Expand All @@ -104,7 +104,7 @@ def _is_collectable_expr(node: itir.Node) -> bool:
# do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement
# instead of an as_fieldop
if cpm.is_call_to(
node, ("lift", "shift", "neighbors", "reduce", "map_", "index")
node, ("lift", "shift", "neighbors", "reduce", "map_list", "index")
) or cpm.is_applied_lift(node):
return False
return True
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Fuses nested `map_`s.
Fuses nested `map_list`s.

Preconditions:
- `FunctionDefinitions` are inlined
Expand All @@ -29,7 +29,7 @@ class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrai
to
map(λ(a, b, c) → f(a, g(b, c)))(a, b, c)

reduce(λ(x, y) → f(x, y), init)(map_(g(z, w))(a, b))
reduce(λ(x, y) → f(x, y), init)(map_list(g(z, w))(a, b))
to
reduce(λ(x, y, z) → f(x, g(y, z)), init)(a, b)
"""
Expand Down Expand Up @@ -93,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
new_op = ir.Lambda(params=new_params, expr=new_body)
if cpm.is_applied_map(node):
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args
fun=ir.FunCall(fun=ir.SymRef(id="map_list"), args=[new_op]), args=new_args
)
else: # is_applied_reduce(node)
return ir.FunCall(
Expand Down
32 changes: 32 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
prune_empty_concat_where,
remove_broadcast,
symbol_ref_utils,
unroll_tuple_maps,
)
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
Expand Down Expand Up @@ -169,6 +170,23 @@ def apply_common_transforms(
ir = inline_lifts.InlineLifts().visit(ir)

ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
# `UnrollTupleMaps` requires fully-inferred tuple types (relies on `reinfer` to see
# nested `TupleType` chains). `expand_tuple_args` runs full type inference, so this is
# the earliest safe position.
ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids)
# `UnrollTupleMaps` collapses `tuple_get(i, make_tuple(...))` patterns on the fly
# for trivial arguments, so no additional `CollapseTuple` cleanup loop is needed.
# A single `CollapseTuple` pass still handles any residual patterns produced when
# arguments had to be let-bound (non-trivial sub-expressions).
ir = CollapseTuple.apply(
ir,
enabled_transformations=(
CollapseTuple.Transformation.PROPAGATE_TUPLE_GET
| CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE
),
uids=uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
Comment thread
SF-N marked this conversation as resolved.
Outdated
ir = dead_code_elimination.dead_code_elimination(
ir, uids=uids, offset_provider_type=offset_provider_type
) # domain inference does not support dead-code
Expand Down Expand Up @@ -282,6 +300,19 @@ def apply_fieldview_transforms(
ir = inline_fundefs.prune_unreferenced_fundefs(ir)
# required for dead-code-elimination and `prune_empty_concat_where` pass
ir = concat_where.expand_tuple_args(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program
# `UnrollTupleMaps` requires fully-inferred tuple types; `expand_tuple_args` runs full
# type inference, so this is the earliest safe position.
ir = unroll_tuple_maps.UnrollTupleMaps.apply(ir, uids=uids)
# See note in `apply_common_transforms` about why a single `CollapseTuple` pass suffices.
ir = CollapseTuple.apply(
ir,
enabled_transformations=(
CollapseTuple.Transformation.PROPAGATE_TUPLE_GET
| CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE
),
uids=uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
Comment thread
SF-N marked this conversation as resolved.
Outdated
ir = dead_code_elimination.dead_code_elimination(
ir, offset_provider_type=offset_provider_type, uids=uids
)
Expand All @@ -291,6 +322,7 @@ def apply_fieldview_transforms(

ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)

ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program

ir = infer_domain.infer_program(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/trace_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def applied_as_fieldop(*args):
"scan": _scan,
"reduce": _reduce,
"neighbors": _neighbors,
"map_": _map,
"map_list": _map,
"if_": _if,
"make_tuple": _make_tuple,
}
Expand Down
113 changes: 113 additions & 0 deletions src/gt4py/next/iterator/transforms/unroll_tuple_maps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses
import functools

from gt4py import eve
from gt4py.next import utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.type_system import inference as itir_inference
from gt4py.next.type_system import type_specifications as ts


def _collapsing_tuple_get(expr: itir.Expr, i: int) -> itir.Expr:
"""Like `im.tuple_get`, but collapses immediately when `expr` is a `make_tuple` call.

Note: argument order is `(expr, i)` to allow use as a `functools.reduce` reducer.
"""
if cpm.is_call_to(expr, "make_tuple"):
return expr.args[i]
return im.tuple_get(i, expr)


def _tree_map_tuple_body(
f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType]
) -> itir.Expr:
"""Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls."""

@utils.tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda _, elts: im.make_tuple(*elts),
with_path_arg=True,
)
def mapper(*args):
*_el_types, path = args
return im.call(f)(
*(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs)
)

return mapper(*tup_types)


def _map_tuple_body(
f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType]
) -> itir.Expr:
"""Unroll `map_tuple(f)(t)` over top-level elements only (no recursion)."""
(tup_expr,) = tup_exprs
(tup_type,) = tup_types
return im.make_tuple(
*(im.call(f)(_collapsing_tuple_get(tup_expr, i)) for i in range(len(tup_type.types)))
)


_UNROLLERS = {
"tree_map_tuple": _tree_map_tuple_body,
"map_tuple": _map_tuple_body,
}


@dataclasses.dataclass
class UnrollTupleMaps(eve.NodeTranslator):
"""Unroll tuple-map ITIR builtins (`tree_map_tuple`, `map_tuple`) into `make_tuple`."""

PRESERVED_ANNEX_ATTRS = ("domain",)

uids: utils.IDGeneratorPool

@classmethod
def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool):
return cls(uids=uids).visit(program)
Comment thread
SF-N marked this conversation as resolved.
Outdated

def visit_FunCall(self, node: itir.FunCall):
node = self.generic_visit(node)

builtin_name = next((name for name in _UNROLLERS if cpm.is_call_to(node.fun, name)), None)
if builtin_name is None:
return node

assert isinstance(node.fun, itir.FunCall)
f = node.fun.args[0]
tup_args = node.args

tup_types: list[ts.TupleType] = []
for tup in tup_args:
itir_inference.reinfer(tup)
assert isinstance(tup.type, ts.TupleType)
tup_types.append(tup.type)

# For trivial args (those that can be duplicated without cost or side effects),
# we substitute them directly into the body. This avoids leaving behind
# `tuple_get(i, make_tuple(...))` patterns that would otherwise require a
# separate cleanup pass (CollapseTuple). For non-trivial args we still
# introduce a `let` binding to avoid duplicating expensive sub-expressions.
substituted_exprs: list[itir.Expr] = []
let_bindings: list[tuple[str, itir.Expr]] = []
for tup in tup_args:
if isinstance(tup, (itir.SymRef, itir.Literal)) or cpm.is_call_to(tup, "make_tuple"):
substituted_exprs.append(tup)
else:
ref_name = next(self.uids["_utm"])
let_bindings.append((ref_name, tup))
substituted_exprs.append(im.ref(ref_name))

body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types)

result = im.let(*let_bindings)(body) if let_bindings else body
itir_inference.reinfer(result)
return result
Loading