Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1b4707c
Tracer prototype
SF-N Feb 20, 2026
902f8a3
Merge branch 'main' into tracer_support
SF-N Apr 27, 2026
02f881f
Introduce GTIR tree_map builtin and transform to make_tuple, also sup…
SF-N Apr 27, 2026
0ec4692
Run pre-commit and fix some tests
SF-N Apr 27, 2026
ab84ecc
Run CollapseTuple after UnrollTreeMap
SF-N Apr 28, 2026
36d6956
Merge branch 'main' into tracer_support_tree_map
SF-N Apr 28, 2026
152300e
Address review comments
SF-N Apr 28, 2026
d459b0e
Address further review comments
SF-N Apr 28, 2026
97af81e
Apply review comments
SF-N Apr 29, 2026
8d75708
Merge branch 'main' into tracer_support_tree_map
SF-N Apr 29, 2026
067bc29
Merge branch 'main' into tracer_support_tree_map
SF-N May 4, 2026
32e5b2d
Rename map_ -> map_list
SF-N May 28, 2026
a7175d7
Run pre-commit
SF-N May 28, 2026
4f89818
Merge branch 'main' into tracer_support_tree_map
SF-N May 28, 2026
2779fd0
Refactor tree_map_tuple and add map_tuple with unrolling support
SF-N May 28, 2026
80f3273
Rename
SF-N May 28, 2026
454e15f
Minor fix
SF-N May 28, 2026
8a1febd
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 2, 2026
31b969a
Remove unnecessary CollapseTuple loop
SF-N Jun 2, 2026
c7fc102
Reposition UnrollTupleMaps and simplify CollapseTuple usage
SF-N Jun 2, 2026
7993b9c
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 2, 2026
b7f8ba9
Refactor tree_map unrolling
SF-N Jun 16, 2026
3d38868
Cleanup
SF-N Jun 16, 2026
7d5c86c
Revert "Cleanup"
SF-N Jun 17, 2026
747f36e
Revert "Refactor tree_map unrolling"
SF-N Jun 17, 2026
b767700
Cleanup
SF-N Jun 17, 2026
e91f1f1
Merge branch 'origin-main' into tracer_support_tree_map
SF-N Jun 17, 2026
7b270a3
Address review comment
SF-N Jun 18, 2026
d3d4e46
Remove CollapseTuple pass after UnrollTupleMaps
SF-N Jun 19, 2026
d0272df
Remove program wrapper in tests
SF-N Jun 19, 2026
56f234e
Merge branch 'tracer_support_tree_map' of github.com:SF-N/gt4py into …
SF-N Jun 19, 2026
b7bb0b2
Fix test
SF-N Jun 19, 2026
158d540
Merge branch 'main' into tracer_support_tree_map
SF-N Jun 24, 2026
7d8f56c
Also allow itir.Expr in UnrollTupleMaps and run tye_inference when ne…
SF-N Jun 24, 2026
7808b0f
Merge branch 'tracer_support_tree_map' of github.com:SF-N/gt4py into …
SF-N Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,23 +412,15 @@ def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall:
return self._lower_and_map("if_", *node.args)

cond_ = self.visit(node.args[0])
true_ = self.visit(node.args[1])
false_ = self.visit(node.args[2])
cond_symref_name = f"__cond_{cond_.fingerprint()}"

def create_if(
true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec]
) -> itir.FunCall:
return _map(
"if_",
(im.ref(cond_symref_name), true_, false_),
(node.args[0].type, *arg_types),
result = im.tree_map(
im.lambda_("__a", "__b")(
im.op_as_fieldop("if_")(im.ref(cond_symref_name), im.ref("__a"), im.ref("__b"))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tuple-where over a per-neighbour (local) condition produces invalid IR.

Hardcoding the per-leaf op to op_as_fieldop("if_") drops the local-field handling the non-tuple path keeps: _lower_and_map("if_", …)_map wraps the op in map_list(if_) and promotes the condition with make_const_list when the leaves are local fields. The single, uniform tree_map_tuple leaf can't do that.

This is reachable from the frontend and breaks on roundtrip, gtfn and dace whenever the where condition is itself a local (per-neighbour) field — the mask only has to be a FieldType, and a local field qualifies:

@gtx.field_operator
def tup(a: EdgeF, b: EdgeF, c: EdgeF, d: EdgeF) -> tuple[VertF, VertF]:
    cond = a(V2E) > c(V2E)                               # per-neighbour (local) bool field
    t = where(cond, (a(V2E), b(V2E)), (c(V2E), d(V2E)))
    return (neighbor_sum(t[0], axis=V2EDim), neighbor_sum(t[1], axis=V2EDim))

The non-tuple equivalent (neighbor_sum(where(a(V2E)>c(V2E), a(V2E), c(V2E)), …)) works on all three backends; the tuple version fails on all three with AssertionError in the if_ type synthesizer (isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL), because after UnrollTupleMaps the leaf is ⇑(λ(c,x,y) → if ·c then ·x else ·y) applied to a list-typed predicate — no map_list.

Broadcast (non-local) conditions happen to pass, since wholesale if_(scalar, list, list) equals element-wise selection and the consuming reduce absorbs the list — only a varying condition exposes it.

This is the same per-leaf type-dependent reason given for not migrating _visit_astype. It also isn't fixable by "always emit map_list": with a non-local cond and mixed leaves (one branch local, one not), no single uniform leaf lambda is correct. Options: keep where on process_elements, or wrap if_ in map_list post-unroll when the predicate is a list.

)

result = lowering_utils.process_elements(
create_if,
(self.visit(node.args[1]), self.visit(node.args[2])),
node.type,
arg_types=(node.args[1].type, node.args[2].type),
)
)(true_, false_)

return im.let(cond_symref_name, cond_)(result)

Expand Down
8 changes: 7 additions & 1 deletion src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def map_(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def tree_map(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def make_const_list(*args):
raise BackendNotSelectedError()
Expand Down Expand Up @@ -498,7 +503,8 @@ def get_domain_range(*args):
"lift",
"make_const_list",
"make_tuple",
"map_",
"tree_map",
"map_", # TODO: rename to map_list
"named_range",
"neighbors",
"reduce",
Expand Down
5 changes: 5 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,11 @@ def map_(op):
return call(call("map_")(op))


def tree_map(op):
"""Create a `tree_map` call: tree_map(op)(tup1, tup2, ...)."""
return call(call("tree_map")(op))


def reduce(op, expr):
"""Create a `reduce` call."""
return call(call("reduce")(op, expr))
Expand Down
38 changes: 38 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
prune_empty_concat_where,
remove_broadcast,
symbol_ref_utils,
unroll_tree_map,
)
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
Expand Down Expand Up @@ -176,6 +177,26 @@ def apply_common_transforms(
) # domain inference does not support dynamic offsets yet
ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids)

# After UnrollTreeMap, collapse `tuple_get(i, let(...)(make_tuple(...)))` patterns so that
# domain inference does not encounter `as_fieldop` nodes inside dead tuple elements
# (which would receive NEVER domain). Do multiple iterations for nested `let`s.
for _ in range(10):
collapsed = ir
ir = CollapseTuple.apply(
ir,
enabled_transformations=(
CollapseTuple.Transformation.PROPAGATE_TUPLE_GET
| CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE
),
uids=uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
if ir == collapsed:
break
Comment thread
SF-N marked this conversation as resolved.
Outdated
else:
raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this test_reduction_expression_with_where_and_tuples fails with ValueError: 'target_domain' cannot be 'NEVER' unless "allow_uninferred=True".

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: probably this is also the test case where the loop is required. I'll take a look if another configuration of the pass helps to avoid the loop.

Comment thread
SF-N marked this conversation as resolved.
Outdated

ir = infer_domain.infer_program(
ir,
Expand Down Expand Up @@ -290,6 +311,23 @@ def apply_fieldview_transforms(

ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = unroll_tree_map.UnrollTreeMap.apply(ir, uids=uids)
for _ in range(10):
prev = ir
ir = CollapseTuple.apply(
ir,
enabled_transformations=(
CollapseTuple.Transformation.PROPAGATE_TUPLE_GET
| CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE
),
uids=uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
if ir == prev:
break
Comment thread
SF-N marked this conversation as resolved.
Outdated
else:
raise RuntimeError("'CollapseTuple' did not converge after `UnrollTreeMap`.")

ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program

ir = infer_domain.infer_program(
Expand Down
85 changes: 85 additions & 0 deletions src/gt4py/next/iterator/transforms/unroll_tree_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses

from gt4py import eve
from gt4py.next import utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.type_system import inference as itir_inference
from gt4py.next.type_system import type_specifications as ts


def _unroll(
f: itir.Expr,
tup_types: list[ts.TupleType],
tup_exprs: list[itir.Expr],
) -> itir.Expr:
"""Recursively expand ``tree_map(f)(tup0, tup1, ...)`` into ``make_tuple`` / ``tuple_get``."""
assert tup_types, "tree_map requires at least one tuple argument."
Comment thread
SF-N marked this conversation as resolved.
Outdated
n = len(tup_types[0].types)
if any(len(t.types) != n for t in tup_types[1:]):
raise ValueError(
f"All tree_map arguments must have the same tuple structure at each level, "
f"got {[len(t.types) for t in tup_types]}."
)

elements: list[itir.Expr] = []
for i in range(n):
child_types = [t.types[i] for t in tup_types]
child_exprs = [im.tuple_get(i, e) for e in tup_exprs]

all_tuples = all(isinstance(ct, ts.TupleType) for ct in child_types)
all_leaves = all(not isinstance(ct, ts.TupleType) for ct in child_types)
if all_tuples:
nested_types = [ct for ct in child_types if isinstance(ct, ts.TupleType)]
elements.append(_unroll(f, nested_types, child_exprs))
elif all_leaves:
elements.append(im.call(f)(*child_exprs))
else:
raise ValueError(
"All tree_map arguments must have the same tree structure "
"(all leaves must be reached simultaneously)."
)

return im.make_tuple(*elements)


@dataclasses.dataclass
class UnrollTreeMap(eve.NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("domain",)

uids: utils.IDGeneratorPool

@classmethod
def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool):
return cls(uids=uids).visit(program)

def visit_FunCall(self, node: itir.FunCall):
node = self.generic_visit(node)

if not cpm.is_call_to(node.fun, "tree_map"):
return node

f = node.fun.args[0]
tup_args = node.args
tup_types: list[ts.TupleType] = []
for tup in tup_args:
itir_inference.reinfer(tup)
assert isinstance(tup.type, ts.TupleType)
tup_types.append(tup.type)
Comment thread
SF-N marked this conversation as resolved.
Outdated

tup_refs = [next(self.uids["_utm"]) for _ in tup_args]
body = _unroll(f, tup_types, [im.ref(r) for r in tup_refs])

result = body
for ref_name, tup in reversed(list(zip(tup_refs, tup_args))):
result = im.let(ref_name, tup)(result)

itir_inference.reinfer(result)
return result
Comment thread
SF-N marked this conversation as resolved.
Outdated
41 changes: 41 additions & 0 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,47 @@ def applied_map(
return applied_map


@_register_builtin_type_synthesizer(fun_names=["tree_map"])
def _tree_map(op: TypeSynthesizer) -> TypeSynthesizer:
Comment thread
SF-N marked this conversation as resolved.
Outdated
@type_synthesizer
def applied_map(
Comment thread
SF-N marked this conversation as resolved.
Outdated
*args: ts.TupleType, offset_provider_type: common.OffsetProviderType
) -> ts.TupleType:
if not args:
raise TypeError("tree_map requires at least one argument.")
if not all(isinstance(a, ts.TupleType) for a in args):
raise TypeError(
"tree_map requires all top-level arguments to be TupleType, "
f"got {[type(a).__name__ for a in args]}."
)

def _recurse(*arg_types: ts.TypeSpec) -> ts.TypeSpec:
all_tuples = all(isinstance(a, ts.TupleType) for a in arg_types)
all_leaves = all(not isinstance(a, ts.TupleType) for a in arg_types)
if all_tuples:
tup_types = [a for a in arg_types if isinstance(a, ts.TupleType)]
n = len(tup_types[0].types)
if any(len(t.types) != n for t in tup_types[1:]):
raise TypeError(
f"All tree_map arguments must have the same tuple structure at each level, "
f"got {[len(t.types) for t in tup_types]}."
)
return ts.TupleType(
types=[_recurse(*(a.types[i] for a in tup_types)) for i in range(n)]
)
elif all_leaves:
return op(*arg_types, offset_provider_type=offset_provider_type) # type: ignore[return-value]
else:
raise TypeError(
"All tree_map arguments must have the same tree structure "
"(all leaves must be reached simultaneously)."
)

return _recurse(*args) # type: ignore[return-value]

return applied_map


@_register_builtin_type_synthesizer
def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer:
@type_synthesizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,9 @@ def foo(
lowered
) # we generate a let for the condition which is removed by inlining for easier testing

reference = im.make_tuple(
im.op_as_fieldop("if_")("a", im.tuple_get(0, "b"), im.tuple_get(0, "c")),
im.op_as_fieldop("if_")("a", im.tuple_get(1, "b"), im.tuple_get(1, "c")),
)
reference = im.tree_map(
im.lambda_("__a", "__b")(im.op_as_fieldop("if_")("a", im.ref("__a"), im.ref("__b")))
)("b", "c")

assert lowered_inlined.expr == reference

Expand Down
17 changes: 17 additions & 0 deletions tests/next_tests/unit_tests/iterator_tests/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,23 @@ def expression_test_cases():
),
ts.ListType(element_type=int_type, offset_type=V2EDim),
),
# tree_map
(
im.tree_map(im.ref("plus"))(
im.ref("t1", ts.TupleType(types=[int_type, int_type])),
im.ref("t2", ts.TupleType(types=[int_type, int_type])),
),
ts.TupleType(types=[int_type, int_type]),
),
(
im.tree_map(im.ref("not_"))(
im.ref(
"t",
ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]),
),
),
ts.TupleType(types=[bool_type, ts.TupleType(types=[bool_type, bool_type])]),
),
# reduce
(im.reduce("plus", 0)(im.ref("l", int_list_type)), int_type),
(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.next import common, utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms.unroll_tree_map import UnrollTreeMap
from gt4py.next.type_system import type_specifications as ts

IDim = common.Dimension("IDim")
T = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
i_field = ts.FieldType(dims=[IDim], dtype=T)
i_tuple_field = ts.TupleType(types=[i_field, i_field])
i_nested_tuple_field = ts.TupleType(types=[i_tuple_field, i_field])

i_domain = im.call("cartesian_domain")(im.named_range(itir.AxisLiteral(value="IDim"), 0, 1))


def _make_program(
params: list[itir.Sym], expr: itir.Expr, out_type: ts.TypeSpec = i_field
) -> itir.Program:
return itir.Program(
id="testee",
function_definitions=[],
params=[*params, im.sym("out", out_type)],
declarations=[],
body=[
itir.SetAt(
expr=expr,
domain=i_domain,
target=im.ref("out", out_type),
)
],
)


def _neg():
return im.lambda_("__a")(im.op_as_fieldop("neg")("__a"))


def _plus():
return im.lambda_("__a", "__b")(im.op_as_fieldop("plus")("__a", "__b"))


def test_multi_arg():
uids = utils.IDGeneratorPool()
program = _make_program(
[im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)],
im.call(im.call("tree_map")(_plus()))(
im.ref("a", i_tuple_field), im.ref("b", i_tuple_field)
),
out_type=i_tuple_field,
)
result = UnrollTreeMap.apply(program, uids=uids)

expected = _make_program(
[im.sym("a", i_tuple_field), im.sym("b", i_tuple_field)],
im.let("_utm_0", "a")(
im.let("_utm_1", "b")(
im.make_tuple(
im.call(_plus())(im.tuple_get(0, "_utm_0"), im.tuple_get(0, "_utm_1")),
im.call(_plus())(im.tuple_get(1, "_utm_0"), im.tuple_get(1, "_utm_1")),
)
)
),
out_type=i_tuple_field,
)
assert result == expected


def test_nested():
uids = utils.IDGeneratorPool()
program = _make_program(
[im.sym("t", i_nested_tuple_field)],
im.call(im.call("tree_map")(_neg()))(im.ref("t", i_nested_tuple_field)),
out_type=i_nested_tuple_field,
)
result = UnrollTreeMap.apply(program, uids=uids)

expected = _make_program(
[im.sym("t", i_nested_tuple_field)],
im.let("_utm_0", "t")(
im.make_tuple(
im.make_tuple(
im.call(_neg())(im.tuple_get(0, im.tuple_get(0, "_utm_0"))),
im.call(_neg())(im.tuple_get(1, im.tuple_get(0, "_utm_0"))),
),
im.call(_neg())(im.tuple_get(1, "_utm_0")),
)
),
out_type=i_nested_tuple_field,
)
assert result == expected