Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1b4707c
Tracer prototype
SF-N Feb 20, 2026
902f8a3
Merge branch 'main' into tracer_support
SF-N Apr 27, 2026
02f881f
Introduce GTIR tree_map builtin and transform to make_tuple, also sup…
SF-N Apr 27, 2026
0ec4692
Run pre-commit and fix some tests
SF-N Apr 27, 2026
ab84ecc
Run CollapseTuple after UnrollTreeMap
SF-N Apr 28, 2026
36d6956
Merge branch 'main' into tracer_support_tree_map
SF-N Apr 28, 2026
152300e
Address review comments
SF-N Apr 28, 2026
d459b0e
Address further review comments
SF-N Apr 28, 2026
97af81e
Apply review comments
SF-N Apr 29, 2026
8d75708
Merge branch 'main' into tracer_support_tree_map
SF-N Apr 29, 2026
067bc29
Merge branch 'main' into tracer_support_tree_map
SF-N May 4, 2026
32e5b2d
Rename map_ -> map_list
SF-N May 28, 2026
a7175d7
Run pre-commit
SF-N May 28, 2026
4f89818
Merge branch 'main' into tracer_support_tree_map
SF-N May 28, 2026
2779fd0
Refactor tree_map_tuple and add map_tuple with unrolling support
SF-N May 28, 2026
80f3273
Rename
SF-N May 28, 2026
454e15f
Minor fix
SF-N May 28, 2026
8a1febd
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 2, 2026
31b969a
Remove unnecessary CollapseTuple loop
SF-N Jun 2, 2026
c7fc102
Reposition UnrollTupleMaps and simplify CollapseTuple usage
SF-N Jun 2, 2026
7993b9c
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 2, 2026
b7f8ba9
Refactor tree_map unrolling
SF-N Jun 16, 2026
3d38868
Cleanup
SF-N Jun 16, 2026
7d5c86c
Revert "Cleanup"
SF-N Jun 17, 2026
747f36e
Revert "Refactor tree_map unrolling"
SF-N Jun 17, 2026
b767700
Cleanup
SF-N Jun 17, 2026
e91f1f1
Merge branch 'origin-main' into tracer_support_tree_map
SF-N Jun 17, 2026
7b270a3
Address review comment
SF-N Jun 18, 2026
d3d4e46
Remove CollapseTuple pass after UnrollTupleMaps
SF-N Jun 19, 2026
d0272df
Remove program wrapper in tests
SF-N Jun 19, 2026
56f234e
Merge branch 'tracer_support_tree_map' of github.com:SF-N/gt4py into …
SF-N Jun 19, 2026
b7bb0b2
Fix test
SF-N Jun 19, 2026
158d540
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 24, 2026
7d8f56c
Also allow itir.Expr in UnrollTupleMaps and run tye_inference when ne…
SF-N Jun 24, 2026
7808b0f
Merge branch 'tracer_support_tree_map' of github.com:SF-N/gt4py into …
SF-N Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def _map(
original_arg_types: tuple[ts.TypeSpec, ...],
) -> itir.FunCall:
"""
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists.
Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_list`ing lists.
"""
if all(
isinstance(t, (ts.ScalarType, ts.DimensionType, ts.DomainType))
Expand All @@ -547,7 +547,7 @@ def _map(
promote_to_list(arg_type)(larg)
for arg_type, larg in zip(original_arg_types, lowered_args)
)
op = im.map_(op)
op = im.map_list(op)

return im.op_as_fieldop(op)(*lowered_args)

Expand Down
136 changes: 134 additions & 2 deletions src/gt4py/next/ffront/lowering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import functools
from collections.abc import Iterable
from typing import Callable, Optional, TypeVar

from gt4py.next import utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.type_system import type_info, type_specifications as ts


# TODO(tehrengruber): The code quality of this function is poor. We should rewrite it.
Expand Down Expand Up @@ -84,3 +86,133 @@ def _process_elements_impl(
result = process_func(*_current_el_exprs)

return result


def _collapsing_tuple_get(expr: itir.Expr, i: int) -> itir.Expr:
"""Like `im.tuple_get`, but collapses immediately when `expr` is a `make_tuple` call.

Note: argument order is `(expr, i)` to allow use as a `functools.reduce` reducer.
"""
if cpm.is_call_to(expr, "make_tuple"):
return expr.args[i]
return im.tuple_get(i, expr)


def _tree_map_tuple_body(
f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType]
) -> itir.Expr:
"""Recursively unroll `tree_map_tuple(f)(t1, ..., tN)` into `make_tuple` calls."""

@utils.tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda _, elts: im.make_tuple(*elts),
with_path_arg=True,
)
def mapper(*args: ts.TypeSpec | tuple[int, ...]) -> itir.Expr:
*_el_types, path = args
assert isinstance(path, tuple), "Expected path to be tuple[int, ...]"
return im.call(f)(
*(functools.reduce(_collapsing_tuple_get, path, tup_expr) for tup_expr in tup_exprs)
)

return mapper(*tup_types)


def _map_tuple_body(
f: itir.Expr, tup_exprs: list[itir.Expr], tup_types: list[ts.TupleType]
) -> itir.Expr:
"""Unroll `map_tuple(f)(t)` over top-level elements only (no recursion)."""
(tup_expr,) = tup_exprs
(tup_type,) = tup_types
return im.make_tuple(
*(im.call(f)(_collapsing_tuple_get(tup_expr, i)) for i in range(len(tup_type.types)))
)


_UNROLLERS = {
"tree_map_tuple": _tree_map_tuple_body,
"map_tuple": _map_tuple_body,
}


def _unroll_tuple_map(
builtin_name: str,
f: itir.Expr,
tup_exprs: Iterable[itir.Expr],
tup_types: Iterable[ts.TypeSpec],
*,
uids: utils.IDGeneratorPool,
) -> itir.Expr:
tup_exprs = list(tup_exprs)
tup_types_list = list(tup_types)
for tup_type in tup_types_list:
if not isinstance(tup_type, ts.TupleType):
raise TypeError(
f"'{builtin_name}' requires all arguments to be tuples, got '{tup_type}'."
)
tup_types_validated: list[ts.TupleType] = tup_types_list # type: ignore[assignment]

if not type_info.tuple_structures_match(*tup_types_validated):
raise TypeError(
f"'{builtin_name}' requires all arguments to share the same (nested) tuple "
f"structure, got {[str(t) for t in tup_types_validated]}."
)

# For trivial args (those that can be duplicated without cost or side effects),
# we substitute them directly into the body. This avoids leaving behind
# `tuple_get(i, make_tuple(...))` patterns that would otherwise require a
# separate cleanup pass (`CollapseTuple`). For non-trivial args we still
# introduce a `let` binding to avoid duplicating expensive sub-expressions.
substituted_exprs: list[itir.Expr] = []
let_bindings: list[tuple[str, itir.Expr]] = []
for tup in tup_exprs:
if isinstance(tup, (itir.SymRef, itir.Literal)) or cpm.is_call_to(tup, "make_tuple"):
substituted_exprs.append(tup)
else:
ref_name = next(uids["__utm"])
let_bindings.append((ref_name, tup))
substituted_exprs.append(im.ref(ref_name))

body = _UNROLLERS[builtin_name](f, substituted_exprs, tup_types_validated)
return im.let(*let_bindings)(body) if let_bindings else body


def unroll_tree_map_tuple(
f: itir.Expr,
tup_exprs: Iterable[itir.Expr],
tup_types: Iterable[ts.TypeSpec],
*,
uids: utils.IDGeneratorPool,
) -> itir.Expr:
"""
Lower ``tree_map_tuple(f)(t1, ..., tN)`` to explicit ``make_tuple`` calls, recursing into
nested tuples and applying ``f`` to each leaf.

Args:
f: The function to apply at each leaf.
tup_exprs: The (already lowered) tuple argument expressions.
tup_types: The type of each argument in ``tup_exprs``; all must be ``TupleType`` and
share the same (nested) structure.
uids: Used to generate fresh names for `let`-bindings of non-trivial arguments.
"""
return _unroll_tuple_map("tree_map_tuple", f, tup_exprs, tup_types, uids=uids)


def unroll_map_tuple(
f: itir.Expr,
tup_expr: itir.Expr,
tup_type: ts.TypeSpec,
*,
uids: utils.IDGeneratorPool,
) -> itir.Expr:
"""
Lower ``map_tuple(f)(t)`` to an explicit ``make_tuple`` call, applying ``f`` to each
top-level element only (no recursion).

Args:
f: The function to apply to each top-level element.
tup_expr: The (already lowered) tuple argument expression.
tup_type: The type of ``tup_expr``; must be a ``TupleType``.
uids: Used to generate a fresh name for a `let`-binding of a non-trivial argument.
"""
return _unroll_tuple_map("map_tuple", f, (tup_expr,), (tup_type,), uids=uids)
16 changes: 14 additions & 2 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,17 @@ def neighbors(*args):


@builtin_dispatch
def map_(*args):
def map_list(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def tree_map_tuple(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def map_tuple(*args):
raise BackendNotSelectedError()


Expand Down Expand Up @@ -498,7 +508,9 @@ def get_domain_range(*args):
"lift",
"make_const_list",
"make_tuple",
"map_",
"tree_map_tuple",
"map_tuple",
"map_list",
"named_range",
"neighbors",
"reduce",
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1455,8 +1455,8 @@ def _get_offset(*lists: _List | _ConstList) -> Optional[runtime.Offset]:
raise AssertionError("All lists must have the same offset.")


@builtins.map_.register(EMBEDDED)
def map_(op):
@builtins.map_list.register(EMBEDDED)
def map_list(op):
def impl_(*lists):
offset = _get_offset(*lists)
if offset is None:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def is_applied_map(arg: itir.Node) -> TypeGuard[_FunCallToFunCallToRef]:
isinstance(arg, itir.FunCall)
and isinstance(arg.fun, itir.FunCall)
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "map_"
and arg.fun.fun.id == "map_list"
)


Expand Down
16 changes: 13 additions & 3 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,9 +624,19 @@ def index(dim: common.Dimension) -> itir.FunCall:
return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind))


def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))
def map_list(op):
"""Create a `map_list` call."""
return call(call("map_list")(op))


def tree_map_tuple(op):
"""Create a `tree_map_tuple` call: tree_map_tuple(op)(tup1, tup2, ...)."""
return call(call("tree_map_tuple")(op))


def map_tuple(op):
"""Create a `map_tuple` call: map_tuple(op)(tup)."""
return call(call("map_tuple")(op))


def reduce(op, expr):
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/collapse_list_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node:
if cpm.is_call_to(node.args[1], "make_const_list"):
return node.args[1].args[0]
if cpm.is_applied_map(node.args[1]):
# list_get(0, map_(λ(val_) → foo(val_, int64))(·__sym_1))
# list_get(0, map_list(λ(val_) → foo(val_, int64))(·__sym_1))
# -> (λ(val_) → foo(val_, int64))(list_get(0, ·__sym_1))
lsts = node.args[1].args
assert len(node.args[1].fun.args) == 1 # a single lambda in the map
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _is_collectable_expr(node: itir.Node) -> bool:
if isinstance(node, itir.FunCall):
# do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be
# visited, to ensure symbol dependencies are recognized correctly.
# do also not collect reduce, map_ and neighbors nodes if they are left in the IR at this point, this may lead to
# do also not collect reduce, map_list and neighbors nodes if they are left in the IR at this point, this may lead to
# conceptual problems (other parts of the tool chain rely on the arguments being present directly
# on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend
# backend (single pass eager depth first visit approach), see also https://github.com/GridTools/gt4py/issues/1795
Expand All @@ -104,7 +104,7 @@ def _is_collectable_expr(node: itir.Node) -> bool:
# do also not collect index nodes because otherwise the right hand side of SetAts becomes a let statement
# instead of an as_fieldop
if cpm.is_call_to(
node, ("lift", "shift", "neighbors", "reduce", "map_", "index")
node, ("lift", "shift", "neighbors", "reduce", "map_list", "index")
) or cpm.is_applied_lift(node):
return False
return True
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Fuses nested `map_`s.
Fuses nested `map_list`s.

Preconditions:
- `FunctionDefinitions` are inlined
Expand All @@ -29,7 +29,7 @@ class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrai
to
map(λ(a, b, c) → f(a, g(b, c)))(a, b, c)

reduce(λ(x, y) → f(x, y), init)(map_(g(z, w))(a, b))
reduce(λ(x, y) → f(x, y), init)(map_list(g(z, w))(a, b))
to
reduce(λ(x, y, z) → f(x, g(y, z)), init)(a, b)
"""
Expand Down Expand Up @@ -93,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
new_op = ir.Lambda(params=new_params, expr=new_body)
if cpm.is_applied_map(node):
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args
fun=ir.FunCall(fun=ir.SymRef(id="map_list"), args=[new_op]), args=new_args
)
else: # is_applied_reduce(node)
return ir.FunCall(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/transforms/trace_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def applied_as_fieldop(*args):
"scan": _scan,
"reduce": _reduce,
"neighbors": _neighbors,
"map_": _map,
"map_list": _map,
"if_": _if,
"make_tuple": _make_tuple,
}
Expand Down
52 changes: 49 additions & 3 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from gt4py.next.iterator import builtins, ir as itir
from gt4py.next.iterator.type_system import type_specifications as it_ts
from gt4py.next.type_system import type_info, type_specifications as ts
from gt4py.next.utils import tree_map


def _type_synth_arg_cache_key(type_or_synth: TypeOrTypeSynthesizer) -> int:
Expand Down Expand Up @@ -203,7 +202,7 @@ def if_(
pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType
) -> ts.DataType:
if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType):
return tree_map(
return utils.tree_map(
collection_type=ts.TupleType,
result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]),
)(functools.partial(if_, pred))(true_branch, false_branch)
Expand Down Expand Up @@ -615,7 +614,7 @@ def apply_scan(


@_register_builtin_type_synthesizer
def map_(op: TypeSynthesizer) -> TypeSynthesizer:
def map_list(op: TypeSynthesizer) -> TypeSynthesizer:
@type_synthesizer
def applied_map(
*args: ts.ListType, offset_provider_type: common.OffsetProviderType
Expand All @@ -633,6 +632,53 @@ def applied_map(
return applied_map


def _tuple_map_synthesizer(
builtin_name: str, *, recursive: bool
) -> Callable[..., TypeOrTypeSynthesizer]:
"""Shared implementation for `tree_map_tuple` (recursive) and `map_tuple` (top-level)."""

def factory(op: TypeSynthesizer) -> TypeSynthesizer:
@type_synthesizer
def applied_map(
*args: ts.TupleType, offset_provider_type: common.OffsetProviderType
) -> ts.TupleType:
if not args:
raise TypeError(f"'{builtin_name}' requires at least one argument.")
if not recursive and len(args) != 1:
raise TypeError(f"'{builtin_name}' requires exactly one argument, got {len(args)}.")
if not all(isinstance(a, ts.TupleType) for a in args):
raise TypeError(
f"'{builtin_name}' requires all top-level arguments to be TupleType, "
f"got {[type(a).__name__ for a in args]}."
)

def leaf_op(*leaf_types: ts.TypeSpec) -> ts.TypeSpec:
return op(*leaf_types, offset_provider_type=offset_provider_type) # type: ignore[return-value]

if recursive:
return utils.tree_map( # type: ignore[return-value]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mismatched tuple structure raises a bare AssertionError instead of a clear TypeError.

The check above guarantees all args are TupleType, but not that they share arity/nesting. When they differ, utils.tree_map trips its internal assert ... len(args[0]) == len(arg) (or the all-collection assert during recursion) and raises an AssertionError with no message:

im.tree_map_tuple(im.ref("plus"))(
    im.ref("t1", ts.TupleType(types=[int_t, int_t, int_t])),
    im.ref("t2", ts.TupleType(types=[int_t, int_t])),
)   # -> AssertionError (empty message); also for mismatched nesting

Not reachable from the frontend today — where pre-validates branch structure with a DSLError — but it'll bite the upcoming map_tuple/tracer producers and anyone building IR directly. A TypeError here (and in the matching spot in UnrollTupleMaps) would be friendlier.

leaf_op,
collection_type=ts.TupleType,
result_collection_constructor=lambda _, elts: ts.TupleType(types=[*elts]),
)(*args)

# Non-recursive: apply `op` once per top-level element.
(arg,) = args
return ts.TupleType(types=[leaf_op(el) for el in arg.types])

return applied_map

return factory


tree_map_tuple = _register_builtin_type_synthesizer(
_tuple_map_synthesizer("tree_map_tuple", recursive=True), fun_names=["tree_map_tuple"]
)
map_tuple = _register_builtin_type_synthesizer(
_tuple_map_synthesizer("map_tuple", recursive=False), fun_names=["map_tuple"]
)


@_register_builtin_type_synthesizer
def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer:
@type_synthesizer
Expand Down
Loading