From cd90f5b11f66d0c53cb6d8be1887d8638044d10d Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 14 May 2026 12:00:57 -0400 Subject: [PATCH 001/101] Fix `ReplaceAxisSymbol` and keep it to Taskslets -> `ReplaceAxisSymbolInTasklet` Move re-usable functions into `tree_common_op` --- .../dace/stree/optimizations/axis_merge.py | 59 ++++++------------- .../stree/optimizations/tree_common_op.py | 14 ++++- 2 files changed, 32 insertions(+), 41 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 0f97468c..94920954 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import itertools import dace from dace.properties import CodeBlock @@ -14,6 +15,8 @@ detect_cycle, list_index, swap_node_position_in_tree, + is_axis_map, + is_axis_for ) from ndsl.logging import ndsl_log @@ -22,16 +25,6 @@ PUSH_IFSCOPE_DOWNWARD = False # Crashing the overall stree - bad algorithmics -def _is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: - """Returns true if node is a map over the given axis.""" - map_parameter = node.node.map.params - return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) - - -def _is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: - return node.loop.loop_variable.startswith(axis.as_str()) - - def _both_same_single_axis_maps( first: tn.MapScope, second: tn.MapScope, axis: AxisIterator ) -> bool: @@ -39,8 +32,8 @@ def _both_same_single_axis_maps( ( len(first.node.map.params) == 1 and len(second.node.map.params) == 1 ) # Single axis - and _is_axis_map(first, axis) # Correct axis in first map - and _is_axis_map(second, axis) # Correct axis in second map + and is_axis_map(first, axis) # Correct axis in first map + and is_axis_map(second, axis) # Correct axis in second map ) @@ -109,26 +102,10 @@ def _last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> b return list_index(nodes, node) >= len(nodes) - 1 -class ReplaceAxisSymbol(tn.ScheduleNodeVisitor): +class ReplaceAxisSymbolInTasklet(tn.ScheduleNodeVisitor): def __init__(self, axis: AxisIterator) -> None: self._axis = axis - def visit_MapScope( - self, - map_scope: tn.MapScope, - axis_replacements: dict[str, str] | None = None, - ) -> None: - if axis_replacements is None: - axis_replacements = {} - - for index, param in enumerate(map_scope.node.params): - if param in axis_replacements: - map_scope.node.params[index] = axis_replacements[param] - - # visit children - for child in map_scope.children: - self.visit(child, axis_replacements=axis_replacements) - def visit_TaskletNode( self, node: tn.TaskletNode, @@ -138,11 +115,13 @@ def visit_TaskletNode( # Noop if there are no replacements to do. return - for memlets in node.in_memlets.values(): - memlets.replace(axis_replacements) - for memlets in node.out_memlets.values(): - memlets.replace(axis_replacements) - + # Dev NOTE: We directly replace the memlet.subset because the `memlet.replace` + # function sometimes doesn't work + for memlet in itertools.chain(node.in_memlets.values(), node.out_memlets.values()): + if memlet.subset is not None: + memlet.subset.replace(axis_replacements) + if memlet.other_subset is not None: + memlet.other_subset.replace(axis_replacements) class CartesianAxisMerge(tn.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. @@ -197,7 +176,7 @@ def _merge_node( def _for_merge(self, the_for_scope: tn.ForScope) -> int: merged = 0 - if _is_axis_for(the_for_scope, self.axis): + if is_axis_for(the_for_scope, self.axis): # TODO: if the for scope is on a cartesian axis it can be # merged with other for scope going in the same direction pass @@ -206,7 +185,7 @@ def _for_merge(self, the_for_scope: tn.ForScope) -> int: if ( len(the_for_scope.children) == 1 and isinstance(the_for_scope.children[0], tn.MapScope) - and _is_axis_map(the_for_scope.children[0], self.axis) + and is_axis_map(the_for_scope.children[0], self.axis) ): swap_node_position_in_tree(the_for_scope, the_for_scope.children[0]) merged += 1 @@ -327,7 +306,7 @@ def _map_overcompute_merge( # End of nodes OR # Not the right axis # --> recurse - if _last_node(nodes, the_map) or not _is_axis_map(the_map, self.axis): + if _last_node(nodes, the_map) or not is_axis_map(the_map, self.axis): merged = 0 for child in the_map.children: merged += self._merge_node(child, the_map.children) @@ -384,9 +363,9 @@ def _map_overcompute_merge( # K-maps use unique iterators (i.e. every k-map iterates over `k__[0-9]*`). # After merge, we need to replace the axis symbols of the second map's children # with the axis symbol of the first map. - if next_node.node.map.params[0] != the_map.node.map.params[0]: - replacements = {next_node.node.map.params[0]: the_map.node.map.params[0]} - ReplaceAxisSymbol(self.axis).visit( + if second_map.node.map.params[0] != first_map.node.map.params[0]: + replacements = {second_map.node.map.params[0]: first_map.node.map.params[0]} + ReplaceAxisSymbolInTasklet(self.axis).visit( first_map, axis_replacements=replacements ) diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py index 1253ba81..e243ede3 100644 --- a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py +++ b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py @@ -1,7 +1,7 @@ from typing import Collection import dace.sdfg.analysis.schedule_tree.treenodes as tn - +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator def swap_node_position_in_tree( top_node: tn.ScheduleTreeScope, child_node: tn.ScheduleTreeScope @@ -51,3 +51,15 @@ def list_index( """Check if node is in list with "is" operator.""" # compare with "is" to get memory comparison. ".index()" uses value comparison return next(index for index, element in enumerate(collection) if element is node) + + +def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: + """Returns true if node is a Map over the given axis.""" + map_parameter = node.node.map.params + return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) + + +def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: + """Returns true if node is a For over the given axis.""" + return node.loop.loop_variable.startswith(axis.as_str()) + From 2c8f74e6d17ff634f1f7ffc2329a3c66d866749f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 14 May 2026 12:02:56 -0400 Subject: [PATCH 002/101] Add `TreeOptimizationStatistics` to capture the results of the opt at a glance --- .../dace/stree/optimizations/axis_merge.py | 9 +- .../dace/stree/optimizations/statistics.py | 94 +++++++++++++++++++ .../stree/optimizations/tree_common_op.py | 3 +- ndsl/dsl/dace/stree/pipeline.py | 8 ++ 4 files changed, 110 insertions(+), 4 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/statistics.py diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 94920954..abacd0b4 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -13,10 +13,10 @@ ) from ndsl.dsl.dace.stree.optimizations.tree_common_op import ( detect_cycle, + is_axis_for, + is_axis_map, list_index, swap_node_position_in_tree, - is_axis_map, - is_axis_for ) from ndsl.logging import ndsl_log @@ -117,12 +117,15 @@ def visit_TaskletNode( # Dev NOTE: We directly replace the memlet.subset because the `memlet.replace` # function sometimes doesn't work - for memlet in itertools.chain(node.in_memlets.values(), node.out_memlets.values()): + for memlet in itertools.chain( + node.in_memlets.values(), node.out_memlets.values() + ): if memlet.subset is not None: memlet.subset.replace(axis_replacements) if memlet.other_subset is not None: memlet.other_subset.replace(axis_replacements) + class CartesianAxisMerge(tn.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. diff --git a/ndsl/dsl/dace/stree/optimizations/statistics.py b/ndsl/dsl/dace/stree/optimizations/statistics.py new file mode 100644 index 00000000..1a3ce593 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/statistics.py @@ -0,0 +1,94 @@ +import dataclasses + +import dace +import dace.sdfg.analysis.schedule_tree.treenodes as stree + +from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator +from ndsl.dsl.dace.stree.optimizations.tree_common_op import is_axis_for, is_axis_map + + +class CountCartesianLoops(stree.ScheduleNodeVisitor): + def __init__(self) -> None: + super().__init__() + self._maps = [0, 0, 0] + self._fors = [0, 0, 0] + + def visit_MapScope(self, node: stree.MapScope) -> None: + for axis in AxisIterator: + if is_axis_map(node, axis): + self._maps[axis.as_cartesian_index()] += 1 + + self.visit(node.children) + + def visit_ForScope(self, node: stree.ForScope) -> None: + for axis in AxisIterator: + if is_axis_for(node, axis): + self._fors[axis.as_cartesian_index()] += 1 + + self.visit(node.children) + + +class CountTransient(stree.ScheduleNodeVisitor): + def __init__(self) -> None: + super().__init__() + self._counts = [0, 0, 0, 0] + + def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + for data in node.containers.values(): + non_atomic_dims_count = sum(1 for x in data.shape if x != 1) + if isinstance(data, dace.data.Array) and data.transient: + if non_atomic_dims_count == 1: + self._counts[0] += 1 + elif non_atomic_dims_count == 2: + self._counts[1] += 1 + elif non_atomic_dims_count == 3: + self._counts[2] += 1 + else: + self._counts[3] += 1 + + +class TreeOptimizationStatistics: + """Capture basic statistics on the schedule tree optimization actions""" + + @dataclasses.dataclass + class Record: + """Private record of a state of a tree""" + + cartesian_maps: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) + cartesian_fors: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) + transients: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0, 0]) + + def __init__(self) -> None: + self._original_record = TreeOptimizationStatistics.Record() + self._optimized_record = TreeOptimizationStatistics.Record() + + def _record( + self, + record: Record, + tree_root: stree.ScheduleTreeRoot, + ) -> None: + """Record the state of a tree""" + c = CountCartesianLoops() + c.visit(tree_root) + record.cartesian_fors = c._fors + record.cartesian_maps = c._maps + + c = CountTransient() + c.visit(tree_root) + record.transients = c._counts + + def original(self, tree_root: stree.ScheduleTreeRoot) -> None: + """Record the original state of the tree, before optimization""" + self._record(self._original_record, tree_root) + + def optimized(self, tree_root: stree.ScheduleTreeRoot) -> None: + """Record the state of the tree after optimization""" + self._record(self._optimized_record, tree_root) + + def report(self) -> str: + """Craft a concize string reporting on the statistics""" + msg = "Tree optimization:\n" + msg += f" Cartesian maps [I, J, K]: {self._original_record.cartesian_maps} -> {self._optimized_record.cartesian_maps}\n" + msg += f" Cartesian fors [I, J, K]: {self._original_record.cartesian_fors} -> {self._optimized_record.cartesian_fors}\n" + msg += f" Transients [1D, 2D, 3D, 4D+]: {self._original_record.transients} -> {self._optimized_record.transients}\n" + return msg diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py index e243ede3..4748a901 100644 --- a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py +++ b/ndsl/dsl/dace/stree/optimizations/tree_common_op.py @@ -1,8 +1,10 @@ from typing import Collection import dace.sdfg.analysis.schedule_tree.treenodes as tn + from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator + def swap_node_position_in_tree( top_node: tn.ScheduleTreeScope, child_node: tn.ScheduleTreeScope ) -> None: @@ -62,4 +64,3 @@ def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: """Returns true if node is a For over the given axis.""" return node.loop.loop_variable.startswith(axis.as_str()) - diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index e4307ddf..44e933bc 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -3,6 +3,7 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl.dsl.dace.stree.optimizations.statistics import TreeOptimizationStatistics from ndsl.logging import ndsl_log_on_rank_0 @@ -30,6 +31,9 @@ def run( stree: stree.ScheduleTreeRoot, verbose: bool = False, ) -> stree.ScheduleTreeRoot: + tree_stats = TreeOptimizationStatistics() + tree_stats.original(stree) + for i, p in enumerate(self.passes): if verbose: path = self.cache_directory / f"pass{i}_{p}.txt" @@ -41,6 +45,10 @@ def run( with open(path, "w+") as f: f.write(stree.as_string()) + tree_stats.optimized(stree) + if verbose: + ndsl_log_on_rank_0.info(tree_stats.report()) + return stree From 8b49e3b550ceed51c4e7063932803df1e0e91d96 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 14 May 2026 15:28:09 -0400 Subject: [PATCH 003/101] Add a master `CartesianMerge` bringing everything axis merge, refactor around --- ndsl/dsl/dace/orchestration.py | 50 +++----------- ndsl/dsl/dace/stree/optimizations/__init__.py | 3 +- .../dace/stree/optimizations/axis_merge.py | 65 +++++++++++-------- .../replace_symbol_in_tasklet.py | 29 +++++++++ 4 files changed, 79 insertions(+), 68 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index a81b3744..933a379b 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -24,7 +24,6 @@ import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements from ndsl.comm.mpi import MPI -from ndsl.config import BackendLoopOrder from ndsl.dsl.dace.build import get_sdfg_path, write_build_info from ndsl.dsl.dace.dace_config import ( DEACTIVATE_DISTRIBUTED_DACE_COMPILE, @@ -40,8 +39,7 @@ ) from ndsl.dsl.dace.stree import CPUPipeline from ndsl.dsl.dace.stree.optimizations import ( - AxisIterator, - CartesianAxisMerge, + CartesianMerge, CartesianRefineTransients, CleanUpScheduleTree, ) @@ -198,44 +196,14 @@ def _build_sdfg( f.write(stree.as_string()) with DaCeProgress(config, "Schedule Tree: optimization"): - passes = [] - if backend_name.loop_order == BackendLoopOrder.IJK: - passes.extend( - [ - CleanUpScheduleTree(), - CartesianAxisMerge(AxisIterator._I), - CartesianAxisMerge(AxisIterator._J), - CartesianAxisMerge(AxisIterator._K), - CartesianRefineTransients(backend_name), - ] - ) - elif backend_name.loop_order == BackendLoopOrder.KJI: - passes.extend( - [ - CleanUpScheduleTree(), - CartesianAxisMerge(AxisIterator._K), - CartesianAxisMerge(AxisIterator._J), - CartesianAxisMerge(AxisIterator._I), - CartesianRefineTransients(backend_name), - ] - ) - elif backend_name.loop_order == BackendLoopOrder.KIJ: - passes.extend( - [ - CleanUpScheduleTree(), - CartesianAxisMerge(AxisIterator._K), - CartesianAxisMerge(AxisIterator._I), - CartesianAxisMerge(AxisIterator._J), - CartesianRefineTransients(backend_name), - ] - ) - else: - raise NotImplementedError( - f"Loop order {backend_name.loop_order} has no schedule tree pipeline" - ) - CPUPipeline(passes=passes, cache_directory=Path(sdfg.build_folder)).run( - stree, verbose=config.verbose_schedule_tree_optimizations - ) + CPUPipeline( + passes=[ + CleanUpScheduleTree(), + CartesianMerge(backend_name), + CartesianRefineTransients(backend_name), + ], + cache_directory=Path(sdfg.build_folder), + ).run(stree, verbose=config.verbose_schedule_tree_optimizations) if config.verbose_orchestration: with open( os.path.abspath(f"{sdfg.build_folder}/03-post_opt.stree.txt"), diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 73497f93..b08d6839 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,4 +1,4 @@ -from .axis_merge import AxisIterator, CartesianAxisMerge +from .axis_merge import AxisIterator, CartesianAxisMerge, CartesianMerge from .clean_tree import CleanUpScheduleTree from .refine_transients import CartesianRefineTransients @@ -6,6 +6,7 @@ __all__ = [ "AxisIterator", "CartesianAxisMerge", + "CartesianMerge", "CartesianRefineTransients", "CleanUpScheduleTree", ] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index abacd0b4..2cbf8995 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -1,16 +1,19 @@ from __future__ import annotations import copy -import itertools import dace from dace.properties import CodeBlock from dace.sdfg.analysis.schedule_tree import treenodes as tn +from ndsl.config import Backend, BackendLoopOrder from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( AxisIterator, no_data_dependencies_on_cartesian_axis, ) +from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( + ReplaceAxisSymbolInTasklet, +) from ndsl.dsl.dace.stree.optimizations.tree_common_op import ( detect_cycle, is_axis_for, @@ -102,30 +105,6 @@ def _last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> b return list_index(nodes, node) >= len(nodes) - 1 -class ReplaceAxisSymbolInTasklet(tn.ScheduleNodeVisitor): - def __init__(self, axis: AxisIterator) -> None: - self._axis = axis - - def visit_TaskletNode( - self, - node: tn.TaskletNode, - axis_replacements: dict[str, str] | None = None, - ) -> None: - if not axis_replacements: - # Noop if there are no replacements to do. - return - - # Dev NOTE: We directly replace the memlet.subset because the `memlet.replace` - # function sometimes doesn't work - for memlet in itertools.chain( - node.in_memlets.values(), node.out_memlets.values() - ): - if memlet.subset is not None: - memlet.subset.replace(axis_replacements) - if memlet.other_subset is not None: - memlet.other_subset.replace(axis_replacements) - - class CartesianAxisMerge(tn.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. @@ -368,7 +347,7 @@ def _map_overcompute_merge( # with the axis symbol of the first map. if second_map.node.map.params[0] != first_map.node.map.params[0]: replacements = {second_map.node.map.params[0]: first_map.node.map.params[0]} - ReplaceAxisSymbolInTasklet(self.axis).visit( + ReplaceAxisSymbolInTasklet().visit( first_map, axis_replacements=replacements ) @@ -447,3 +426,37 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: ndsl_log.debug( f"🚀 {self}: {overall_merged} maps merged in {passes_apply} passes" ) + + +class CartesianMerge(tn.ScheduleNodeTransformer): + """Merge Cartesian axis loops""" + + def __init__(self, backend: Backend, *, eager: bool = True) -> None: + self._backend = backend + self.eager = eager + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + if self._backend.loop_order == BackendLoopOrder.IJK: + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + elif self._backend.loop_order == BackendLoopOrder.IKJ: + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + elif self._backend.loop_order == BackendLoopOrder.JIK: + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + elif self._backend.loop_order == BackendLoopOrder.JKI: + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + elif self._backend.loop_order == BackendLoopOrder.KIJ: + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + elif self._backend.loop_order == BackendLoopOrder.KJI: + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py new file mode 100644 index 00000000..41c03ada --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import itertools + +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +class ReplaceAxisSymbolInTasklet(tn.ScheduleNodeVisitor): + def __init__(self) -> None: + pass + + def visit_TaskletNode( + self, + node: tn.TaskletNode, + axis_replacements: dict[str, str] | None = None, + ) -> None: + if not axis_replacements: + # Noop if there are no replacements to do. + return + + # Dev NOTE: We directly replace the memlet.subset because the `memlet.replace` + # function sometimes doesn't work + for memlet in itertools.chain( + node.in_memlets.values(), node.out_memlets.values() + ): + if memlet.subset is not None: + memlet.subset.replace(axis_replacements) + if memlet.other_subset is not None: + memlet.other_subset.replace(axis_replacements) From c8d05af66d6ae94c5a31a0e235de3d07f36a3c4a Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 15 May 2026 08:55:36 -0400 Subject: [PATCH 004/101] Move helpers into `common` and break them by type Move pipeline defaults inside the Pipeline itself and have orchestration call default Mockup of passes required for merging to behave --- ndsl/dsl/dace/orchestration.py | 11 +-- ndsl/dsl/dace/stree/optimizations/__init__.py | 13 ++- .../dace/stree/optimizations/axis_merge.py | 28 +++---- .../stree/optimizations/cartesian_merge.py | 52 ++++++++++++ .../stree/optimizations/common/__init__.py | 22 ++++++ .../dace/stree/optimizations/common/loops.py | 14 ++++ .../{memlet_helpers.py => common/memlet.py} | 0 .../{tree_common_op.py => common/topology.py} | 17 ++-- .../optimizations/offgrid_conditionals.py | 79 +++++++++++++++++++ .../stree/optimizations/refine_transients.py | 2 +- .../dace/stree/optimizations/remove_loops.py | 22 ++++++ .../dace/stree/optimizations/statistics.py | 7 +- ndsl/dsl/dace/stree/pipeline.py | 20 ++++- 13 files changed, 241 insertions(+), 46 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/cartesian_merge.py create mode 100644 ndsl/dsl/dace/stree/optimizations/common/__init__.py create mode 100644 ndsl/dsl/dace/stree/optimizations/common/loops.py rename ndsl/dsl/dace/stree/optimizations/{memlet_helpers.py => common/memlet.py} (100%) rename ndsl/dsl/dace/stree/optimizations/{tree_common_op.py => common/topology.py} (77%) create mode 100644 ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py create mode 100644 ndsl/dsl/dace/stree/optimizations/remove_loops.py diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 933a379b..e31298d7 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -38,11 +38,6 @@ sdfg_nan_checker, ) from ndsl.dsl.dace.stree import CPUPipeline -from ndsl.dsl.dace.stree.optimizations import ( - CartesianMerge, - CartesianRefineTransients, - CleanUpScheduleTree, -) from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -197,11 +192,7 @@ def _build_sdfg( with DaCeProgress(config, "Schedule Tree: optimization"): CPUPipeline( - passes=[ - CleanUpScheduleTree(), - CartesianMerge(backend_name), - CartesianRefineTransients(backend_name), - ], + backend=backend_name, cache_directory=Path(sdfg.build_folder), ).run(stree, verbose=config.verbose_schedule_tree_optimizations) if config.verbose_orchestration: diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index b08d6839..21dcaa72 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,6 +1,13 @@ -from .axis_merge import AxisIterator, CartesianAxisMerge, CartesianMerge +from .axis_merge import AxisIterator, CartesianAxisMerge +from .cartesian_merge import CartesianMerge from .clean_tree import CleanUpScheduleTree +from .offgrid_conditionals import ( + ExtractOffgridConditionals, + InlineOffgridConditionals, + MergeConditionals, +) from .refine_transients import CartesianRefineTransients +from .remove_loops import InlineVertical2DWrite __all__ = [ @@ -9,4 +16,8 @@ "CartesianMerge", "CartesianRefineTransients", "CleanUpScheduleTree", + "InlineVertical2DWrite", + "ExtractOffgridConditionals", + "InlineOffgridConditionals", + "MergeConditionals", ] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 2cbf8995..9ac339e3 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -7,20 +7,20 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl.config import Backend, BackendLoopOrder -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import ( +from ndsl.dsl.dace.stree.optimizations.common import ( AxisIterator, - no_data_dependencies_on_cartesian_axis, -) -from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( - ReplaceAxisSymbolInTasklet, -) -from ndsl.dsl.dace.stree.optimizations.tree_common_op import ( detect_cycle, + get_next_node, is_axis_for, is_axis_map, + last_node, list_index, + no_data_dependencies_on_cartesian_axis, swap_node_position_in_tree, ) +from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( + ReplaceAxisSymbolInTasklet, +) from ndsl.logging import ndsl_log @@ -95,16 +95,6 @@ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: return node -def _get_next_node( - nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode -) -> tn.ScheduleTreeNode: - return nodes[list_index(nodes, node) + 1] - - -def _last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> bool: - return list_index(nodes, node) >= len(nodes) - 1 - - class CartesianAxisMerge(tn.ScheduleNodeTransformer): """Merge a cartesian axis if they are contiguous in code-flow. @@ -288,13 +278,13 @@ def _map_overcompute_merge( # End of nodes OR # Not the right axis # --> recurse - if _last_node(nodes, the_map) or not is_axis_map(the_map, self.axis): + if last_node(nodes, the_map) or not is_axis_map(the_map, self.axis): merged = 0 for child in the_map.children: merged += self._merge_node(child, the_map.children) return merged - next_node = _get_next_node(nodes, the_map) + next_node = get_next_node(nodes, the_map) # Next node is not a MapScope - no merge if not isinstance(next_node, tn.MapScope): diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py new file mode 100644 index 00000000..779d2900 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from ndsl.config import Backend, BackendLoopOrder +from ndsl.dsl.dace.stree.optimizations import ( + CartesianAxisMerge, + ExtractOffgridConditionals, + InlineOffgridConditionals, + MergeConditionals, +) +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator + + +class CartesianMerge(tn.ScheduleNodeTransformer): + """Merge Cartesian computation blocks""" + + def __init__(self, backend: Backend, *, eager: bool = True) -> None: + self._backend = backend + self.eager = eager + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + InlineOffgridConditionals().visit(node) + MergeConditionals().visit(node) + + if self._backend.loop_order == BackendLoopOrder.IJK: + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + elif self._backend.loop_order == BackendLoopOrder.IKJ: + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + elif self._backend.loop_order == BackendLoopOrder.JIK: + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + elif self._backend.loop_order == BackendLoopOrder.JKI: + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + elif self._backend.loop_order == BackendLoopOrder.KIJ: + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + elif self._backend.loop_order == BackendLoopOrder.KJI: + CartesianAxisMerge(AxisIterator._K).visit(node) + CartesianAxisMerge(AxisIterator._J).visit(node) + CartesianAxisMerge(AxisIterator._I).visit(node) + + ExtractOffgridConditionals().visit(node) + MergeConditionals().visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/common/__init__.py b/ndsl/dsl/dace/stree/optimizations/common/__init__.py new file mode 100644 index 00000000..a4a64bc4 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/common/__init__.py @@ -0,0 +1,22 @@ +from .memlet import AxisIterator, no_data_dependencies_on_cartesian_axis # isort: skip +from .loops import is_axis_for, is_axis_map +from .topology import ( + detect_cycle, + get_next_node, + last_node, + list_index, + swap_node_position_in_tree, +) + + +__all__ = [ + "AxisIterator", + "no_data_dependencies_on_cartesian_axis", + "is_axis_map", + "is_axis_for", + "get_next_node", + "last_node", + "swap_node_position_in_tree", + "detect_cycle", + "list_index", +] diff --git a/ndsl/dsl/dace/stree/optimizations/common/loops.py b/ndsl/dsl/dace/stree/optimizations/common/loops.py new file mode 100644 index 00000000..83a91280 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/common/loops.py @@ -0,0 +1,14 @@ +import dace.sdfg.analysis.schedule_tree.treenodes as tn + +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator + + +def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: + """Returns true if node is a Map over the given axis.""" + map_parameter = node.node.map.params + return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) + + +def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: + """Returns true if node is a For over the given axis.""" + return node.loop.loop_variable.startswith(axis.as_str()) diff --git a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py similarity index 100% rename from ndsl/dsl/dace/stree/optimizations/memlet_helpers.py rename to ndsl/dsl/dace/stree/optimizations/common/memlet.py diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/common/topology.py similarity index 77% rename from ndsl/dsl/dace/stree/optimizations/tree_common_op.py rename to ndsl/dsl/dace/stree/optimizations/common/topology.py index 4748a901..27edf8fe 100644 --- a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py +++ b/ndsl/dsl/dace/stree/optimizations/common/topology.py @@ -2,8 +2,6 @@ import dace.sdfg.analysis.schedule_tree.treenodes as tn -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator - def swap_node_position_in_tree( top_node: tn.ScheduleTreeScope, child_node: tn.ScheduleTreeScope @@ -55,12 +53,13 @@ def list_index( return next(index for index, element in enumerate(collection) if element is node) -def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: - """Returns true if node is a Map over the given axis.""" - map_parameter = node.node.map.params - return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) +def get_next_node( + nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode +) -> tn.ScheduleTreeNode: + """Get next node in the children from given node""" + return nodes[list_index(nodes, node) + 1] -def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: - """Returns true if node is a For over the given axis.""" - return node.loop.loop_variable.startswith(axis.as_str()) +def last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> bool: + """Test for last node of list""" + return list_index(nodes, node) >= len(nodes) - 1 diff --git a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py new file mode 100644 index 00000000..93d6ab20 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +class InlineOffgridConditionals(tn.ScheduleNodeTransformer): + """Push offgrid conditional inside their cartesian block, + duplicating the conditional if needed + + Turning: + ``` + if a_flag == 0 + map i, j, k + [ops...] + map i, j, k + [ops...] + ``` + into + ``` + map i,j, k + if a_flag == 0 + [ops...] + map i,j, k + if a_flag == 0 + [ops...] + ``` + """ + + def __init__(self) -> None: + pass + + def __str__(self) -> str: + return "InlineOffgridConditionals" + + +class ExtractOffgridConditionals(tn.ScheduleNodeTransformer): + """Push offgrid conditional outside of their cartesian block + + Reverse transform from InlineOffgridConditionals + """ + + def __init__(self) -> None: + pass + + def __str__(self) -> str: + return "ExtractOffgridConditionals" + + +class MergeConditionals(tn.ScheduleNodeTransformer): + """Merge consecutive and equal conditionals + + Turning: + ``` + if a_flag == 0 + map i, j, k + [ops...] + if a_flag == 0 + map i, j, k + [ops...] + ``` + into + ``` + if a_flag == 0 + map i, j, k + [ops...] + map i, j, k + [ops...] + ``` + + Outside of user code, vombination of ExtractOffgridConditionals, + InlineOffgridConditionals and CartesianMapMerge can lead to this + pattern. + """ + + def __init__(self) -> None: + pass + + def __str__(self) -> str: + return "MergeConditionals" diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 7b788e32..bb066d4c 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -4,7 +4,7 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree from ndsl.config import Backend, BackendFramework -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator from ndsl.logging import ndsl_log diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py new file mode 100644 index 00000000..f3ba21a3 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -0,0 +1,22 @@ +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +class InlineVertical2DWrite(tn.ScheduleNodeTransformer): + """Inline K index value for 2D write vertical while removing for loop. + + Transforming: + ``` + for __k = 0; __k < 1; __k = __k + 1: + map __j, __i: + field[__i, __j] = tasklet(field_in[__i, __j, __k]) + ``` + + Into + ``` + map __j, __i: + field[__i, __j] = tasklet(field_in[__i, __j, 0]) + ``` + """ + + def __init__(self) -> None: + super().__init__() diff --git a/ndsl/dsl/dace/stree/optimizations/statistics.py b/ndsl/dsl/dace/stree/optimizations/statistics.py index 1a3ce593..ebef36fe 100644 --- a/ndsl/dsl/dace/stree/optimizations/statistics.py +++ b/ndsl/dsl/dace/stree/optimizations/statistics.py @@ -3,8 +3,11 @@ import dace import dace.sdfg.analysis.schedule_tree.treenodes as stree -from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator -from ndsl.dsl.dace.stree.optimizations.tree_common_op import is_axis_for, is_axis_map +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_for, + is_axis_map, +) class CountCartesianLoops(stree.ScheduleNodeVisitor): diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 44e933bc..ad52ca9c 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -2,7 +2,13 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree -from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge +from ndsl import Backend +from ndsl.dsl.dace.stree.optimizations import ( + CartesianMerge, + CartesianRefineTransients, + CleanUpScheduleTree, + InlineVertical2DWrite, +) from ndsl.dsl.dace.stree.optimizations.statistics import TreeOptimizationStatistics from ndsl.logging import ndsl_log_on_rank_0 @@ -55,14 +61,20 @@ def run( class CPUPipeline(StreePipeline): def __init__( self, + backend: Backend, *, passes: list[stree.ScheduleNodeTransformer] | None = None, cache_directory: Path | None = None, ) -> None: + if passes is None: + passes = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(backend), + CartesianRefineTransients(backend), + ] super().__init__( - passes=( - passes if passes is not None else [CartesianAxisMerge(AxisIterator._K)] - ), + passes=passes, cache_directory=cache_directory, ) From 20665a8aedc97c5cabc034d97e5a6fdb4ebbeb35 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 15 May 2026 09:04:22 -0400 Subject: [PATCH 005/101] Fix imports --- ndsl/dsl/dace/stree/optimizations/cartesian_merge.py | 6 +++--- ndsl/dsl/dace/stree/optimizations/remove_loops.py | 3 +++ .../dace/stree/optimizations/replace_symbol_in_tasklet.py | 3 +++ ndsl/dsl/dace/stree/optimizations/specialize_maps.py | 3 +++ 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index 779d2900..d8ab5043 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -3,13 +3,13 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl.config import Backend, BackendLoopOrder -from ndsl.dsl.dace.stree.optimizations import ( - CartesianAxisMerge, +from ndsl.dsl.dace.stree.optimizations.axis_merge import CartesianAxisMerge +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator +from ndsl.dsl.dace.stree.optimizations.offgrid_conditionals import ( ExtractOffgridConditionals, InlineOffgridConditionals, MergeConditionals, ) -from ndsl.dsl.dace.stree.optimizations.common import AxisIterator class CartesianMerge(tn.ScheduleNodeTransformer): diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index f3ba21a3..0bffe7ab 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -20,3 +20,6 @@ class InlineVertical2DWrite(tn.ScheduleNodeTransformer): def __init__(self) -> None: super().__init__() + + def __str__(self) -> str: + return "InlineVertical2DWrite" diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index 41c03ada..afb150e0 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -27,3 +27,6 @@ def visit_TaskletNode( memlet.subset.replace(axis_replacements) if memlet.other_subset is not None: memlet.other_subset.replace(axis_replacements) + + def __str__(self) -> str: + return "ReplaceAxisSymbolInTasklet" diff --git a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py index 2583ec2d..e2409e1a 100644 --- a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py +++ b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py @@ -19,3 +19,6 @@ def visit_MapScope(self, node: stree.MapScope) -> None: node.node.map.range = sbs.Range(dims) self.visit(node.children) + + def __str__(self) -> str: + return "SpecializeCartesianMaps" From c8a225e2492427634a9742df593f6a3ae65a1f0b Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 15 May 2026 15:23:44 -0400 Subject: [PATCH 006/101] `InlineVertical2DWrite` + utest --- ndsl/dsl/dace/stree/optimizations/__init__.py | 3 +- .../stree/optimizations/cartesian_merge.py | 3 + .../stree/optimizations/common/__init__.py | 2 + .../dace/stree/optimizations/common/memlet.py | 3 + .../stree/optimizations/common/topology.py | 16 +++ .../stree/optimizations/refine_transients.py | 6 +- .../dace/stree/optimizations/remove_loops.py | 50 +++++++ .../stree/optimizations/test_remove_loops.py | 132 ++++++++++++++++++ 8 files changed, 208 insertions(+), 7 deletions(-) create mode 100644 tests/dsl/dace/stree/optimizations/test_remove_loops.py diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 21dcaa72..8cd77f55 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,4 +1,4 @@ -from .axis_merge import AxisIterator, CartesianAxisMerge +from .axis_merge import CartesianAxisMerge from .cartesian_merge import CartesianMerge from .clean_tree import CleanUpScheduleTree from .offgrid_conditionals import ( @@ -11,7 +11,6 @@ __all__ = [ - "AxisIterator", "CartesianAxisMerge", "CartesianMerge", "CartesianRefineTransients", diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index d8ab5043..398a2103 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -19,6 +19,9 @@ def __init__(self, backend: Backend, *, eager: bool = True) -> None: self._backend = backend self.eager = eager + def __str__(self) -> str: + return "CartesianMerge" + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: InlineOffgridConditionals().visit(node) MergeConditionals().visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/common/__init__.py b/ndsl/dsl/dace/stree/optimizations/common/__init__.py index a4a64bc4..c76887fb 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/common/__init__.py @@ -5,6 +5,7 @@ get_next_node, last_node, list_index, + reparent_scope_node, swap_node_position_in_tree, ) @@ -19,4 +20,5 @@ "swap_node_position_in_tree", "detect_cycle", "list_index", + "reparent_scope_node", ] diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index 75f68143..b1540c98 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -17,6 +17,9 @@ def as_str(self) -> str: def as_cartesian_index(self) -> int: return self.value[1] + def is_equal(self, other: str) -> bool: + return other.startswith(self.as_str()) + def no_data_dependencies_on_cartesian_axis( first: stree.MapScope, diff --git a/ndsl/dsl/dace/stree/optimizations/common/topology.py b/ndsl/dsl/dace/stree/optimizations/common/topology.py index 27edf8fe..fe878522 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/topology.py +++ b/ndsl/dsl/dace/stree/optimizations/common/topology.py @@ -3,6 +3,22 @@ import dace.sdfg.analysis.schedule_tree.treenodes as tn +def reparent_scope_node( + original_parent: tn.ScheduleTreeScope, + new_parent: tn.ScheduleTreeNode, + prepend: bool = True, +) -> None: + """Re-parent children between two scope nodes""" + + for child in original_parent.children: + child.parent = new_parent + + if prepend: + new_parent.children = [*original_parent.children, *new_parent.children] + else: + new_parent.children = [*new_parent.children, *original_parent.children] + + def swap_node_position_in_tree( top_node: tn.ScheduleTreeScope, child_node: tn.ScheduleTreeScope ) -> None: diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index bb066d4c..8da15f8f 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -38,11 +38,7 @@ def _reduce_cartesian_axis_size_to_1( # Assume 3D cartesian! if len(transient_data.shape) < 3: - warnings.warn( - f"Potential non-3D array: {transient_data}, skipping.", - UserWarning, - stacklevel=2, - ) + ndsl_log.debug(f"Potential non-3D array: {transient_data}, skipping.") return False read_write_range: dace.subsets.Range = dace.subsets.union( diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 0bffe7ab..58b20d23 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -1,5 +1,13 @@ +import ast + from dace.sdfg.analysis.schedule_tree import treenodes as tn +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator, reparent_scope_node +from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( + ReplaceAxisSymbolInTasklet, +) + class InlineVertical2DWrite(tn.ScheduleNodeTransformer): """Inline K index value for 2D write vertical while removing for loop. @@ -20,6 +28,48 @@ class InlineVertical2DWrite(tn.ScheduleNodeTransformer): def __init__(self) -> None: super().__init__() + self._for_scope_removed = 0 def __str__(self) -> str: return "InlineVertical2DWrite" + + def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeNode: + if ( + AxisIterator._K.is_equal(the_for.loop.loop_variable) + and the_for.loop.executions == 1 + and the_for.parent + ): + # Retrieve init value by executing the code and replace usage of it + # If the code cannot be executed (no-literal variable part of the op, etc.) + # we will _not_ inline + try: + exec(ast.unparse(the_for.loop.init_statement.code[0])) + except Exception as _: + return the_for + init_value = locals()[the_for.loop.loop_variable] + ReplaceAxisSymbolInTasklet().visit( + the_for, axis_replacements={the_for.loop.loop_variable: str(init_value)} + ) + + # Prepend children of the ForScope to parent + # the_for.parent.children = [*the_for.children, *the_for.parent.children] + reparent_scope_node(the_for, the_for.parent) + + # Remove ForScope + the_for.parent.children.remove(the_for) + self._for_scope_removed += 1 + assert len(the_for.children) > 0 + return the_for.parent.children[0] + + return the_for + + def visit_ScheduleTreeRoot( + self, the_root: tn.ScheduleTreeRoot + ) -> tn.ScheduleTreeRoot: + + for child in the_root.children: + self.visit(child) + + ndsl_log.debug(f"🚀 {self}: {self._for_scope_removed} inlined") + + return the_root diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py new file mode 100644 index 00000000..07611767 --- /dev/null +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -0,0 +1,132 @@ +from typing import TypeAlias + +import pytest +from dace import nodes +from dace.sdfg.state import LoopRegion + +from ndsl import QuantityFactory, StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile +from ndsl.config import Backend +from ndsl.constants import I_DIM, J_DIM, K_DIM, Float +from ndsl.dsl.gt4py import FORWARD, computation, interval +from ndsl.dsl.typing import FloatField, FloatFieldIJ +from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge + + +def stencil_simple_2D_write(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None: + with computation(FORWARD), interval(0, 1): + out_fieldIJ = in_field + + +def stencil_2D_write_at_K(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None: + with computation(FORWARD), interval(-1, None): + out_fieldIJ = in_field + + +class OrchestratedCode: + def __init__( + self, + stencil_factory: StencilFactory, + quantity_factory: QuantityFactory, + ) -> None: + orchestratable_methods = [ + "write_at_0", + "write_at_top", + ] + for method in orchestratable_methods: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + ) + + self.stencil_simple_2D_write = stencil_factory.from_dims_halo( + func=stencil_simple_2D_write, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + self.stencil_2D_write_at_K = stencil_factory.from_dims_halo( + func=stencil_2D_write_at_K, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + + def write_at_0( + self, + in_field: FloatField, + out_field: FloatFieldIJ, + ) -> None: + self.stencil_simple_2D_write(in_field, out_field) + + def write_at_top( + self, + in_field: FloatField, + out_field: FloatFieldIJ, + ) -> None: + self.stencil_2D_write_at_K(in_field, out_field) + + +Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] + + +class TestStree2DWriteInline: + @pytest.fixture(params=[Backend("orch:dace:cpu:IJK"), Backend("orch:dace:cpu:KJI")]) + def factories(self, request) -> Factories: + domain = (3, 3, 4) + return get_factories_single_tile( + domain[0], domain[1], domain[2], 0, backend=request.param + ) + + @pytest.fixture + def code(self, factories: Factories) -> OrchestratedCode: + return OrchestratedCode(*factories) + + def test_common_2D_write( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") + in_qty.field[:, :, 0] = Float(32.0) + + with StreeOptimization(): + code.write_at_0(in_qty, out_qty) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert len(all_maps) == 2 + assert len(all_loop_region) == 0 + assert (out_qty.field[:] == Float(32.0)).all() + + def test_2D_write_K_top(self, code: OrchestratedCode, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") + in_qty.field[:, :, -1] = Float(32.0) + + with StreeOptimization(): + code.write_at_top(in_qty, out_qty) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert len(all_maps) == 2 + assert len(all_loop_region) == 0 + assert (out_qty.field[:] == Float(32.0)).all() From 73f5609d355b9adf78eb175d1002decee2df97d2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 15 May 2026 15:59:20 -0400 Subject: [PATCH 007/101] Fix InlineVertical2DWrite --- .../dace/stree/optimizations/remove_loops.py | 16 ++++--- .../stree/optimizations/test_remove_loops.py | 45 +++++++++++++++++-- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 58b20d23..8eff3e51 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -34,19 +34,21 @@ def __str__(self) -> str: return "InlineVertical2DWrite" def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeNode: - if ( - AxisIterator._K.is_equal(the_for.loop.loop_variable) - and the_for.loop.executions == 1 - and the_for.parent - ): - # Retrieve init value by executing the code and replace usage of it + if AxisIterator._K.is_equal(the_for.loop.loop_variable) and the_for.parent: + # Retrieve init/bound value by executing the code and replace usage of it # If the code cannot be executed (no-literal variable part of the op, etc.) # we will _not_ inline try: exec(ast.unparse(the_for.loop.init_statement.code[0])) + init_value = locals()[the_for.loop.loop_variable] + bound_value = eval( + ast.unparse(the_for.loop.loop_condition.code[0].value.comparators) + ) except Exception as _: return the_for - init_value = locals()[the_for.loop.loop_variable] + if abs(bound_value - init_value) != 1: + return the_for + ReplaceAxisSymbolInTasklet().visit( the_for, axis_replacements={the_for.loop.loop_variable: str(init_value)} ) diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 07611767..012e88b9 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -23,16 +23,18 @@ def stencil_2D_write_at_K(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> No out_fieldIJ = in_field +def stencil_forward_at_K(in_field: FloatField, out_field: FloatField) -> None: + with computation(FORWARD), interval(...): + out_field = in_field + + class OrchestratedCode: def __init__( self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory, ) -> None: - orchestratable_methods = [ - "write_at_0", - "write_at_top", - ] + orchestratable_methods = ["write_at_0", "write_at_top", "do_not_inline"] for method in orchestratable_methods: orchestrate( obj=self, @@ -48,6 +50,10 @@ def __init__( func=stencil_2D_write_at_K, compute_dims=[I_DIM, J_DIM, K_DIM], ) + self.stencil_do_not_inline = stencil_factory.from_dims_halo( + func=stencil_forward_at_K, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) def write_at_0( self, @@ -63,6 +69,13 @@ def write_at_top( ) -> None: self.stencil_2D_write_at_K(in_field, out_field) + def do_not_inline( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.stencil_do_not_inline(in_field, out_field) + Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] @@ -130,3 +143,27 @@ def test_2D_write_K_top(self, code: OrchestratedCode, factories: Factories) -> N assert len(all_maps) == 2 assert len(all_loop_region) == 0 assert (out_qty.field[:] == Float(32.0)).all() + + def test_do_not_inline(self, code: OrchestratedCode, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.do_not_inline(in_qty, out_qty) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert len(all_maps) == 2 + assert len(all_loop_region) == 1 + assert (out_qty.field[:] == Float(1)).all() From fc1ecb10838849d4179b387e06849b04d303df08 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 18 May 2026 11:42:44 +0200 Subject: [PATCH 008/101] cleanup --- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 4 +--- .../dsl/dace/stree/optimizations/cartesian_merge.py | 2 -- ndsl/dsl/dace/stree/optimizations/clean_tree.py | 5 ++--- .../stree/optimizations/offgrid_conditionals.py | 13 +------------ .../dace/stree/optimizations/refine_transients.py | 4 ++-- .../optimizations/replace_symbol_in_tasklet.py | 5 ----- 6 files changed, 6 insertions(+), 27 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 9ac339e3..0c924123 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -1,11 +1,10 @@ -from __future__ import annotations - import copy import dace from dace.properties import CodeBlock from dace.sdfg.analysis.schedule_tree import treenodes as tn +from ndsl import ndsl_log from ndsl.config import Backend, BackendLoopOrder from ndsl.dsl.dace.stree.optimizations.common import ( AxisIterator, @@ -21,7 +20,6 @@ from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( ReplaceAxisSymbolInTasklet, ) -from ndsl.logging import ndsl_log # Buggy passes that should work diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index 398a2103..0521309f 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl.config import Backend, BackendLoopOrder diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py index 0da456de..8f04882a 100644 --- a/ndsl/dsl/dace/stree/optimizations/clean_tree.py +++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py @@ -1,14 +1,13 @@ -from __future__ import annotations - from dace.sdfg.analysis.schedule_tree import treenodes as tn -from ndsl.logging import ndsl_log +from ndsl import ndsl_log class CleanUpScheduleTree(tn.ScheduleNodeTransformer): """Remove `StateBoundary` nodes from children of ScheduleTreeScopes.""" def __init__(self) -> None: + super().__init__() self._removed_state_boundaries = 0 def __str__(self) -> str: diff --git a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py index 93d6ab20..d677c203 100644 --- a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py +++ b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from dace.sdfg.analysis.schedule_tree import treenodes as tn @@ -26,9 +24,6 @@ class InlineOffgridConditionals(tn.ScheduleNodeTransformer): ``` """ - def __init__(self) -> None: - pass - def __str__(self) -> str: return "InlineOffgridConditionals" @@ -39,9 +34,6 @@ class ExtractOffgridConditionals(tn.ScheduleNodeTransformer): Reverse transform from InlineOffgridConditionals """ - def __init__(self) -> None: - pass - def __str__(self) -> str: return "ExtractOffgridConditionals" @@ -67,13 +59,10 @@ class MergeConditionals(tn.ScheduleNodeTransformer): [ops...] ``` - Outside of user code, vombination of ExtractOffgridConditionals, + Outside of user code, combination of ExtractOffgridConditionals, InlineOffgridConditionals and CartesianMapMerge can lead to this pattern. """ - def __init__(self) -> None: - pass - def __str__(self) -> str: return "MergeConditionals" diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 8da15f8f..cd8e2703 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -34,7 +34,7 @@ def _reduce_cartesian_axis_size_to_1( are atomic""" # Dev Note: Better dataflow analysis would look at exactly - # what's goin on here! + # what's going on here! # Assume 3D cartesian! if len(transient_data.shape) < 3: @@ -206,7 +206,7 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer): cartesian axis) it will reduce that axis to 1 if all access are atomic (exactly _one_ element of the array is ever worked on in a single loop) - It will refuse to merge if the transient is used in multiple loops of for - a given axis - irrigardless of it's access pattern (e.g. even if it could be + a given axis - regardless of it's access pattern (e.g. even if it could be refine because it's always written first.) It should but cannot do/will bug if: diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index afb150e0..398ce203 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -1,14 +1,9 @@ -from __future__ import annotations - import itertools from dace.sdfg.analysis.schedule_tree import treenodes as tn class ReplaceAxisSymbolInTasklet(tn.ScheduleNodeVisitor): - def __init__(self) -> None: - pass - def visit_TaskletNode( self, node: tn.TaskletNode, From d7e40aa81096c87efdfc04998bd3136a55c5c031 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 18 May 2026 11:43:35 +0200 Subject: [PATCH 009/101] fix symbol replacement Use symbols in the replacement directory. Update DaCe to a version that doesn't re-initialize the symbols. And fix the test failure in python 3.13. --- external/dace | 2 +- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 6 +++++- .../dace/stree/optimizations/remove_loops.py | 17 ++++++++++++++--- .../optimizations/replace_symbol_in_tasklet.py | 5 +---- .../stree/optimizations/test_remove_loops.py | 5 +++-- 5 files changed, 24 insertions(+), 11 deletions(-) diff --git a/external/dace b/external/dace index d5fbadb6..f271b30b 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit d5fbadb626389e425fac5ed93d2a880811eca41f +Subproject commit f271b30bb983559306342ce2ff98c69e6662bb32 diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 0c924123..227f5614 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -334,7 +334,11 @@ def _map_overcompute_merge( # After merge, we need to replace the axis symbols of the second map's children # with the axis symbol of the first map. if second_map.node.map.params[0] != first_map.node.map.params[0]: - replacements = {second_map.node.map.params[0]: first_map.node.map.params[0]} + replacements = { + dace.symbol(second_map.node.map.params[0]): dace.symbol( + first_map.node.map.params[0] + ) + } ReplaceAxisSymbolInTasklet().visit( first_map, axis_replacements=replacements ) diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 8eff3e51..43e9c15e 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -1,5 +1,7 @@ import ast +from typing import Any +import dace from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import ndsl_log @@ -39,8 +41,14 @@ def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeN # If the code cannot be executed (no-literal variable part of the op, etc.) # we will _not_ inline try: - exec(ast.unparse(the_for.loop.init_statement.code[0])) - init_value = locals()[the_for.loop.loop_variable] + exec_locals: dict[str, Any] = {} + exec_globals: dict[str, Any] = {} + exec( + ast.unparse(the_for.loop.init_statement.code[0]), + exec_globals, + exec_locals, + ) + init_value = exec_locals[the_for.loop.loop_variable] bound_value = eval( ast.unparse(the_for.loop.loop_condition.code[0].value.comparators) ) @@ -50,7 +58,10 @@ def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeN return the_for ReplaceAxisSymbolInTasklet().visit( - the_for, axis_replacements={the_for.loop.loop_variable: str(init_value)} + the_for, + axis_replacements={ + dace.symbol(the_for.loop.loop_variable): str(init_value) + }, ) # Prepend children of the ForScope to parent diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index 398ce203..1020affe 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -18,10 +18,7 @@ def visit_TaskletNode( for memlet in itertools.chain( node.in_memlets.values(), node.out_memlets.values() ): - if memlet.subset is not None: - memlet.subset.replace(axis_replacements) - if memlet.other_subset is not None: - memlet.other_subset.replace(axis_replacements) + memlet.replace(axis_replacements) def __str__(self) -> str: return "ReplaceAxisSymbolInTasklet" diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 012e88b9..06cbe9fe 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -81,11 +81,12 @@ def do_not_inline( class TestStree2DWriteInline: - @pytest.fixture(params=[Backend("orch:dace:cpu:IJK"), Backend("orch:dace:cpu:KJI")]) + @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) def factories(self, request) -> Factories: + domain = (3, 3, 4) return get_factories_single_tile( - domain[0], domain[1], domain[2], 0, backend=request.param + domain[0], domain[1], domain[2], 0, backend=Backend(request.param) ) @pytest.fixture From 55ad8fa8db6b8acc050b3d1c82f862e96918fd27 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 18 May 2026 15:11:43 +0200 Subject: [PATCH 010/101] update gt4py (log10 and friends) --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index c7d162cc..9fbba0a0 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit c7d162cccc35cb2d1aaa79f5ad12222f617803ac +Subproject commit 9fbba0a07232cd8765123bdd226ea3c26cf768a8 From 1c9bb5f567831e824d86c778c7392d2cfaff8b95 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 18 May 2026 15:07:44 +0200 Subject: [PATCH 011/101] more cleanup (all minor nothing fancy) --- .../dace/stree/optimizations/axis_merge.py | 43 ++----------------- .../stree/optimizations/cartesian_merge.py | 4 +- .../stree/optimizations/common/topology.py | 1 + .../replace_symbol_in_tasklet.py | 2 - 4 files changed, 6 insertions(+), 44 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 227f5614..e699dddc 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -5,7 +5,6 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import ndsl_log -from ndsl.config import Backend, BackendLoopOrder from ndsl.dsl.dace.stree.optimizations.common import ( AxisIterator, detect_cycle, @@ -98,22 +97,20 @@ class CartesianAxisMerge(tn.ScheduleNodeTransformer): Can do: - merge a given axis with the next maps at the same recursion level - - can overcompute (eager) to allow for more merging at the cost of an if + - does overcompute to allow for more merging at the cost of an if It expects: - All Maps and ForLoop are on a single axis - but doesn't check for it. Args: axis: AxisIterator to be merged - eager: overcompute with a conditional guard """ - def __init__(self, axis: AxisIterator, *, eager: bool = True) -> None: + def __init__(self, axis: AxisIterator) -> None: self.axis = axis - self.eager = eager def __str__(self) -> str: - return f"CartesianAxisMerge_{self.axis.name}_{'eager' if self.eager else ''}" + return f"CartesianAxisMerge_{self.axis.name}" def _merge_node( self, node: tn.ScheduleTreeNode, nodes: list[tn.ScheduleTreeNode] @@ -418,37 +415,3 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: ndsl_log.debug( f"🚀 {self}: {overall_merged} maps merged in {passes_apply} passes" ) - - -class CartesianMerge(tn.ScheduleNodeTransformer): - """Merge Cartesian axis loops""" - - def __init__(self, backend: Backend, *, eager: bool = True) -> None: - self._backend = backend - self.eager = eager - - def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: - if self._backend.loop_order == BackendLoopOrder.IJK: - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - elif self._backend.loop_order == BackendLoopOrder.IKJ: - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - elif self._backend.loop_order == BackendLoopOrder.JIK: - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - elif self._backend.loop_order == BackendLoopOrder.JKI: - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - elif self._backend.loop_order == BackendLoopOrder.KIJ: - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - elif self._backend.loop_order == BackendLoopOrder.KJI: - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index 0521309f..d52403f7 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -13,9 +13,9 @@ class CartesianMerge(tn.ScheduleNodeTransformer): """Merge Cartesian computation blocks""" - def __init__(self, backend: Backend, *, eager: bool = True) -> None: + def __init__(self, backend: Backend) -> None: + super().__init__() self._backend = backend - self.eager = eager def __str__(self) -> str: return "CartesianMerge" diff --git a/ndsl/dsl/dace/stree/optimizations/common/topology.py b/ndsl/dsl/dace/stree/optimizations/common/topology.py index fe878522..e81df22a 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/topology.py +++ b/ndsl/dsl/dace/stree/optimizations/common/topology.py @@ -6,6 +6,7 @@ def reparent_scope_node( original_parent: tn.ScheduleTreeScope, new_parent: tn.ScheduleTreeNode, + *, prepend: bool = True, ) -> None: """Re-parent children between two scope nodes""" diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index 1020affe..7dcb7bae 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -13,8 +13,6 @@ def visit_TaskletNode( # Noop if there are no replacements to do. return - # Dev NOTE: We directly replace the memlet.subset because the `memlet.replace` - # function sometimes doesn't work for memlet in itertools.chain( node.in_memlets.values(), node.out_memlets.values() ): From 36204b00eebab3ee9ca539dbb01d2c2034a0d2ad Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 20 May 2026 15:35:28 +0200 Subject: [PATCH 012/101] Add support for InlineOffgridConditionals --- external/dace | 2 +- .../dace/stree/optimizations/axis_merge.py | 24 +-- .../stree/optimizations/cartesian_merge.py | 48 +++--- .../dace/stree/optimizations/clean_tree.py | 4 +- .../optimizations/offgrid_conditionals.py | 119 +++++++++---- ndsl/dsl/dace/stree/pipeline.py | 1 + .../test_offgrid_conditionals.py | 158 ++++++++++++++++++ 7 files changed, 288 insertions(+), 68 deletions(-) create mode 100644 tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py diff --git a/external/dace b/external/dace index f271b30b..ec81b1a0 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit f271b30bb983559306342ce2ff98c69e6662bb32 +Subproject commit ec81b1a0c2a872da8dd315378ff6a9ac67d5458b diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index e699dddc..3ad5b377 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -75,20 +75,20 @@ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: all_children_are_maps = all( [isinstance(child, tn.MapScope) for child in node.children] ) - if not all_children_are_maps: - if self._merged_range != self._original_range: - if_scope = tn.IfScope( - condition=self._execution_condition(), - children=node.children, - parent=node, - ) - # Re-parent to IF - for child in node.children: - child.parent = if_scope - node.children = [if_scope] + if all_children_are_maps: + node.children = self.visit(node.children) return node - node.children = self.visit(node.children) + if self._merged_range != self._original_range: + if_scope = tn.IfScope( + condition=self._execution_condition(), + children=node.children, + parent=node, + ) + # Re-parent to IF + for child in node.children: + child.parent = if_scope + node.children = [if_scope] return node diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index d52403f7..1dd64458 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -21,33 +21,31 @@ def __str__(self) -> str: return "CartesianMerge" def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: - InlineOffgridConditionals().visit(node) + for axis in self._backend_order(): + InlineOffgridConditionals(axis).visit(node) MergeConditionals().visit(node) - if self._backend.loop_order == BackendLoopOrder.IJK: - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - elif self._backend.loop_order == BackendLoopOrder.IKJ: - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - elif self._backend.loop_order == BackendLoopOrder.JIK: - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - elif self._backend.loop_order == BackendLoopOrder.JKI: - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - elif self._backend.loop_order == BackendLoopOrder.KIJ: - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - elif self._backend.loop_order == BackendLoopOrder.KJI: - CartesianAxisMerge(AxisIterator._K).visit(node) - CartesianAxisMerge(AxisIterator._J).visit(node) - CartesianAxisMerge(AxisIterator._I).visit(node) + for axis in self._backend_order(): + CartesianAxisMerge(axis).visit(node) ExtractOffgridConditionals().visit(node) MergeConditionals().visit(node) + + def _backend_order(self) -> tuple[AxisIterator, AxisIterator, AxisIterator]: + if self._backend.loop_order == BackendLoopOrder.IJK: + return (AxisIterator._I, AxisIterator._J, AxisIterator._K) + + if self._backend.loop_order == BackendLoopOrder.IKJ: + return (AxisIterator._I, AxisIterator._K, AxisIterator._J) + + if self._backend.loop_order == BackendLoopOrder.JIK: + return (AxisIterator._J, AxisIterator._I, AxisIterator._K) + + if self._backend.loop_order == BackendLoopOrder.JKI: + return (AxisIterator._J, AxisIterator._K, AxisIterator._I) + + if self._backend.loop_order == BackendLoopOrder.KIJ: + return (AxisIterator._K, AxisIterator._I, AxisIterator._J) + + assert self._backend.loop_order == BackendLoopOrder.KJI + return (AxisIterator._K, AxisIterator._J, AxisIterator._I) diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py index 8f04882a..7d3b5558 100644 --- a/ndsl/dsl/dace/stree/optimizations/clean_tree.py +++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py @@ -49,12 +49,13 @@ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: def visit_IfScope(self, node: tn.IfScope) -> tn.IfScope: self._remove_state_boundaries_from_children(node) + for child in node.children: self.visit(child) return node - def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot: self._removed_state_boundaries = 0 self._remove_state_boundaries_from_children(node) @@ -63,3 +64,4 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: self.visit(child) ndsl_log.debug(f"{self}: removed {self._removed_state_boundaries} nodes") + return node diff --git a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py index d677c203..de4c21a2 100644 --- a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py +++ b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py @@ -1,37 +1,97 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn +from ndsl import ndsl_log +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + get_next_node, + is_axis_map, + last_node, + list_index, +) -class InlineOffgridConditionals(tn.ScheduleNodeTransformer): - """Push offgrid conditional inside their cartesian block, - duplicating the conditional if needed + +class InlineOffgridConditionals(tn.ScheduleNodeVisitor): + """ + Push offgrid conditional inside their cartesian block, duplicating the + conditional if needed. Turning: ``` - if a_flag == 0 - map i, j, k - [ops...] - map i, j, k - [ops...] + if a_flag == 0: + map i, j, k: + ... + map i, j, k: + ... ``` into ``` - map i,j, k - if a_flag == 0 - [ops...] - map i,j, k - if a_flag == 0 - [ops...] + map i, j, k: + if a_flag == 0: + ... + map i, j, k: + if a_flag == 0: + ... ``` """ + _axis: AxisIterator + + def __init__(self, axis: AxisIterator) -> None: + super().__init__() + self._axis = axis + def __str__(self) -> str: - return "InlineOffgridConditionals" + return f"InlineOffgridConditionals_{self._axis}" + + def visit_IfScope(self, node: tn.IfScope) -> None: + assert node.parent is not None # just to keep pyright happy + + # For now, skip in case there's an `elif` or `else` following. + if not last_node(node.parent.children, node): + next_node = get_next_node(node.parent.children, node) + if isinstance(next_node, (tn.ElifScope, tn.ElseScope)): + ndsl_log.debug( + "Can't handle conditionals with `elif` and `else` blocks yet :(" + ) + return + + if not all( + [ + isinstance(child, tn.MapScope) and is_axis_map(child, self._axis) + for child in node.children + ] + ): + return + + # If all children are maps over the correct axis, move the if inside. + new_nodes: list[tn.MapScope] = [] + + for child in node.children: + assert isinstance( + child, tn.MapScope + ) # otherwise the condition above is wrong + + if_scope = tn.IfScope( + condition=node.condition, children=child.children, parent=child + ) + + for map_child in child.children: + map_child.parent = if_scope # re-parent to new if_scope + + child.children = [if_scope] + child.parent = node.parent # re-parent to parent of old if_scope + new_nodes.append(child) + + insert_at = list_index(node.parent.children, node) + node.parent.children[insert_at:insert_at] = new_nodes + node.parent.children.remove(node) class ExtractOffgridConditionals(tn.ScheduleNodeTransformer): - """Push offgrid conditional outside of their cartesian block + """ + Push offgrid conditional outside of their cartesian block. - Reverse transform from InlineOffgridConditionals + This is the inverse transform of InlineOffgridConditionals. """ def __str__(self) -> str: @@ -39,24 +99,25 @@ def __str__(self) -> str: class MergeConditionals(tn.ScheduleNodeTransformer): - """Merge consecutive and equal conditionals + """ + Merge consecutive and equal conditionals. Turning: ``` - if a_flag == 0 - map i, j, k - [ops...] - if a_flag == 0 - map i, j, k - [ops...] + if a_flag == 0: + map i, j, k: + ... + if a_flag == 0: + map i, j, k: + ... ``` into ``` - if a_flag == 0 - map i, j, k - [ops...] - map i, j, k - [ops...] + if a_flag == 0: + map i, j, k: + ... + map i, j, k: + ... ``` Outside of user code, combination of ExtractOffgridConditionals, diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index ad52ca9c..13e30974 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -52,6 +52,7 @@ def run( f.write(stree.as_string()) tree_stats.optimized(stree) + if verbose: ndsl_log_on_rank_0.info(tree_stats.report()) diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py new file mode 100644 index 00000000..7e4ceb4f --- /dev/null +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -0,0 +1,158 @@ +from typing import TypeAlias + +import pytest +from dace import nodes + +from ndsl import ( + Backend, + NDSLRuntime, + QuantityFactory, + StencilFactory, + orchestrate, + stencils, +) +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import I_DIM, J_DIM, K_DIM +from ndsl.dsl.typing import FloatField +from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge + + +class OrchestratedCode(NDSLRuntime): + def __init__(self, stencil_factory: StencilFactory) -> None: + super().__init__(stencil_factory) + + methods_to_orchestrate = [ + "happy_case", + "happy_case_2", + "blocked_by_else", + "blocked_by_other_nodes", + ] + + for method in methods_to_orchestrate: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + ) + + self._copy_stencil = stencil_factory.from_dims_halo( + func=stencils.copy, compute_dims=[I_DIM, J_DIM, K_DIM] + ) + + def happy_case(self, in_field: FloatField, out_field: FloatField) -> None: + if in_field[0, 0, 0] > 0: + self._copy_stencil(in_field, out_field) + self._copy_stencil(in_field, out_field) + + def happy_case_2(self, in_field: FloatField, out_field: FloatField) -> None: + if not in_field[0, 0, 0] > 0: + self._copy_stencil(in_field, out_field) + self._copy_stencil(in_field, out_field) + + def blocked_by_else(self, in_field: FloatField, out_field: FloatField) -> None: + self._copy_stencil(in_field, out_field) + + if in_field[0, 0, 0] > 0: + self._copy_stencil(in_field, out_field) + else: + self._copy_stencil(out_field, in_field) + + def blocked_by_other_nodes( + self, in_field: FloatField, out_field: FloatField + ) -> None: + if in_field[0, 0, 0] > 0: + in_field[:] = 42.0 + self._copy_stencil(in_field, out_field) + self._copy_stencil(in_field, out_field) + + +Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] + + +class TestStreeInlineOffgridConditionals: + @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) + def factories(self, request) -> Factories: + domain = (3, 3, 4) + return get_factories_single_tile( + domain[0], domain[1], domain[2], 0, backend=Backend(request.param) + ) + + def test_happy_case(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.happy_case(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + assert precompiled_sdfg.sdfg + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 3 + + def test_happy_case_2(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.happy_case_2(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + assert precompiled_sdfg.sdfg + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 3 + + def test_blocked_by_else(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.blocked_by_else(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + assert precompiled_sdfg.sdfg + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 9 + + def test_blocked_by_other_nodes(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + + code = OrchestratedCode(stencil_factory) + in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + with StreeOptimization(): + code.blocked_by_other_nodes(in_quantity, out_quantity) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + assert precompiled_sdfg.sdfg + + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + assert len(all_maps) == 6 From 689ab89a4b046af4860f386cbb892d959d30a9eb Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 20 May 2026 15:51:21 +0200 Subject: [PATCH 013/101] fixup: temp fix for test of InlineOffgridConditionals --- .../stree/optimizations/test_offgrid_conditionals.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index 7e4ceb4f..77ba3591 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -155,4 +155,13 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 6 + + # ⚠️ Dev note: + # This should be just `assert len(all_maps) == 6`, but currently, the K-loops + # can't merge because the K-iterators are different. To be fixed (and simplified + # here) with a subsequent commit. + assert ( + len(all_maps) == 6 + if stencil_factory.backend == Backend("orch:dace:cpu:IJK") + else 9 + ) From c263116d813bf5c7eaa95b08194e11faa46f08f9 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 20 May 2026 15:54:39 +0200 Subject: [PATCH 014/101] cleanup: remove old "push if down" codepath This has been replaced with `InlineOffgridConditionals` pass --- .../dace/stree/optimizations/axis_merge.py | 80 ------------------- 1 file changed, 80 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 3ad5b377..44f6577f 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -21,10 +21,6 @@ ) -# Buggy passes that should work -PUSH_IFSCOPE_DOWNWARD = False # Crashing the overall stree - bad algorithmics - - def _both_same_single_axis_maps( first: tn.MapScope, second: tn.MapScope, axis: AxisIterator ) -> bool: @@ -125,9 +121,6 @@ def _merge_node( if isinstance(node, tn.MapScope): return self._map_overcompute_merge(node, nodes) - if PUSH_IFSCOPE_DOWNWARD and isinstance(node, tn.IfScope): - return self._push_ifelse_down(node, nodes) - if isinstance(node, tn.ForScope): return self._for_merge(node) @@ -194,79 +187,6 @@ def _push_tasklet_down( return merged - def _push_ifelse_down( - self, the_if: tn.IfScope, nodes: list[tn.ScheduleTreeNode] - ) -> int: - merged = 0 - - # Recurse down if/else/elif - if_index = list_index(nodes, the_if) - if len(the_if.children) != 0: - merged += self._merge_node(the_if.children[0], the_if.children) - for else_index in range(if_index + 1, len(nodes)): - else_node = nodes[else_index] - if else_index < len(nodes) and ( - isinstance(else_node, tn.ElseScope) - or isinstance(else_node, tn.ElifScope) - ): - merged += self._merge_node(else_node, else_node.children) - else: - break - - # Look at swapping if/else/elif first map w/ control flow - - # Gather all first maps - if they do not exists, get out - all_maps = [] - if isinstance(the_if.children[0], tn.MapScope): - all_maps.append(the_if.children[0]) - else: - return merged - for else_index in range(if_index + 1, len(nodes)): - else_node = nodes[else_index] - if else_index < len(nodes) and ( - isinstance(else_node, tn.ElseScope) - or isinstance(else_node, tn.ElifScope) - ): - if isinstance(else_node.children[0], tn.MapScope): - all_maps.append(else_node.children[0]) - else: - return merged - - else: - break - - # Check for mergeability - if len(all_maps) > 1: - the_map = all_maps[0] - for _map in all_maps[1:]: - if not _can_merge_axis_maps(the_map, _map, self.axis): - return merged - - # We are good to go - swap it all - inner_if_map = the_if.children[0] - - # Swap IF & maps - if_index = list_index(nodes, the_if) - swap_node_position_in_tree(the_if, inner_if_map) - - # Swap ELIF/ELSE & maps - for else_index in range(if_index + 1, len(nodes)): - if else_index < len(nodes) and ( - isinstance(nodes[else_index], tn.ElseScope) - or isinstance(nodes[else_index], tn.ElifScope) - ): - swap_node_position_in_tree( - nodes[else_index], nodes[else_index].children[0] - ) - else: - break - - # Merge the Maps - assert isinstance(nodes[if_index], tn.MapScope) - merged += self._map_overcompute_merge(nodes[if_index], nodes) - - return merged - def _map_overcompute_merge( self, the_map: tn.MapScope, nodes: list[tn.ScheduleTreeNode] ) -> int: From 7d6ecc1756a4ca00417ee146430b05fd30e1191a Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 20 May 2026 12:27:44 -0400 Subject: [PATCH 015/101] Normalize cartesian index during data depedancy check --- .../dace/stree/optimizations/common/memlet.py | 26 ++++++++++++--- tests/dsl/dace/stree/common/__init__.py | 0 tests/dsl/dace/stree/common/test_memlet.py | 32 +++++++++++++++++++ 3 files changed, 53 insertions(+), 5 deletions(-) create mode 100644 tests/dsl/dace/stree/common/__init__.py create mode 100644 tests/dsl/dace/stree/common/test_memlet.py diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index b1540c98..97d99b68 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -2,6 +2,7 @@ import dace.sdfg.analysis.schedule_tree.treenodes as stree from dace.memlet import Memlet +from dace.symbolic import symbol from ndsl.logging import ndsl_log @@ -21,6 +22,15 @@ def is_equal(self, other: str) -> bool: return other.startswith(self.as_str()) +def normalize_cartesian_indexation(index: symbol, axis: AxisIterator) -> symbol: + """Return a normalize indexation symbol for cartesian indexation.""" + rename_maps = {} + for symb in index.free_symbols: + if symb.name.startswith(axis.as_str()): + rename_maps[symb] = symbol(axis.as_str()) + return index.subs(rename_maps) + + def no_data_dependencies_on_cartesian_axis( first: stree.MapScope, second: stree.MapScope, @@ -36,20 +46,26 @@ def no_data_dependencies_on_cartesian_axis( for write in write_collector.out_memlets: # TODO: this can be optimized to allow non-overlapping intervals and such in the future - if write.subset.dims() <= axis.as_cartesian_index(): + axis_index = axis.as_cartesian_index() + + if write.subset.dims() <= axis_index: # Dimension does not exist continue - previous_axis_index = write.subset[axis.as_cartesian_index()][0] + previous_axis_index = normalize_cartesian_indexation( + write.subset[axis_index][0], axis + ) for read in read_collector.in_memlets: if write.data == read.data: - if previous_axis_index != read.subset[axis.as_cartesian_index()][0]: + if previous_axis_index != normalize_cartesian_indexation( + read.subset[axis_index][0], axis + ): ndsl_log.debug( f"[{axis.name} Merge] Found read after write conflict " f"for {write.data} " f"w/ different offset to {axis.name} (" - f"write at {write.subset[axis.as_cartesian_index()][0]}, " - f"read at {read.subset[axis.as_cartesian_index()][0]})" + f"write at {write.subset[axis_index][0]}, " + f"read at {read.subset[axis_index][0]})" ) return False return True diff --git a/tests/dsl/dace/stree/common/__init__.py b/tests/dsl/dace/stree/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dsl/dace/stree/common/test_memlet.py b/tests/dsl/dace/stree/common/test_memlet.py new file mode 100644 index 00000000..44fe15e1 --- /dev/null +++ b/tests/dsl/dace/stree/common/test_memlet.py @@ -0,0 +1,32 @@ +from dace.symbolic import symbol + +from ndsl.dsl.dace.stree.optimizations.common import AxisIterator +from ndsl.dsl.dace.stree.optimizations.common.memlet import ( + normalize_cartesian_indexation, +) + + +def test_normalize_cartesian_index(): + # Case of __k_id(node) - original case + original_symbol = symbol("__k_12345678789") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("__k") + + # Case of offset + original_symbol = 1 + symbol("__k_12345678789") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("__k") + 1 + + # Case of no-op (with offset) + original_symbol = 1 + symbol("__k") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("__k") + 1 + + # Case of index named with _k - so not a cartesian axis + original_symbol = 1 + symbol("_kindex") + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + + assert norm_symbol == symbol("_kindex") + 1 From de03d3480038d990f2bcbee2a5d62ec72efa12e0 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 20 May 2026 16:18:25 -0400 Subject: [PATCH 016/101] Update tests --- tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index 77ba3591..5699407e 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -161,7 +161,7 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: # can't merge because the K-iterators are different. To be fixed (and simplified # here) with a subsequent commit. assert ( - len(all_maps) == 6 + len(all_maps) == 5 if stencil_factory.backend == Backend("orch:dace:cpu:IJK") else 9 ) From ff5722770caf6a919e09c969e3f085fa32263f59 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 20 May 2026 17:12:39 -0400 Subject: [PATCH 017/101] ReplaceAxisSymbolInTasklet -> ReplaceAxisSymbol + debug of it's usage --- .../dace/stree/optimizations/axis_merge.py | 9 +++---- .../dace/stree/optimizations/remove_loops.py | 11 +++----- .../replace_symbol_in_tasklet.py | 26 ++++++++++++++----- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 44f6577f..3f10f122 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -17,7 +17,7 @@ swap_node_position_in_tree, ) from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( - ReplaceAxisSymbolInTasklet, + ReplaceAxisSymbol, ) @@ -211,7 +211,6 @@ def _map_overcompute_merge( # Over compute to merge: # - force-merge by expanding the ranges - # - then, guard children to only run in their respective range first_range = the_map.node.map.range second_range = next_node.node.map.range merged_range = dace.subsets.Range( @@ -224,7 +223,7 @@ def _map_overcompute_merge( ] ) - # push IfScope down if children are just maps + # - then, guard children to only run in their respective range axis_as_str = the_map.node.params[0] first_map = InsertOvercomputationGuard( axis_as_str, merged_range=merged_range, original_range=first_range @@ -256,9 +255,7 @@ def _map_overcompute_merge( first_map.node.map.params[0] ) } - ReplaceAxisSymbolInTasklet().visit( - first_map, axis_replacements=replacements - ) + ReplaceAxisSymbol(replacements).visit(first_map) # delete now-merged second_map del nodes[list_index(nodes, next_node)] diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 43e9c15e..41c7aee9 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -7,7 +7,7 @@ from ndsl import ndsl_log from ndsl.dsl.dace.stree.optimizations.common import AxisIterator, reparent_scope_node from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( - ReplaceAxisSymbolInTasklet, + ReplaceAxisSymbol, ) @@ -57,12 +57,9 @@ def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeN if abs(bound_value - init_value) != 1: return the_for - ReplaceAxisSymbolInTasklet().visit( - the_for, - axis_replacements={ - dace.symbol(the_for.loop.loop_variable): str(init_value) - }, - ) + ReplaceAxisSymbol( + {dace.symbol(the_for.loop.loop_variable): str(init_value)} + ).visit(the_for) # Prepend children of the ForScope to parent # the_for.parent.children = [*the_for.children, *the_for.parent.children] diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index 7dcb7bae..a64c6e67 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -1,22 +1,34 @@ import itertools from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.symbolic import symbol -class ReplaceAxisSymbolInTasklet(tn.ScheduleNodeVisitor): +class ReplaceAxisSymbol(tn.ScheduleNodeVisitor): + def __init__(self, axis_replacements: dict[str | symbol, str | symbol]) -> None: + self._axis_replacements = axis_replacements + def visit_TaskletNode( self, node: tn.TaskletNode, - axis_replacements: dict[str, str] | None = None, ) -> None: - if not axis_replacements: - # Noop if there are no replacements to do. - return - for memlet in itertools.chain( node.in_memlets.values(), node.out_memlets.values() ): - memlet.replace(axis_replacements) + memlet.replace(self._axis_replacements) + + def visit_IfScope( + self, + node: tn.IfScope, + ) -> None: + if self._axis_replacements: + for old, new in self._axis_replacements.items(): + node.condition.as_string = node.condition.as_string.replace( + str(old), str(new) + ) + + for child in node.children: + self.visit(child) def __str__(self) -> str: return "ReplaceAxisSymbolInTasklet" From 94b2e9928d5e89b5e8804cdae044783cecb5d9bd Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 21 May 2026 09:59:03 +0200 Subject: [PATCH 018/101] fix unit test by hardinging detection of "our" loops --- .../dace/stree/optimizations/common/loops.py | 13 +++- .../dace/stree/optimizations/remove_loops.py | 7 +- .../replace_symbol_in_tasklet.py | 21 ++---- tests/dsl/dace/stree/common/test_loops.py | 69 +++++++++++++++++++ .../test_offgrid_conditionals.py | 6 +- 5 files changed, 92 insertions(+), 24 deletions(-) create mode 100644 tests/dsl/dace/stree/common/test_loops.py diff --git a/ndsl/dsl/dace/stree/optimizations/common/loops.py b/ndsl/dsl/dace/stree/optimizations/common/loops.py index 83a91280..35e33b8c 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/loops.py +++ b/ndsl/dsl/dace/stree/optimizations/common/loops.py @@ -6,9 +6,18 @@ def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: """Returns true if node is a Map over the given axis.""" map_parameter = node.node.map.params - return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str()) + if len(map_parameter) != 1: + return False + + if axis == AxisIterator._K: + return map_parameter[0].startswith(axis.as_str()) + + return map_parameter[0] == axis.as_str() def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: """Returns true if node is a For over the given axis.""" - return node.loop.loop_variable.startswith(axis.as_str()) + if axis == AxisIterator._K: + return node.loop.loop_variable.startswith(axis.as_str()) + + return node.loop.loop_variable == axis.as_str() diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 41c7aee9..54ac6d5d 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -30,7 +30,7 @@ class InlineVertical2DWrite(tn.ScheduleNodeTransformer): def __init__(self) -> None: super().__init__() - self._for_scope_removed = 0 + self._for_scopes_removed = 0 def __str__(self) -> str: return "InlineVertical2DWrite" @@ -67,7 +67,7 @@ def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeN # Remove ForScope the_for.parent.children.remove(the_for) - self._for_scope_removed += 1 + self._for_scopes_removed += 1 assert len(the_for.children) > 0 return the_for.parent.children[0] @@ -76,10 +76,11 @@ def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeN def visit_ScheduleTreeRoot( self, the_root: tn.ScheduleTreeRoot ) -> tn.ScheduleTreeRoot: + self._for_scopes_removed = 0 for child in the_root.children: self.visit(child) - ndsl_log.debug(f"🚀 {self}: {self._for_scope_removed} inlined") + ndsl_log.debug(f"🚀 {self}: {self._for_scopes_removed} inlined") return the_root diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py index a64c6e67..dbc26eb0 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py @@ -8,27 +8,20 @@ class ReplaceAxisSymbol(tn.ScheduleNodeVisitor): def __init__(self, axis_replacements: dict[str | symbol, str | symbol]) -> None: self._axis_replacements = axis_replacements - def visit_TaskletNode( - self, - node: tn.TaskletNode, - ) -> None: + def visit_TaskletNode(self, node: tn.TaskletNode) -> None: for memlet in itertools.chain( node.in_memlets.values(), node.out_memlets.values() ): memlet.replace(self._axis_replacements) - def visit_IfScope( - self, - node: tn.IfScope, - ) -> None: - if self._axis_replacements: - for old, new in self._axis_replacements.items(): - node.condition.as_string = node.condition.as_string.replace( - str(old), str(new) - ) + def visit_IfScope(self, node: tn.IfScope) -> None: + for old, new in self._axis_replacements.items(): + node.condition.as_string = node.condition.as_string.replace( + str(old), str(new) + ) for child in node.children: self.visit(child) def __str__(self) -> str: - return "ReplaceAxisSymbolInTasklet" + return "ReplaceAxisSymbol" diff --git a/tests/dsl/dace/stree/common/test_loops.py b/tests/dsl/dace/stree/common/test_loops.py new file mode 100644 index 00000000..3020f551 --- /dev/null +++ b/tests/dsl/dace/stree/common/test_loops.py @@ -0,0 +1,69 @@ +from dace.sdfg import nodes +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.state import LoopRegion + +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_for, + is_axis_map, +) + + +def test_is_axis_map_multiple_params() -> None: + node = tn.MapScope( + node=nodes.MapEntry( + nodes.Map("map_ij", ["__i", "__j"], [(0, 3, 1), (0, 4, 1)]) + ), + children=[], + ) + assert not is_axis_map(node, AxisIterator._I) + assert not is_axis_map(node, AxisIterator._J) + + +def test_is_axis_map_I() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[] + ) + assert is_axis_map(node, AxisIterator._I) + + +def test_is_axis_map_not_I() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_other_i", ["__i0"], [(0, 3, 1)])), + children=[], + ) + assert not is_axis_map(node, AxisIterator._I) + + +def test_is_axis_map_K() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", ["__k_1234"], [(0, 3, 1)])), children=[] + ) + assert is_axis_map(node, AxisIterator._K) + + +def test_is_axis_map_wrong_iterator() -> None: + node = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[] + ) + assert not is_axis_map(node, AxisIterator._J) + + +def test_is_axis_for_k() -> None: + node = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[]) + assert is_axis_for(node, AxisIterator._K) + + +def test_is_axis_for_wrong_iterator() -> None: + node = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[]) + assert not is_axis_for(node, AxisIterator._I) + + +def test_is_axis_for_i() -> None: + node = tn.ForScope(loop=LoopRegion("for_i", loop_var="__i"), children=[]) + assert is_axis_for(node, AxisIterator._I) + + +def test_is_axis_for_not_i() -> None: + node = tn.ForScope(loop=LoopRegion("for_i", loop_var="__i0"), children=[]) + assert not is_axis_for(node, AxisIterator._I) diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index 5699407e..f897173f 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -160,8 +160,4 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: # This should be just `assert len(all_maps) == 6`, but currently, the K-loops # can't merge because the K-iterators are different. To be fixed (and simplified # here) with a subsequent commit. - assert ( - len(all_maps) == 5 - if stencil_factory.backend == Backend("orch:dace:cpu:IJK") - else 9 - ) + assert len(all_maps) == 9 From 3a5057719d69780e20671cc380ea72324056419e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 21 May 2026 11:56:27 +0200 Subject: [PATCH 019/101] unrelated cleanup: fix/assert type issues --- ndsl/dsl/dace/stree/optimizations/common/topology.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/stree/optimizations/common/topology.py b/ndsl/dsl/dace/stree/optimizations/common/topology.py index e81df22a..fa06f3db 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/topology.py +++ b/ndsl/dsl/dace/stree/optimizations/common/topology.py @@ -5,7 +5,7 @@ def reparent_scope_node( original_parent: tn.ScheduleTreeScope, - new_parent: tn.ScheduleTreeNode, + new_parent: tn.ScheduleTreeScope, *, prepend: bool = True, ) -> None: @@ -26,6 +26,7 @@ def swap_node_position_in_tree( """Top node becomes child, child becomes top node.""" # Ensue parent/children relationship is valid tn.validate_children_and_parents_align(top_node) + assert top_node.parent is not None # Take refs before swap top_children = top_node.parent.children From d6824f3c3f63c722cab4e876d25b7fff7912809b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 21 May 2026 15:36:09 +0200 Subject: [PATCH 020/101] Changes to `InlineVertical2DWrite` --- .../dace/stree/optimizations/common/loops.py | 10 +- .../dace/stree/optimizations/common/memlet.py | 5 +- .../dace/stree/optimizations/remove_loops.py | 93 +++++++------ .../stree/optimizations/test_remove_loops.py | 129 +++++++++++++++--- 4 files changed, 166 insertions(+), 71 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/common/loops.py b/ndsl/dsl/dace/stree/optimizations/common/loops.py index 35e33b8c..5d414915 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/loops.py +++ b/ndsl/dsl/dace/stree/optimizations/common/loops.py @@ -9,15 +9,9 @@ def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: if len(map_parameter) != 1: return False - if axis == AxisIterator._K: - return map_parameter[0].startswith(axis.as_str()) - - return map_parameter[0] == axis.as_str() + return axis.is_equal(map_parameter[0]) def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: """Returns true if node is a For over the given axis.""" - if axis == AxisIterator._K: - return node.loop.loop_variable.startswith(axis.as_str()) - - return node.loop.loop_variable == axis.as_str() + return axis.is_equal(node.loop.loop_variable) diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index 97d99b68..61f64225 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -19,7 +19,10 @@ def as_cartesian_index(self) -> int: return self.value[1] def is_equal(self, other: str) -> bool: - return other.startswith(self.as_str()) + if self == AxisIterator._K: + return other.startswith(self.as_str()) + + return other == self.as_str() def normalize_cartesian_indexation(index: symbol, axis: AxisIterator) -> symbol: diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 54ac6d5d..c02f8af5 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -5,13 +5,17 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import ndsl_log -from ndsl.dsl.dace.stree.optimizations.common import AxisIterator, reparent_scope_node +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_for, + list_index, +) from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( ReplaceAxisSymbol, ) -class InlineVertical2DWrite(tn.ScheduleNodeTransformer): +class InlineVertical2DWrite(tn.ScheduleNodeVisitor): """Inline K index value for 2D write vertical while removing for loop. Transforming: @@ -35,52 +39,51 @@ def __init__(self) -> None: def __str__(self) -> str: return "InlineVertical2DWrite" - def visit_ForScope(self, the_for: tn.ForScope) -> tn.ForScope | tn.ScheduleTreeNode: - if AxisIterator._K.is_equal(the_for.loop.loop_variable) and the_for.parent: - # Retrieve init/bound value by executing the code and replace usage of it - # If the code cannot be executed (no-literal variable part of the op, etc.) - # we will _not_ inline - try: - exec_locals: dict[str, Any] = {} - exec_globals: dict[str, Any] = {} - exec( - ast.unparse(the_for.loop.init_statement.code[0]), - exec_globals, - exec_locals, - ) - init_value = exec_locals[the_for.loop.loop_variable] - bound_value = eval( - ast.unparse(the_for.loop.loop_condition.code[0].value.comparators) - ) - except Exception as _: - return the_for - if abs(bound_value - init_value) != 1: - return the_for - - ReplaceAxisSymbol( - {dace.symbol(the_for.loop.loop_variable): str(init_value)} - ).visit(the_for) - - # Prepend children of the ForScope to parent - # the_for.parent.children = [*the_for.children, *the_for.parent.children] - reparent_scope_node(the_for, the_for.parent) - - # Remove ForScope - the_for.parent.children.remove(the_for) - self._for_scopes_removed += 1 - assert len(the_for.children) > 0 - return the_for.parent.children[0] - - return the_for - - def visit_ScheduleTreeRoot( - self, the_root: tn.ScheduleTreeRoot - ) -> tn.ScheduleTreeRoot: + def visit_ForScope(self, the_for: tn.ForScope) -> None: + if not is_axis_for(the_for, AxisIterator._K): + return + + assert the_for.parent is not None # just to keep pyright happy + + # Retrieve init/bound value by executing the code and replace usage of it + # If the code cannot be executed (no-literal variable part of the op, etc.) + # we will _not_ inline + try: + exec_locals: dict[str, Any] = {} + exec_globals: dict[str, Any] = {} + exec( + ast.unparse(the_for.loop.init_statement.code[0]), + exec_globals, + exec_locals, + ) + init_value = exec_locals[the_for.loop.loop_variable] + bound_value = eval( + ast.unparse(the_for.loop.loop_condition.code[0].value.comparators) + ) + except Exception as _: + return + if abs(bound_value - init_value) != 1: + return + + ReplaceAxisSymbol( + {dace.symbol(the_for.loop.loop_variable): str(init_value)} + ).visit(the_for) + + # Insert children of the ForScope to parent + insert_at = list_index(the_for.parent.children, the_for) + for child in the_for.children: + child.parent = the_for.parent + the_for.parent.children[insert_at:insert_at] = the_for.children + + # Remove ForScope + the_for.parent.children.remove(the_for) + self._for_scopes_removed += 1 + assert len(the_for.children) > 0 + + def visit_ScheduleTreeRoot(self, the_root: tn.ScheduleTreeRoot) -> None: self._for_scopes_removed = 0 for child in the_root.children: self.visit(child) ndsl_log.debug(f"🚀 {self}: {self._for_scopes_removed} inlined") - - return the_root diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 06cbe9fe..4af0474a 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -6,10 +6,11 @@ from ndsl import QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile -from ndsl.config import Backend +from ndsl.config import Backend, BackendLoopOrder from ndsl.constants import I_DIM, J_DIM, K_DIM, Float from ndsl.dsl.gt4py import FORWARD, computation, interval from ndsl.dsl.typing import FloatField, FloatFieldIJ +from ndsl.stencils import copy from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge @@ -18,6 +19,14 @@ def stencil_simple_2D_write(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> out_fieldIJ = in_field +def stencil_multiple_2D_write( + in_field: FloatField, out_fieldIJ: FloatFieldIJ, out_fieldIJ_2: FloatFieldIJ +) -> None: + with computation(FORWARD), interval(0, 1): + out_fieldIJ = in_field + out_fieldIJ_2 = in_field + 1.0 + + def stencil_2D_write_at_K(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None: with computation(FORWARD), interval(-1, None): out_fieldIJ = in_field @@ -29,13 +38,15 @@ def stencil_forward_at_K(in_field: FloatField, out_field: FloatField) -> None: class OrchestratedCode: - def __init__( - self, - stencil_factory: StencilFactory, - quantity_factory: QuantityFactory, - ) -> None: - orchestratable_methods = ["write_at_0", "write_at_top", "do_not_inline"] - for method in orchestratable_methods: + def __init__(self, stencil_factory: StencilFactory) -> None: + methods_to_orchestrate = [ + "write_at_0", + "write_at_top", + "do_not_inline", + "combined_stencils", + "multiple_statements", + ] + for method in methods_to_orchestrate: orchestrate( obj=self, config=stencil_factory.config.dace_config, @@ -54,6 +65,14 @@ def __init__( func=stencil_forward_at_K, compute_dims=[I_DIM, J_DIM, K_DIM], ) + self.stencil_copy = stencil_factory.from_dims_halo( + func=copy, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) + self.stencil_multiple_2D_write = stencil_factory.from_dims_halo( + func=stencil_multiple_2D_write, + compute_dims=[I_DIM, J_DIM, K_DIM], + ) def write_at_0( self, @@ -76,6 +95,18 @@ def do_not_inline( ) -> None: self.stencil_do_not_inline(in_field, out_field) + def combined_stencils( + self, field: FloatField, field2: FloatField, fieldIJ: FloatFieldIJ + ) -> None: + self.stencil_copy(field, field2) + self.stencil_simple_2D_write(field2, fieldIJ) + + def multiple_statements( + self, in_field: FloatField, out_field: FloatFieldIJ, out_field2: FloatFieldIJ + ) -> None: + self.stencil_copy(in_field, in_field) + self.stencil_multiple_2D_write(in_field, out_field, out_field2) + Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] @@ -89,14 +120,10 @@ def factories(self, request) -> Factories: domain[0], domain[1], domain[2], 0, backend=Backend(request.param) ) - @pytest.fixture - def code(self, factories: Factories) -> OrchestratedCode: - return OrchestratedCode(*factories) - - def test_common_2D_write( - self, code: OrchestratedCode, factories: Factories - ) -> None: + def test_common_2D_write(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, 0] = Float(32.0) @@ -120,8 +147,10 @@ def test_common_2D_write( assert len(all_loop_region) == 0 assert (out_qty.field[:] == Float(32.0)).all() - def test_2D_write_K_top(self, code: OrchestratedCode, factories: Factories) -> None: + def test_2D_write_K_top(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, -1] = Float(32.0) @@ -145,8 +174,10 @@ def test_2D_write_K_top(self, code: OrchestratedCode, factories: Factories) -> N assert len(all_loop_region) == 0 assert (out_qty.field[:] == Float(32.0)).all() - def test_do_not_inline(self, code: OrchestratedCode, factories: Factories) -> None: + def test_do_not_inline(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") @@ -168,3 +199,67 @@ def test_do_not_inline(self, code: OrchestratedCode, factories: Factories) -> No assert len(all_maps) == 2 assert len(all_loop_region) == 1 assert (out_qty.field[:] == Float(1)).all() + + def test_combined_stencils(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + field_2 = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") + + with StreeOptimization(): + code.combined_stencils(field, field_2, field_IJ) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert ( + len(all_maps) == 3 + if stencil_factory.backend.loop_order == BackendLoopOrder.IJK + else 5 + ) + assert len(all_loop_region) == 0 + assert (field_IJ.field[:] == Float(1)).all() + + def test_multiple_statements(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") + field_IJ_2 = quantity_factory.zeros([I_DIM, J_DIM], "") + + field.field[:, :, 0] = Float(42.0) + with StreeOptimization(): + code.multiple_statements(field, field_IJ, field_IJ_2) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + all_maps = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, nodes.MapEntry) + ] + all_loop_region = [ + (me, state) + for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(me, LoopRegion) + ] + + assert ( + len(all_maps) == 3 + if stencil_factory.backend.loop_order == BackendLoopOrder.IJK + else 5 + ) + assert len(all_loop_region) == 0 + assert (field_IJ.field[:] == Float(42.0)).all() + assert (field_IJ_2.field[:] == Float(43.0)).all() From 454fb44867fa21a95018344be8e0136ef4c5419e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 22 May 2026 14:13:11 +0200 Subject: [PATCH 021/101] dace update: connect source/sink nodes with empty memlets --- external/dace | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/dace b/external/dace index ec81b1a0..44657753 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit ec81b1a0c2a872da8dd315378ff6a9ac67d5458b +Subproject commit 44657753cef3c0ce3ef9deef9d0c81e0e7314b1e From 9ba2664b62c1cdfc23790cd175bb35260ec54dc9 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 22 May 2026 16:55:01 +0200 Subject: [PATCH 022/101] dace update: support for self-assigning copy nodes --- external/dace | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/dace b/external/dace index 44657753..99a9360d 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 44657753cef3c0ce3ef9deef9d0c81e0e7314b1e +Subproject commit 99a9360d35f458c328b204860d59a365522484ab From f8798a08a469de196643f06992859fc464b270d3 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 22 May 2026 14:20:40 -0400 Subject: [PATCH 023/101] GPU tree orchestration pipeline - Local are no longer transient on GPU - RefineTransients is deactivated --- ndsl/dsl/dace/orchestration.py | 35 +++++++++++++++++++++++++-------- ndsl/dsl/dace/stree/pipeline.py | 21 +++++++++++++++++--- ndsl/quantity/local.py | 3 ++- 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index e31298d7..5b1de8e8 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -6,19 +6,20 @@ from pathlib import Path from typing import Any -from dace import SDFG, CompiledSDFG +from dace import SDFG, CompiledSDFG, DeviceType from dace import compiletime as DaceCompiletime from dace import dtypes from dace import method as dace_method from dace import nodes from dace import program as dace_program from dace.dtypes import DeviceType as DaceDeviceType +from dace.dtypes import ScheduleType from dace.dtypes import StorageType as DaceStorageType from dace.frontend.python.common import SDFGConvertible from dace.frontend.python.parser import DaceProgram from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.transformation.auto.auto_optimize import make_transients_persistent -from dace.transformation.dataflow import MapExpansion +from dace.transformation.dataflow import MapCollapse, MapExpansion from dace.transformation.helpers import get_parent_map from gt4py import storage as gt_storage @@ -37,7 +38,7 @@ negative_qtracers_checker, sdfg_nan_checker, ) -from ndsl.dsl.dace.stree import CPUPipeline +from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -181,7 +182,18 @@ def _build_sdfg( # Here be 🐉 - but tests exists in test_optimization.py with DaCeProgress(config, "Schedule Tree: generate from SDFG"): # Break all loops into uni-dimensional loops to simplify optimizations - sdfg.apply_transformations_repeated(MapExpansion, validate=True) + sdfg.apply_transformations_repeated( + MapExpansion, + options={ + "inner_schedule": ( + ScheduleType.GPU_Device + if device_type is DeviceType.GPU + else ScheduleType.Default + ) + }, + validate=True, + print_report=True, + ) stree = sdfg.as_schedule_tree() if config.verbose_orchestration: with open( @@ -191,10 +203,16 @@ def _build_sdfg( f.write(stree.as_string()) with DaCeProgress(config, "Schedule Tree: optimization"): - CPUPipeline( - backend=backend_name, - cache_directory=Path(sdfg.build_folder), - ).run(stree, verbose=config.verbose_schedule_tree_optimizations) + if device_type == device_type.CPU: + CPUPipeline( + backend=backend_name, + cache_directory=Path(sdfg.build_folder), + ).run(stree, verbose=config.verbose_schedule_tree_optimizations) + elif device_type == DeviceType.GPU: + GPUPipeline( + backend=backend_name, + cache_directory=Path(sdfg.build_folder), + ).run(stree, verbose=config.verbose_schedule_tree_optimizations) if config.verbose_orchestration: with open( os.path.abspath(f"{sdfg.build_folder}/03-post_opt.stree.txt"), @@ -209,6 +227,7 @@ def _build_sdfg( os.path.abspath(f"{sdfg.build_folder}/04-from_stree.sdfgz"), compress=True, ) + sdfg.apply_transformations_repeated(MapCollapse) # Make the transients array persistents if config.is_gpu_backend(): diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 13e30974..9833b2eb 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -17,7 +17,7 @@ class StreePipeline: def __init__( self, *, - passes: list[stree.ScheduleNodeTransformer], + passes: list[stree.ScheduleNodeVisitor], cache_directory: Path | None = None, ) -> None: if cache_directory is None: @@ -64,7 +64,7 @@ def __init__( self, backend: Backend, *, - passes: list[stree.ScheduleNodeTransformer] | None = None, + passes: list[stree.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: if passes is None: @@ -83,9 +83,24 @@ def __init__( class GPUPipeline(StreePipeline): def __init__( self, - passes: list[stree.ScheduleNodeTransformer] | None = None, + backend: Backend, + *, + passes: list[stree.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: + if passes is None: + passes = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(backend), + # 🐞 Transient refine can't be used + # because of bugs transients showing in code generation + # CartesianRefineTransients(backend), + ] + super().__init__( + passes=passes, + cache_directory=cache_directory, + ) super().__init__( passes=passes if passes is not None else [], cache_directory=cache_directory, diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py index f69480a7..37aee3eb 100644 --- a/ndsl/quantity/local.py +++ b/ndsl/quantity/local.py @@ -31,6 +31,7 @@ def __init__( # Initialize memory to obviously wrong value - Local should _not_ be expected # to be zero'ed. data[:] = 123456789 + self._on_gpu = backend.is_gpu_backend() super().__init__( data, @@ -45,5 +46,5 @@ def __init__( def __descriptor__(self) -> Any: """Locals uses `Quantity.__descriptor__` and flag itself as transient.""" data = dace.data.create_datadescriptor(self._data) - data.transient = True + data.transient = True if not self._on_gpu else False return data From 89294d2b5e520771a7f8af1d708ac36c181208f7 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 22 May 2026 14:23:32 -0400 Subject: [PATCH 024/101] Add scalarized array to tree statistics --- .../dsl/dace/stree/optimizations/statistics.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/statistics.py b/ndsl/dsl/dace/stree/optimizations/statistics.py index ebef36fe..6e5fe3af 100644 --- a/ndsl/dsl/dace/stree/optimizations/statistics.py +++ b/ndsl/dsl/dace/stree/optimizations/statistics.py @@ -34,20 +34,22 @@ def visit_ForScope(self, node: stree.ForScope) -> None: class CountTransient(stree.ScheduleNodeVisitor): def __init__(self) -> None: super().__init__() - self._counts = [0, 0, 0, 0] + self._counts = [0, 0, 0, 0, 0] def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: for data in node.containers.values(): non_atomic_dims_count = sum(1 for x in data.shape if x != 1) if isinstance(data, dace.data.Array) and data.transient: - if non_atomic_dims_count == 1: + if non_atomic_dims_count == 0: self._counts[0] += 1 - elif non_atomic_dims_count == 2: + elif non_atomic_dims_count == 1: self._counts[1] += 1 - elif non_atomic_dims_count == 3: + elif non_atomic_dims_count == 2: self._counts[2] += 1 - else: + elif non_atomic_dims_count == 3: self._counts[3] += 1 + else: + self._counts[4] += 1 class TreeOptimizationStatistics: @@ -59,7 +61,9 @@ class Record: cartesian_maps: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) cartesian_fors: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) - transients: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0, 0]) + transients: list[int] = dataclasses.field( + default_factory=lambda: [0, 0, 0, 0, 0] + ) def __init__(self) -> None: self._original_record = TreeOptimizationStatistics.Record() @@ -93,5 +97,5 @@ def report(self) -> str: msg = "Tree optimization:\n" msg += f" Cartesian maps [I, J, K]: {self._original_record.cartesian_maps} -> {self._optimized_record.cartesian_maps}\n" msg += f" Cartesian fors [I, J, K]: {self._original_record.cartesian_fors} -> {self._optimized_record.cartesian_fors}\n" - msg += f" Transients [1D, 2D, 3D, 4D+]: {self._original_record.transients} -> {self._optimized_record.transients}\n" + msg += f" Transients [Scalarized Array, 1D, 2D, 3D, 4D+]: {self._original_record.transients} -> {self._optimized_record.transients}\n" return msg From 456b5fbda02c22ea5d8735ee00ee630dbd460cb5 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 22 May 2026 17:49:35 -0400 Subject: [PATCH 025/101] Replace `AxisSymbol` in "masklet as well + rename file --- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 2 +- ndsl/dsl/dace/stree/optimizations/remove_loops.py | 2 +- ...replace_symbol_in_tasklet.py => replace_axis_symbol.py} | 7 +++++++ 3 files changed, 9 insertions(+), 2 deletions(-) rename ndsl/dsl/dace/stree/optimizations/{replace_symbol_in_tasklet.py => replace_axis_symbol.py} (77%) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 3f10f122..e0867ede 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -16,7 +16,7 @@ no_data_dependencies_on_cartesian_axis, swap_node_position_in_tree, ) -from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( +from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ( ReplaceAxisSymbol, ) diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index c02f8af5..76b6dc54 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -10,7 +10,7 @@ is_axis_for, list_index, ) -from ndsl.dsl.dace.stree.optimizations.replace_symbol_in_tasklet import ( +from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ( ReplaceAxisSymbol, ) diff --git a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py similarity index 77% rename from ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py rename to ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py index dbc26eb0..b7f3d548 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_symbol_in_tasklet.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py @@ -14,6 +14,13 @@ def visit_TaskletNode(self, node: tn.TaskletNode) -> None: ): memlet.replace(self._axis_replacements) + if node.node.label.startswith("masklet"): + for old, new in self._axis_replacements.items(): + node.node.code.as_string = node.node.code.as_string.replace( + str(old), str(new) + ) + + def visit_IfScope(self, node: tn.IfScope) -> None: for old, new in self._axis_replacements.items(): node.condition.as_string = node.condition.as_string.replace( From 0aaa78d69070dbe601f1fd79e531671f8674b818 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 22 May 2026 17:49:53 -0400 Subject: [PATCH 026/101] Deactivate `InlineVertical2DWrite` for now --- ndsl/dsl/dace/stree/pipeline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 9833b2eb..ed3d4d77 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -70,7 +70,8 @@ def __init__( if passes is None: passes = [ CleanUpScheduleTree(), - InlineVertical2DWrite(), + # TODO: Is it safe? Deactivate for now + # InlineVertical2DWrite(), CartesianMerge(backend), CartesianRefineTransients(backend), ] @@ -91,7 +92,8 @@ def __init__( if passes is None: passes = [ CleanUpScheduleTree(), - InlineVertical2DWrite(), + # TODO: Is it safe? Deactivate for now + # InlineVertical2DWrite(), CartesianMerge(backend), # 🐞 Transient refine can't be used # because of bugs transients showing in code generation From 634a097197072d6ba4e654337b733be4abda92ae Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 22 May 2026 17:50:57 -0400 Subject: [PATCH 027/101] Lint --- ndsl/dsl/dace/stree/optimizations/axis_merge.py | 4 +--- ndsl/dsl/dace/stree/optimizations/remove_loops.py | 4 +--- ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py | 1 - ndsl/dsl/dace/stree/pipeline.py | 1 - 4 files changed, 2 insertions(+), 8 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index e0867ede..d875082e 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -16,9 +16,7 @@ no_data_dependencies_on_cartesian_axis, swap_node_position_in_tree, ) -from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ( - ReplaceAxisSymbol, -) +from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ReplaceAxisSymbol def _both_same_single_axis_maps( diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py index 76b6dc54..89716404 100644 --- a/ndsl/dsl/dace/stree/optimizations/remove_loops.py +++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py @@ -10,9 +10,7 @@ is_axis_for, list_index, ) -from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ( - ReplaceAxisSymbol, -) +from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ReplaceAxisSymbol class InlineVertical2DWrite(tn.ScheduleNodeVisitor): diff --git a/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py index b7f3d548..c04c2fc5 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py @@ -19,7 +19,6 @@ def visit_TaskletNode(self, node: tn.TaskletNode) -> None: node.node.code.as_string = node.node.code.as_string.replace( str(old), str(new) ) - def visit_IfScope(self, node: tn.IfScope) -> None: for old, new in self._axis_replacements.items(): diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index ed3d4d77..cb3a2ec5 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -7,7 +7,6 @@ CartesianMerge, CartesianRefineTransients, CleanUpScheduleTree, - InlineVertical2DWrite, ) from ndsl.dsl.dace.stree.optimizations.statistics import TreeOptimizationStatistics from ndsl.logging import ndsl_log_on_rank_0 From 02102af6ac7c3fa421ebf5cfc81a5902c0d11ee2 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 26 May 2026 10:23:40 +0200 Subject: [PATCH 028/101] Fix tests after collapsing maps / fix non-cartesian loop inline Also adds infrastructure to override the orchestration pipeline in tests (used to allow testing `InlineVertical2Dwrite`). --- ndsl/dsl/dace/orchestration.py | 39 ++++++++---- .../dace/stree/optimizations/axis_merge.py | 2 +- ndsl/dsl/dace/stree/pipeline.py | 12 ++-- .../dace/stree/optimizations/test_merge.py | 20 +++---- .../test_offgrid_conditionals.py | 14 ++--- .../stree/optimizations/test_remove_loops.py | 60 +++++++++++++++---- tests/dsl/dace/stree/sdfg_stree_tools.py | 8 +++ 7 files changed, 107 insertions(+), 48 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 5b1de8e8..c88212c2 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -24,6 +24,7 @@ from gt4py import storage as gt_storage import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements +from ndsl import Backend from ndsl.comm.mpi import MPI from ndsl.dsl.dace.build import get_sdfg_path, write_build_info from ndsl.dsl.dace.dace_config import ( @@ -39,6 +40,7 @@ sdfg_nan_checker, ) from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline +from ndsl.dsl.dace.stree.pipeline import StreePipeline from ndsl.dsl.dace.utils import ( DaCeProgress, memory_static_analysis, @@ -54,6 +56,8 @@ ) """INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" +_INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES: list[tn.ScheduleNodeVisitor] | None = None + def dace_inhibitor(func: Callable) -> Callable: """Triggers callback generation wrapping `func` while doing DaCe parsing.""" @@ -143,6 +147,24 @@ def _tree_as_sdfg(stree: tn.ScheduleTreeRoot) -> SDFG: return stree.as_sdfg(skip={"ScalarToSymbolPromotion", "ControlFlowRaising"}) +def _optimization_pipeline( + device_type: DeviceType, + backend: Backend, + *, + passes: list[tn.ScheduleNodeVisitor] | None = None, + cache_directory: Path | None = None, +) -> StreePipeline: + if device_type == device_type.CPU: + return CPUPipeline(backend, passes=passes, cache_directory=cache_directory) + + if device_type == DeviceType.GPU: + return GPUPipeline(backend, passes=passes, cache_directory=cache_directory) + + raise ValueError( + f"Unknown device type `{device_type}`, expected {DeviceType.CPU} or {DeviceType.GPU}." + ) + + def _build_sdfg( dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, args: Any, kwargs: Any ) -> None: @@ -203,16 +225,13 @@ def _build_sdfg( f.write(stree.as_string()) with DaCeProgress(config, "Schedule Tree: optimization"): - if device_type == device_type.CPU: - CPUPipeline( - backend=backend_name, - cache_directory=Path(sdfg.build_folder), - ).run(stree, verbose=config.verbose_schedule_tree_optimizations) - elif device_type == DeviceType.GPU: - GPUPipeline( - backend=backend_name, - cache_directory=Path(sdfg.build_folder), - ).run(stree, verbose=config.verbose_schedule_tree_optimizations) + pipeline = _optimization_pipeline( + device_type, + backend_name, + cache_directory=Path(sdfg.build_folder), + passes=_INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES, + ) + pipeline.run(stree, verbose=config.verbose_schedule_tree_optimizations) if config.verbose_orchestration: with open( os.path.abspath(f"{sdfg.build_folder}/03-post_opt.stree.txt"), diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index d875082e..0c5a476f 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -134,7 +134,7 @@ def _merge_node( def _for_merge(self, the_for_scope: tn.ForScope) -> int: merged = 0 - if is_axis_for(the_for_scope, self.axis): + if is_axis_for(the_for_scope, AxisIterator._K): # TODO: if the for scope is on a cartesian axis it can be # merged with other for scope going in the same direction pass diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index cb3a2ec5..f44a399e 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -1,6 +1,6 @@ from pathlib import Path -import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import Backend from ndsl.dsl.dace.stree.optimizations import ( @@ -16,7 +16,7 @@ class StreePipeline: def __init__( self, *, - passes: list[stree.ScheduleNodeVisitor], + passes: list[tn.ScheduleNodeVisitor], cache_directory: Path | None = None, ) -> None: if cache_directory is None: @@ -33,9 +33,9 @@ def __repr__(self) -> str: def run( self, - stree: stree.ScheduleTreeRoot, + stree: tn.ScheduleTreeRoot, verbose: bool = False, - ) -> stree.ScheduleTreeRoot: + ) -> tn.ScheduleTreeRoot: tree_stats = TreeOptimizationStatistics() tree_stats.original(stree) @@ -63,7 +63,7 @@ def __init__( self, backend: Backend, *, - passes: list[stree.ScheduleNodeVisitor] | None = None, + passes: list[tn.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: if passes is None: @@ -85,7 +85,7 @@ def __init__( self, backend: Backend, *, - passes: list[stree.ScheduleNodeVisitor] | None = None, + passes: list[tn.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: if passes is None: diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index d57e758a..2e76029c 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -160,7 +160,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 + assert len(all_maps) == 1 # all merged and collapsed assert (out_qty.field[:] == 2).all() def test_missing_merge_of_forscope_and_map( @@ -179,7 +179,7 @@ def test_missing_merge_of_forscope_and_map( for map_entry, _ in sdfg.all_nodes_recursive() if isinstance(map_entry, nodes.MapEntry) ] - assert len(all_maps) == 4 # 2 IJ + 2 Ks + assert len(all_maps) == 3 # 1 IJ + 2 Ks all_loops = [ loop for loop, _ in sdfg.all_nodes_recursive() @@ -203,7 +203,7 @@ def test_overcompute_merge( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All maps merged + assert len(all_maps) == 1 # All maps merged and collapsed def test_block_merge_when_dependencies_are_found( self, code: OrchestratedCode, factories: Factories @@ -222,7 +222,7 @@ def test_block_merge_when_dependencies_are_found( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 4 # 2 IJ + 2 Ks (un-merged) + assert len(all_maps) == 3 # 1 IJ + 2 Ks (un-merged) def test_push_non_cartesian_for( self, code: OrchestratedCode, factories: Factories @@ -242,7 +242,7 @@ def test_push_non_cartesian_for( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All merged + assert len(all_maps) == 1 # All merged & collapsed for_loops = [ node for node, _ in sdfg.all_nodes_recursive() @@ -278,7 +278,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 + assert len(all_maps) == 1 # all maps merged and collapsed assert (out_qty.field[:] == 2).all() def test_missing_merge_of_forscope_and_map( @@ -298,7 +298,7 @@ def test_missing_merge_of_forscope_and_map( for map_entry, _ in sdfg.all_nodes_recursive() if isinstance(map_entry, nodes.MapEntry) ] - assert len(all_maps) == 8 # 2 KJI (all maps) + 1 for scope + assert len(all_maps) == 3 # 2 KJI (all maps) + 1 JI all_loops = [ loop for loop, _ in sdfg.all_nodes_recursive() @@ -323,7 +323,7 @@ def test_overcompute_merge( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All maps merged + assert len(all_maps) == 1 # All maps merged & collapsed def test_block_merge_when_dependencies_are_found( self, code: OrchestratedCode, factories: Factories @@ -342,7 +342,7 @@ def test_block_merge_when_dependencies_are_found( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 6 # 2 * KJI + assert len(all_maps) == 2 # 2 * KJI def test_push_non_cartesian_for( self, code: OrchestratedCode, factories: Factories @@ -362,7 +362,7 @@ def test_push_non_cartesian_for( for me, state in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 # All merged + assert len(all_maps) == 1 # All merged and collapsed for_loops = [ node for node, _ in sdfg.all_nodes_recursive() diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index f897173f..f58c90f7 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -88,14 +88,13 @@ def test_happy_case(self, factories: Factories) -> None: code.happy_case(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - assert precompiled_sdfg.sdfg all_maps = [ (me, state) for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 + assert len(all_maps) == 1 # all merged and collapsed def test_happy_case_2(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories @@ -108,14 +107,13 @@ def test_happy_case_2(self, factories: Factories) -> None: code.happy_case_2(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - assert precompiled_sdfg.sdfg all_maps = [ (me, state) for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 3 + assert len(all_maps) == 1 # all merged and collapsed def test_blocked_by_else(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories @@ -128,14 +126,13 @@ def test_blocked_by_else(self, factories: Factories) -> None: code.blocked_by_else(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - assert precompiled_sdfg.sdfg all_maps = [ (me, state) for me, state in precompiled_sdfg.sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) ] - assert len(all_maps) == 9 + assert len(all_maps) == 3 # 3 * IJK/KJI def test_blocked_by_other_nodes(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories @@ -148,7 +145,6 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: code.blocked_by_other_nodes(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - assert precompiled_sdfg.sdfg all_maps = [ (me, state) @@ -157,7 +153,7 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: ] # ⚠️ Dev note: - # This should be just `assert len(all_maps) == 6`, but currently, the K-loops + # This should be just `assert len(all_maps) == 2`, but currently, the K-loops # can't merge because the K-iterators are different. To be fixed (and simplified # here) with a subsequent commit. - assert len(all_maps) == 9 + assert len(all_maps) == 3 diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 4af0474a..da38e890 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -8,6 +8,12 @@ from ndsl.boilerplate import get_factories_single_tile from ndsl.config import Backend, BackendLoopOrder from ndsl.constants import I_DIM, J_DIM, K_DIM, Float +from ndsl.dsl.dace.stree.optimizations import InlineVertical2DWrite +from ndsl.dsl.dace.stree.pipeline import ( + CartesianMerge, + CartesianRefineTransients, + CleanUpScheduleTree, +) from ndsl.dsl.gt4py import FORWARD, computation, interval from ndsl.dsl.typing import FloatField, FloatFieldIJ from ndsl.stencils import copy @@ -123,12 +129,18 @@ def factories(self, request) -> Factories: def test_common_2D_write(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, 0] = Float(32.0) - with StreeOptimization(): + with StreeOptimization(passes=pipeline): code.write_at_0(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -143,19 +155,25 @@ def test_common_2D_write(self, factories: Factories) -> None: if isinstance(me, LoopRegion) ] - assert len(all_maps) == 2 + assert len(all_maps) == 1 # IJ/JI collapsed assert len(all_loop_region) == 0 assert (out_qty.field[:] == Float(32.0)).all() def test_2D_write_K_top(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, -1] = Float(32.0) - with StreeOptimization(): + with StreeOptimization(passes=pipeline): code.write_at_top(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -170,18 +188,24 @@ def test_2D_write_K_top(self, factories: Factories) -> None: if isinstance(me, LoopRegion) ] - assert len(all_maps) == 2 + assert len(all_maps) == 1 # IJ/JI collapsed assert len(all_loop_region) == 0 assert (out_qty.field[:] == Float(32.0)).all() def test_do_not_inline(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): + with StreeOptimization(passes=pipeline): code.do_not_inline(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -196,19 +220,25 @@ def test_do_not_inline(self, factories: Factories) -> None: if isinstance(me, LoopRegion) ] - assert len(all_maps) == 2 + assert len(all_maps) == 1 # IJ/JI collapsed assert len(all_loop_region) == 1 assert (out_qty.field[:] == Float(1)).all() def test_combined_stencils(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") field_2 = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") - with StreeOptimization(): + with StreeOptimization(passes=pipeline): code.combined_stencils(field, field_2, field_IJ) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -224,9 +254,9 @@ def test_combined_stencils(self, factories: Factories) -> None: ] assert ( - len(all_maps) == 3 + len(all_maps) == 2 # IJ + K if stencil_factory.backend.loop_order == BackendLoopOrder.IJK - else 5 + else 2 # KJI + JI ) assert len(all_loop_region) == 0 assert (field_IJ.field[:] == Float(1)).all() @@ -234,13 +264,19 @@ def test_combined_stencils(self, factories: Factories) -> None: def test_multiple_statements(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) + pipeline = [ + CleanUpScheduleTree(), + InlineVertical2DWrite(), + CartesianMerge(stencil_factory.backend), + CartesianRefineTransients(stencil_factory.backend), + ] field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") field_IJ_2 = quantity_factory.zeros([I_DIM, J_DIM], "") field.field[:, :, 0] = Float(42.0) - with StreeOptimization(): + with StreeOptimization(passes=pipeline): code.multiple_statements(field, field_IJ, field_IJ_2) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -256,9 +292,9 @@ def test_multiple_statements(self, factories: Factories) -> None: ] assert ( - len(all_maps) == 3 + len(all_maps) == 2 # IJ + K if stencil_factory.backend.loop_order == BackendLoopOrder.IJK - else 5 + else 2 # KJI + JI ) assert len(all_loop_region) == 0 assert (field_IJ.field[:] == Float(42.0)).all() diff --git a/tests/dsl/dace/stree/sdfg_stree_tools.py b/tests/dsl/dace/stree/sdfg_stree_tools.py index 6c664205..b913a134 100644 --- a/tests/dsl/dace/stree/sdfg_stree_tools.py +++ b/tests/dsl/dace/stree/sdfg_stree_tools.py @@ -1,6 +1,7 @@ from types import TracebackType import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn import ndsl.dsl.dace.orchestration as orch from ndsl import StencilFactory @@ -21,8 +22,14 @@ def get_SDFG_and_purge(stencil_factory: StencilFactory) -> dace.CompiledSDFG: class StreeOptimization: + def __init__(self, *, passes: list[tn.ScheduleNodeVisitor] | None = None) -> None: + self.passes = passes + def __enter__(self) -> None: + self.original_passes = orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES = self.passes def __exit__( self, @@ -31,3 +38,4 @@ def __exit__( exc_tb: TracebackType | None, ) -> None: orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False + orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES = self.original_passes From 43674afc9507194ba07bf925974731f9d331e3bd Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 28 May 2026 17:43:21 +0200 Subject: [PATCH 029/101] fixes to run GFLD_1M with orch:dace:cpu:KJI backend --- external/gt4py | 2 +- ndsl/dsl/stencil.py | 14 +- .../dsl/dace/stree/optimizations/__init__.py | 6 + .../dace/stree/optimizations/test_merge.py | 6 +- .../test_offgrid_conditionals.py | 17 +- .../stree/optimizations/test_remove_loops.py | 10 +- tests/dsl/orchestration/test_boundaries_k.py | 196 ++++++++++++++++++ 7 files changed, 219 insertions(+), 32 deletions(-) create mode 100644 tests/dsl/orchestration/test_boundaries_k.py diff --git a/external/gt4py b/external/gt4py index 7ba05d5d..08100b85 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 7ba05d5dc03c3d140c9074bc2b5f8e8027832842 +Subproject commit 08100b8505a6ce655a8b71043da514f3a6b8634a diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index a00adebb..26f3dce4 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -881,6 +881,8 @@ def _origin_from_dims(self, dims: Iterable[str]) -> list[int]: return_origin.append(self.origin[1]) elif dim in K_DIMS: return_origin.append(self.origin[2]) + else: + raise ValueError(f"Unknown dimension '{dim}'.") return return_origin def _domain_from_dims(self, dimensions: Iterable[str]) -> list[int]: @@ -888,16 +890,18 @@ def _domain_from_dims(self, dimensions: Iterable[str]) -> list[int]: for dimension in dimensions: if dimension == I_DIM: result.append(self.domain[0]) - if dimension == I_INTERFACE_DIM: + elif dimension == I_INTERFACE_DIM: result.append(self.domain[0] + 1) - if dimension == J_DIM: + elif dimension == J_DIM: result.append(self.domain[1]) - if dimension == J_INTERFACE_DIM: + elif dimension == J_INTERFACE_DIM: result.append(self.domain[1] + 1) - if dimension == K_DIM: + elif dimension == K_DIM: result.append(self.domain[2]) - if dimension == K_INTERFACE_DIM: + elif dimension == K_INTERFACE_DIM: result.append(self.domain[2] + 1) + else: + raise ValueError(f"Unknown dimension '{dimension}'.") return result def get_shape( diff --git a/tests/dsl/dace/stree/optimizations/__init__.py b/tests/dsl/dace/stree/optimizations/__init__.py index e69de29b..e0e56d60 100644 --- a/tests/dsl/dace/stree/optimizations/__init__.py +++ b/tests/dsl/dace/stree/optimizations/__init__.py @@ -0,0 +1,6 @@ +from typing import TypeAlias + +from ndsl import QuantityFactory, StencilFactory + + +Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index 2e76029c..1a9ed508 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -1,5 +1,3 @@ -from typing import TypeAlias - import dace import pytest from dace import nodes @@ -13,6 +11,7 @@ from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories def stencil(in_field: FloatField, out_field: FloatField) -> None: @@ -130,9 +129,6 @@ def push_non_cartesian_for( self.stencil(in_field, out_field) -Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] - - class TestStreeMergeMapsIJK: @pytest.fixture def factories(self) -> Factories: diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index f58c90f7..fcfe33bc 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -1,20 +1,12 @@ -from typing import TypeAlias - import pytest from dace import nodes -from ndsl import ( - Backend, - NDSLRuntime, - QuantityFactory, - StencilFactory, - orchestrate, - stencils, -) +from ndsl import Backend, NDSLRuntime, StencilFactory, orchestrate, stencils from ndsl.boilerplate import get_factories_single_tile from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.typing import FloatField from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories class OrchestratedCode(NDSLRuntime): @@ -66,12 +58,9 @@ def blocked_by_other_nodes( self._copy_stencil(in_field, out_field) -Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] - - class TestStreeInlineOffgridConditionals: @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) - def factories(self, request) -> Factories: + def factories(self, request: pytest.FixtureRequest) -> Factories: domain = (3, 3, 4) return get_factories_single_tile( domain[0], domain[1], domain[2], 0, backend=Backend(request.param) diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index da38e890..9469f204 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -1,10 +1,8 @@ -from typing import TypeAlias - import pytest from dace import nodes from dace.sdfg.state import LoopRegion -from ndsl import QuantityFactory, StencilFactory, orchestrate +from ndsl import StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile from ndsl.config import Backend, BackendLoopOrder from ndsl.constants import I_DIM, J_DIM, K_DIM, Float @@ -18,6 +16,7 @@ from ndsl.dsl.typing import FloatField, FloatFieldIJ from ndsl.stencils import copy from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories def stencil_simple_2D_write(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None: @@ -114,12 +113,9 @@ def multiple_statements( self.stencil_multiple_2D_write(in_field, out_field, out_field2) -Factories: TypeAlias = tuple[StencilFactory, QuantityFactory] - - class TestStree2DWriteInline: @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) - def factories(self, request) -> Factories: + def factories(self, request: pytest.FixtureRequest) -> Factories: domain = (3, 3, 4) return get_factories_single_tile( diff --git a/tests/dsl/orchestration/test_boundaries_k.py b/tests/dsl/orchestration/test_boundaries_k.py new file mode 100644 index 00000000..80ea4a84 --- /dev/null +++ b/tests/dsl/orchestration/test_boundaries_k.py @@ -0,0 +1,196 @@ +import numpy as np +import pytest + +from ndsl import Backend, NDSLRuntime, StencilFactory, orchestrate +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import I_DIM, J_DIM, K_DIM, K_INTERFACE_DIM +from ndsl.dsl.gt4py import BACKWARD, FORWARD, computation, interval +from ndsl.dsl.typing import FloatField +from tests.dsl.dace.stree.optimizations import Factories + + +def accumulate_down(in_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(BACKWARD): + # handle top layer separately + with interval(-1, None): + out_field = in_field + + # accumulate "downwards" + with interval(0, -1): + out_field = out_field[0, 0, 1] + in_field + + +def accumulate_down_from_interface_field(interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(BACKWARD): + # handle top layer separately + with interval(-1, None): + out_field = interface_field + interface_field[0, 0, 1] + + # accumulate "downwards" + with interval(0, -1): + out_field = out_field[0, 0, 1] + interface_field + + +def accumulate_on_interface(interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(BACKWARD): + # handle top layer separately + with interval(-2, -1): + out_field = interface_field + interface_field[0, 0, 1] + + # accumulate "downwards" + with interval(0, -2): + out_field = out_field[0, 0, 1] + interface_field + + +def accumulate_up(in_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(FORWARD): + # handle bottom layer separately + with interval(0, 1): + out_field = in_field + + # accumulate "upwards" + with interval(1, None): + out_field = out_field[0, 0, -1] + in_field + + +def accumulate_up_interface(in_field: FloatField, interface_field: FloatField) -> None: # type: ignore + with computation(FORWARD): + # handle bottom layer separately + with interval(0, 1): + interface_field = in_field + + # accumulate "upwards" + with interval(1, None): + interface_field = interface_field[0, 0, -1] + in_field[0, 0, -1] + + +class OrchestratedCode(NDSLRuntime): + def __init__(self, stencil_factory: StencilFactory) -> None: + super().__init__(stencil_factory) + + methods_to_orchestrate = [ + "accumulate_down", + "accumulate_down_from_interface_field", + "accumulate_on_interface", + "accumulate_up", + "accumulate_up_interface", + ] + + for method in methods_to_orchestrate: + orchestrate( + obj=self, + method_to_orchestrate=method, + config=stencil_factory.config.dace_config, + ) + + self._accumulate_down = stencil_factory.from_dims_halo( + func=accumulate_down, compute_dims=(I_DIM, J_DIM, K_DIM) + ) + + self._accumulate_down_from_interface_field = stencil_factory.from_dims_halo( + func=accumulate_down_from_interface_field, + compute_dims=(I_DIM, J_DIM, K_DIM), + ) + + self._accumulate_on_interface = stencil_factory.from_dims_halo( + func=accumulate_on_interface, compute_dims=(I_DIM, J_DIM, K_INTERFACE_DIM) + ) + + self._accumulate_up = stencil_factory.from_dims_halo( + func=accumulate_up, compute_dims=(I_DIM, J_DIM, K_DIM) + ) + + self._accumulate_up_interface = stencil_factory.from_dims_halo( + func=accumulate_up_interface, compute_dims=(I_DIM, J_DIM, K_INTERFACE_DIM) + ) + + def accumulate_down(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_down(in_field, out_field) + + def accumulate_down_from_interface_field(self, interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_down_from_interface_field(interface_field, out_field) + + def accumulate_on_interface(self, interface_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_on_interface(interface_field, out_field) + + def accumulate_up(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._accumulate_up(in_field, out_field) + + def accumulate_up_interface(self, in_field: FloatField, interface_field: FloatField) -> None: # type: ignore + self._accumulate_up_interface(in_field, interface_field) + + +class TestBoundariesK: + @pytest.fixture( + params=[ + "orch:dace:cpu:IJK", + "orch:dace:cpu:KJI", + "st:dace:cpu:IJK", + "st:dace:cpu:KJI", + ] + ) + def factories(self, request: pytest.FixtureRequest) -> Factories: + domain = (3, 4, 5) + return get_factories_single_tile( + nx=domain[0], + ny=domain[1], + nz=domain[2], + nhalo=0, + backend=Backend(request.param), + ) + + def test_accumulate_down(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="") + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_down(in_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [5, 4, 3, 2, 1]) + + def test_accumulate_interface_field(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + interface_field = quantity_factory.ones( + (I_DIM, J_DIM, K_INTERFACE_DIM), units="" + ) + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_down_from_interface_field(interface_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [6, 5, 4, 3, 2]) + + def test_accumulate_interface_domain(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + interface_field = quantity_factory.ones( + (I_DIM, J_DIM, K_INTERFACE_DIM), units="" + ) + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_on_interface(interface_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [6, 5, 4, 3, 2]) + + def test_accumulate_up(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="") + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="") + + code.accumulate_up(in_field, out_field) + assert np.array_equal(out_field.field[0, 0, :], [1, 2, 3, 4, 5]) + + def test_accumulate_up_interface(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="") + interface_field = quantity_factory.zeros( + (I_DIM, J_DIM, K_INTERFACE_DIM), units="" + ) + + code.accumulate_up_interface(in_field, interface_field) + assert np.array_equal(interface_field.field[0, 0, :], [1, 2, 3, 4, 5, 6]) From 1ea1314dc0a956cd66609fefcae48cea864abb1e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 29 May 2026 09:05:02 +0200 Subject: [PATCH 030/101] ci: gt4py update (restore temp dace working branch) --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index 08100b85..d19bf894 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 08100b8505a6ce655a8b71043da514f3a6b8634a +Subproject commit d19bf894f2361c26e5030facd2d06a19ea2af157 From fa60dccfe6d33a15ba2f9a60012358481c0f01ab Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 29 May 2026 10:27:36 +0200 Subject: [PATCH 031/101] remove extra `f` in result report header --- ndsl/stencils/testing/test_translate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index af9f8ae9..266e811a 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -469,7 +469,7 @@ def _report_results( os.makedirs(detail_dir, exist_ok=True) # Summary - header = f"{savepoint_name} w/ f{backend.as_humanly_readable()}" + header = f"{savepoint_name} w/ {backend.as_humanly_readable()}" lines = [] for varname, metric in results.items(): lines.append(f"{varname}: {metric.one_line_report()}") From 160923f14056f20016afdc3044c175c3ac27f83c Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 1 Jun 2026 16:24:50 +0200 Subject: [PATCH 032/101] unrelated dace/gt4py update: just test fixes and a typo --- external/dace | 2 +- external/gt4py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/dace b/external/dace index 99a9360d..4da9d096 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 99a9360d35f458c328b204860d59a365522484ab +Subproject commit 4da9d096ed3454ffa6dcb7b5233c281dc90696c2 diff --git a/external/gt4py b/external/gt4py index d19bf894..210fcbd8 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit d19bf894f2361c26e5030facd2d06a19ea2af157 +Subproject commit 210fcbd8c78800bf26421fac3c49c5b22e59d4e4 From c2bd78d5fd7763d6a047882e0248210f78f222fc Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 2 Jun 2026 11:34:31 -0400 Subject: [PATCH 033/101] Expose `gpu:IJK` backends to NDSL --- ndsl/config/backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 2807cf6a..605b86d7 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -52,6 +52,8 @@ class BackendLoopOrder(Enum): "orch:dace:cpu:KJI": "dace:cpu_KJI", "st:dace:gpu:KJI": "dace:gpu", "orch:dace:gpu:KJI": "dace:gpu", + "st:dace:gpu:IJK": "dace:gpu_IJK", + "orch:dace:gpu:IJK": "dace:gpu_IJK", } """Internal: match the NDSL backend names with the GT4Py names""" From eaaa0cc0a579d1546ed8ed57099cf88857292e0b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 3 Jun 2026 11:48:07 +0200 Subject: [PATCH 034/101] Disable DaceConfig.from_dict() as it is incomplete While the functions creates an inconstent DaceConfig by creating a config first and then tempering with some properites without re-evaluating computed properties. In particular `code_path`, `do_compile` and distributed caches are potentially out of sync with the layout information. --- ndsl/dsl/dace/dace_config.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 62b679a5..c53ebc85 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -166,8 +166,8 @@ def __init__( Args: communicator: used for setting the distributed caches backend: string for the backend - tile_nx: x/y domain size for a single time - tile_nz: z domain size for a single time + tile_nx: x/y domain size for a single tile + tile_nz: z domain size for a single tile orchestration: orchestration mode from DaCeOrchestration time: trigger performance collection, available to user with `performance_collector` @@ -412,4 +412,11 @@ def from_dict(cls, data: dict) -> Self: config.rank_size = data["rank_size"] config.layout = data["layout"] config.tile_resolution = data["tile_resolution"] - return config + # TODO + # Computed properties like `self.code_path` and `self.do_compile` + # aren't updated. + # We also don't `set_distributed_caches()` based on that updated + # information. + raise NotImplementedError( + "Implementation of `DaceConfig.from_dict()` is incomplete." + ) From 749f8a3c5387dd934e6c42ac301bf274e5d5f639 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 3 Jun 2026 11:54:44 +0200 Subject: [PATCH 035/101] readability of cache location code --- ndsl/dsl/caches/cache_location.py | 64 ++++++++++++++++--------------- ndsl/dsl/caches/codepath.py | 2 + 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py index 87d608dd..d4313815 100644 --- a/ndsl/dsl/caches/cache_location.py +++ b/ndsl/dsl/caches/cache_location.py @@ -7,46 +7,48 @@ def identify_code_path( partitioner: Partitioner, single_code_path: bool, ) -> FV3CodePath: - """Determine which code path your rank will hit. + """ + Determine which code path your rank will hit. - If single_code_path is True, single_code_path is True, - only one code path exists (case of doubly periodic grid). + If single_code_path is True, only one code path exists, + e.g. in case of a doubly periodic grid. If single_code_path is False, we are in the case of the - cube-sphere and we will look at our position on the tile.""" + cube-sphere and we will look at our position on the tile. + """ # Doubly-periodic or single tile grid - if single_code_path: + if single_code_path or partitioner.layout == (1, 1): return FV3CodePath.All # Cube-sphere - if partitioner.layout == (1, 1): - return FV3CodePath.All - elif partitioner.layout[0] == 1 or partitioner.layout[1] == 1: + if partitioner.layout[0] <= 1 or partitioner.layout[1] <= 1: raise NotImplementedError( - f"Build for layout {partitioner.layout} is not handled" + f"Build for layout {partitioner.layout} is not handled." ) - else: - if partitioner.tile.on_tile_bottom(rank): - if partitioner.tile.on_tile_left(rank): - return FV3CodePath.BottomLeft - if partitioner.tile.on_tile_right(rank): - return FV3CodePath.BottomRight - else: - return FV3CodePath.Bottom - if partitioner.tile.on_tile_top(rank): - if partitioner.tile.on_tile_left(rank): - return FV3CodePath.TopLeft - if partitioner.tile.on_tile_right(rank): - return FV3CodePath.TopRight - else: - return FV3CodePath.Top - else: - if partitioner.tile.on_tile_left(rank): - return FV3CodePath.Left - if partitioner.tile.on_tile_right(rank): - return FV3CodePath.Right - else: - return FV3CodePath.Center + + # Bottom row + if partitioner.tile.on_tile_bottom(rank): + if partitioner.tile.on_tile_left(rank): + return FV3CodePath.BottomLeft + if partitioner.tile.on_tile_right(rank): + return FV3CodePath.BottomRight + return FV3CodePath.Bottom + + # Top row + if partitioner.tile.on_tile_top(rank): + if partitioner.tile.on_tile_left(rank): + return FV3CodePath.TopLeft + if partitioner.tile.on_tile_right(rank): + return FV3CodePath.TopRight + return FV3CodePath.Top + + # Left & right column with corners already handled + if partitioner.tile.on_tile_left(rank): + return FV3CodePath.Left + if partitioner.tile.on_tile_right(rank): + return FV3CodePath.Right + + return FV3CodePath.Center def get_cache_fullpath(code_path: FV3CodePath) -> str: diff --git a/ndsl/dsl/caches/codepath.py b/ndsl/dsl/caches/codepath.py index 61591ccf..3d90a9e2 100644 --- a/ndsl/dsl/caches/codepath.py +++ b/ndsl/dsl/caches/codepath.py @@ -3,10 +3,12 @@ class FV3CodePath(enum.Enum): """Enum listing all possible code paths on a cube sphere. + For any layout the cube sphere has up to 9 different code paths depending on the positioning of the rank on the tile and which of the edge/corner cases it has to handle, as well as the possibility for all boundary computations in the 1x1 layout case. + Since the framework inlines code to optimize, we _cannot_ pre-suppose which code being kept and/or ejected. This enum serves as the ground truth to map rank to the proper generated code. From 0d860bf4a4f0c65d36ab893dd6b6117184b71ba1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 4 Jun 2026 14:46:32 +0200 Subject: [PATCH 036/101] translate tests: fix crash in reporting when comparing scalars UW translate test compares a scalar value (`dotransport`) as part of the translate test. Doing so trips reporting and this change makes it work again \o/ --- ndsl/testing/comparison.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index 3acfd723..e7fc93eb 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -339,7 +339,7 @@ def one_line_report(self) -> str: return f"❌ Numerical failures: {failed_indices}/{all_indices} failed - metric: {metric_thresholds}" def report(self, file_path: str | None = None) -> list[str]: - failed_indices = np.logical_not(self.success).nonzero() + failed_indices = np.atleast_1d(np.logical_not(self.success)).nonzero() # List all errors to terminal and file bad_indices_count = len(failed_indices[0]) if self.changing_column_map is not None: From 5ee3bb9217f29f7cd23d1ae95dc14109867c8ba4 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 5 Jun 2026 12:20:10 -0400 Subject: [PATCH 037/101] Weaken the cube-sphere communicator hard ranks limit. We need "at least" not "exactly" the rank number --- ndsl/comm/communicator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 65d72018..abb70ec8 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -786,7 +786,7 @@ def __init__( "Communicator needs to be instantiated with communication subsystem" f" derived from `comm_abc.Comm`, got {type(comm)}." ) - if comm.Get_size() != partitioner.total_ranks: + if comm.Get_size() < partitioner.total_ranks: raise ValueError( f"was given a partitioner for {partitioner.total_ranks} ranks but a " f"comm object with only {comm.Get_size()} ranks, are we running " From da82edefd73728cea16c6288297422e7f6bd7792 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 5 Jun 2026 14:03:22 -0400 Subject: [PATCH 038/101] Adjust `cflags` format read for orchestrated compile Protect `performance_timer` for `time==False` and add external setup --- ndsl/dsl/dace/dace_config.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index c53ebc85..579ddcb0 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -181,16 +181,10 @@ def __init__( # ToDo: DaceConfig becomes a bit more than a read-only config # with this. Should be refactored into a DaceExecutor carrying a config self.loaded_dace_executables: DaceExecutables = {} - self.performance_collector = ( - PerformanceCollector( - "InternalOrchestrationTimer", - comm=( - LocalComm(0, 6, {}) if communicator is None else communicator.comm - ), - ) - if time - else NullPerformanceCollector() - ) + if not time: + self.performance_collector = NullPerformanceCollector() + else: + self.set_timer(communicator.comm if communicator else None) # Temporary. This is a bit too out of the ordinary for the common user. # We should refactor the architecture to allow for a `gtc:orchestrated:dace:X` @@ -264,11 +258,12 @@ def __init__( march_option = "-mcpu=native" if is_arm_neoverse else "-march=native" # Removed --fast-math gpu_config = gpu_configuration(GT4PY_COMPILE_OPT_LEVEL) + gpu_cflags = " ".join(gpu_config.gpu_compile_flags).strip() dace.config.Config.set( "compiler", "cuda", "args", - value=f"-std=c++14 -Xcompiler -fPIC -O{optimization_level} -Xcompiler {march_option} {gpu_config.gpu_compile_flags}", + value=f"-std=c++14 -Xcompiler -fPIC -O{optimization_level} -Xcompiler {march_option} {gpu_cflags}", ) cuda_sm = cp.cuda.Device(0).compute_capability if cp else 60 @@ -420,3 +415,12 @@ def from_dict(cls, data: dict) -> Self: raise NotImplementedError( "Implementation of `DaceConfig.from_dict()` is incomplete." ) + + def set_timer(self, comm): + """Set timer on configuration externally""" + # TODO: this absolutely should not be a on a Configuration object + # and even less setup outside. Madness, we have lost our ways... + self.performance_collector = PerformanceCollector( + "InternalOrchestrationTimer", + comm=(LocalComm(0, 6, {}) if comm is None else comm), + ) From 64dd47c5022663105923d9345bd7529bbd79dd34 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 5 Jun 2026 14:05:44 -0400 Subject: [PATCH 039/101] Lint --- ndsl/dsl/dace/dace_config.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 579ddcb0..0059ff39 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -10,6 +10,7 @@ from gt4py.cartesian.utils.compiler import cxx_compiler_defaults, gpu_configuration from ndsl import LocalComm +from ndsl.comm import Comm from ndsl.comm.communicator import Communicator from ndsl.comm.partitioner import Partitioner from ndsl.config import Backend @@ -17,7 +18,11 @@ from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath from ndsl.optional_imports import cupy as cp -from ndsl.performance.collector import NullPerformanceCollector, PerformanceCollector +from ndsl.performance.collector import ( + AbstractPerformanceCollector, + NullPerformanceCollector, + PerformanceCollector, +) if TYPE_CHECKING: @@ -182,7 +187,9 @@ def __init__( # with this. Should be refactored into a DaceExecutor carrying a config self.loaded_dace_executables: DaceExecutables = {} if not time: - self.performance_collector = NullPerformanceCollector() + self.performance_collector: AbstractPerformanceCollector = ( + NullPerformanceCollector() + ) else: self.set_timer(communicator.comm if communicator else None) @@ -416,7 +423,7 @@ def from_dict(cls, data: dict) -> Self: "Implementation of `DaceConfig.from_dict()` is incomplete." ) - def set_timer(self, comm): + def set_timer(self, comm: Comm | None) -> None: """Set timer on configuration externally""" # TODO: this absolutely should not be a on a Configuration object # and even less setup outside. Madness, we have lost our ways... From 26fb0ef2f60168371fc401520ecdd15c1a1842ae Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sat, 6 Jun 2026 17:13:31 -0400 Subject: [PATCH 040/101] Introduce hardware configuration good defaults --- ndsl/dsl/dace/dace_config.py | 22 ++++-- ndsl/dsl/dace/hardware_config.py | 112 +++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 7 deletions(-) create mode 100644 ndsl/dsl/dace/hardware_config.py diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 0059ff39..f73bc249 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -17,6 +17,7 @@ from ndsl.dsl import NDSL_GLOBAL_PRECISION from ndsl.dsl.caches.cache_location import identify_code_path from ndsl.dsl.caches.codepath import FV3CodePath +from ndsl.dsl.dace.hardware_config import get_gpu_hardware_defaults from ndsl.optional_imports import cupy as cp from ndsl.performance.collector import ( AbstractPerformanceCollector, @@ -273,14 +274,21 @@ def __init__( value=f"-std=c++14 -Xcompiler -fPIC -O{optimization_level} -Xcompiler {march_option} {gpu_cflags}", ) - cuda_sm = cp.cuda.Device(0).compute_capability if cp else 60 - dace.config.Config.set("compiler", "cuda", "cuda_arch", value=f"{cuda_sm}") - # Block size/thread count is defaulted to an average value for recent - # hardware (Pascal and upward). The problem of setting an optimized - # block/thread is both hardware and problem dependant. Fine tuners - # available in DaCe should be relied on for further tuning of this value. + # Target compilation for hardware micro-code capacities + gpu_defaults = get_gpu_hardware_defaults() dace.config.Config.set( - "compiler", "cuda", "default_block_size", value="64,8,1" + "compiler", + "cuda", + "cuda_arch", + value=f"{gpu_defaults.compute_capability}", + ) + + # Default block size for kernels launch + dace.config.Config.set( + "compiler", + "cuda", + "default_block_size", + value=str(gpu_defaults.block_size)[1:-1], ) # Potentially buggy - deactivate dace.config.Config.set( diff --git a/ndsl/dsl/dace/hardware_config.py b/ndsl/dsl/dace/hardware_config.py new file mode 100644 index 00000000..ca28ac3b --- /dev/null +++ b/ndsl/dsl/dace/hardware_config.py @@ -0,0 +1,112 @@ +import dataclasses +import os +import sys + +from ndsl import ndsl_log +from ndsl.optional_imports import cupy as cp + + +# Taken straight out of https://pcisig.com/membership/member-companies +_VENDOR_PCI_SIGNAURES = { + 0x10DE: "Nvidia", + 0x1002: "AMD", + 0x8086: "Intel", + 0x0: "Unknown", +} + +# Cached copy of the hardware default +_GPU_HARDWARE_DEFAULTS = None + + +def _get_vendor() -> str: + """Retrieve vendor using the current device PCI id to query the PCI vendor + from the kernel logs + + ⚠️ Only works on Linux - kicks back to "Unknwon" in other cases + """ + if not sys.platform.startswith("linux"): + return _VENDOR_PCI_SIGNAURES[0x0] + + pci_device_id = cp.cuda.runtime.deviceGetPCIBusId(0) + dev_path = f"/sys/bus/pci/devices/{pci_device_id}" + if not os.path.exists(dev_path): + return "Unknown" + + with open(os.path.join(dev_path, "vendor"), "r") as f: + vendor_str = f.read().strip().replace("0x", "") + vendor_id = int(vendor_str, 16) + + if vendor_id not in _VENDOR_PCI_SIGNAURES: + ndsl_log.error(f"Unknown GPU vendor with PCI-SIG ID of {vendor_id:#X}") + return "Unknown" + return _VENDOR_PCI_SIGNAURES[int(vendor_str, 16)] + + +@dataclasses.dataclass +class GPUHardwareDefaults: + """Compute defaults for common GPUs""" + + vendor: str + block_size: list[int] = dataclasses.field(default_factory=list) + compute_capability: int = -1 # Nvidia specific + + +def get_gpu_hardware_defaults() -> GPUHardwareDefaults: + """Retrieve default values for GPU computation configuration""" + global _GPU_HARDWARE_DEFAULTS + if _GPU_HARDWARE_DEFAULTS is not None: + return _GPU_HARDWARE_DEFAULTS # type: ignore[unreachable] + + if not cp: + raise ModuleNotFoundError("Cupy must be installed to read hardware defaults") + if not cp.cuda.is_available(): + raise RuntimeError("No device available for hardware defaults read") + + # Who goes there + vendor = _get_vendor() + if vendor == "Nvidia": + compute_capability = int(cp.cuda.Device(0).compute_capability) + # Default block size based on compute capability + if compute_capability > 80: + # Covers: + # - Blackwell (100+) + # - Hopper (90-100) + # - Ampere (80-90) + block_sizes = [128, 1, 1] + elif compute_capability > 60: + # Covers: + # - Volta (70-80) + # - Pascal (60-70) + block_sizes = [64, 8, 1] + else: + # For older hardware - we default to the safe warp-size since + # the dawn of GPGPU on Nvidia hardware + block_sizes = [32, 1, 1] + + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, + block_size=block_sizes, + compute_capability=compute_capability, + ) + elif vendor == "AMD": + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, block_size=[64, 1, 1] # Default RDNA architectue is Wave64 + ) + elif vendor == "Intel": + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, + block_size=[32, 1, 1], # Intel can run 8, 16 or 32 - but SIMD betters in 32 + ) + else: + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor=vendor, + block_size=[ + 8, + 1, + 1, + ], # Smaller common denominator of massively parallel hardware + ) + + ndsl_log.info(f"GPU vendor detected: {_GPU_HARDWARE_DEFAULTS.vendor}") + + return _GPU_HARDWARE_DEFAULTS From 7bdd3fa39f1317955824be4016ab9bf1b7139e1f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sat, 6 Jun 2026 17:15:06 -0400 Subject: [PATCH 041/101] Fix double load for compiling rank Split Simplify2 pass into a GPU centric with block_size on maps & apply_gpu_xform Remove useless code - legacy code bleed Verbose the steps better --- ndsl/dsl/dace/orchestration.py | 112 +++++++++++++++++++-------------- 1 file changed, 65 insertions(+), 47 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index c88212c2..cf64fa8a 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -33,6 +33,7 @@ DaCeOrchestration, ) from ndsl.dsl.dace.dace_executable import DaceExecutable +from ndsl.dsl.dace.hardware_config import get_gpu_hardware_defaults from ndsl.dsl.dace.labeler import set_label from ndsl.dsl.dace.sdfg_debug_passes import ( negative_delp_checker, @@ -248,43 +249,52 @@ def _build_sdfg( ) sdfg.apply_transformations_repeated(MapCollapse) - # Make the transients array persistents - if config.is_gpu_backend(): - # TODO - # The following should happen on the stree level - _to_gpu(sdfg) - - sdfg.apply_gpu_transformations() + with DaCeProgress(config, "Make transient persistents"): + # Make the transients array persistents + if config.is_gpu_backend(): + # TODO + # The following should happen on the stree level + _to_gpu(sdfg) + make_transients_persistent(sdfg=sdfg, device=device_type) - make_transients_persistent(sdfg=sdfg, device=device_type) + # Upload args to device + _upload_to_device(list(args) + list(kwargs.values())) + else: + # TODO + # The following should happen on the stree level + for _sd, _aname, arr in sdfg.arrays_recursive(): + if arr.shape == (1,): + arr.storage = DaceStorageType.Register + make_transients_persistent(sdfg=sdfg, device=device_type) - # Upload args to device - _upload_to_device(list(args) + list(kwargs.values())) + if config.is_gpu_backend(): + with DaCeProgress(config, "Apply GPU transformations"): + # Set block size on GPU maps + gpu_defaults = get_gpu_hardware_defaults() + for me, _state in sdfg.all_nodes_recursive(): + if ( + isinstance(me, nodes.MapEntry) + and me.map.schedule == ScheduleType.GPU_Device + ): + if me.map.gpu_block_size is None: + me.map.gpu_block_size = gpu_defaults.block_size + # Apply common GPU transforms (includes a simplify) + sdfg.apply_gpu_transformations() + if config.verbose_orchestration: + sdfg.save( + os.path.abspath( + f"{sdfg.build_folder}/05-apply_gpu_xforms.sdfgz" + ), + compress=True, + ) else: - # TODO - # The following should happen on the stree level - for _sd, _aname, arr in sdfg.arrays_recursive(): - if arr.shape == (1,): - arr.storage = DaceStorageType.Register - make_transients_persistent(sdfg=sdfg, device=device_type) - - # Build non-constants & non-transients from the sdfg_kwargs - sdfg_kwargs = dace_program._create_sdfg_args(sdfg, args, kwargs) - for k in dace_program.constant_args: - if k in sdfg_kwargs: - del sdfg_kwargs[k] - sdfg_kwargs = {k: v for k, v in sdfg_kwargs.items() if v is not None} - for k, tup in dace_program.resolver.closure_arrays.items(): - if k in sdfg_kwargs and tup[1].transient: - del sdfg_kwargs[k] - - with DaCeProgress(config, "Simplify (2)"): - _simplify(sdfg) - if config.verbose_orchestration: - sdfg.save( - os.path.abspath(f"{sdfg.build_folder}/05-simplify_2.sdfgz"), - compress=True, - ) + with DaCeProgress(config, "Simplify (2)"): + _simplify(sdfg) + if config.verbose_orchestration: + sdfg.save( + os.path.abspath(f"{sdfg.build_folder}/05-simplify_2.sdfgz"), + compress=True, + ) # Move all memory that can be into a pool to lower memory pressure for GPU # We skip this memory optimization for CPU because we don't have a memory # pool available yet (DaCe v1) @@ -313,7 +323,12 @@ def _build_sdfg( # Compile with DaCeProgress(config, "Codegen & compile"): - sdfg.compile() + compiled_sdfg = sdfg.compile() + config.loaded_dace_executables[dace_program] = DaceExecutable( + compiled_sdfg=compiled_sdfg, + arguments={}, + arguments_hash=0, + ) # Printing analysis of the compiled SDFG with DaCeProgress(config, "Build finished. Running memory static analysis"): @@ -352,18 +367,21 @@ def _build_sdfg( ) MPI.COMM_WORLD.Barrier() - with DaCeProgress(config, "Loading"): - sdfg_path = get_sdfg_path(dace_program.name, config, override_run_only=True) - if sdfg_path is None: - raise ValueError("Couldn't load SDFG post build") - compiledSDFG, _ = dace_program.load_precompiled_sdfg( - sdfg_path, *args, **kwargs - ) - config.loaded_dace_executables[dace_program] = DaceExecutable( - compiled_sdfg=compiledSDFG, - arguments={}, - arguments_hash=0, - ) + if not is_compiling: + with DaCeProgress(config, "Loading"): + sdfg_path = get_sdfg_path( + dace_program.name, config, override_run_only=True + ) + if sdfg_path is None: + raise ValueError("Couldn't load SDFG post build") + compiledSDFG, _ = dace_program.load_precompiled_sdfg( + sdfg_path, *args, **kwargs + ) + config.loaded_dace_executables[dace_program] = DaceExecutable( + compiled_sdfg=compiledSDFG, + arguments={}, + arguments_hash=0, + ) def _call_sdfg( From 0fcd9bd0e5894c8995cfbe170b9489a057eb91f7 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sat, 6 Jun 2026 17:46:03 -0400 Subject: [PATCH 042/101] Hardware default: gives back default when no `cp` instead of raising --- ndsl/dsl/dace/hardware_config.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ndsl/dsl/dace/hardware_config.py b/ndsl/dsl/dace/hardware_config.py index ca28ac3b..ebcdbfee 100644 --- a/ndsl/dsl/dace/hardware_config.py +++ b/ndsl/dsl/dace/hardware_config.py @@ -57,10 +57,17 @@ def get_gpu_hardware_defaults() -> GPUHardwareDefaults: if _GPU_HARDWARE_DEFAULTS is not None: return _GPU_HARDWARE_DEFAULTS # type: ignore[unreachable] - if not cp: - raise ModuleNotFoundError("Cupy must be installed to read hardware defaults") - if not cp.cuda.is_available(): - raise RuntimeError("No device available for hardware defaults read") + if not cp or not cp.cuda.is_available(): + ndsl_log.warning("No cupy - defaulting for GPU hardware") + _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( + vendor="Unknown", + block_size=[ + 8, + 1, + 1, + ], # Smaller common denominator of massively parallel hardware + ) + return _GPU_HARDWARE_DEFAULTS # Who goes there vendor = _get_vendor() From d843a2cd7785a3a4584c42c82720ebcc11fce365 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 7 Jun 2026 16:49:24 -0400 Subject: [PATCH 043/101] Orch: always collapse maps to maximize the kernel parallel basis --- ndsl/dsl/dace/orchestration.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index cf64fa8a..07cdeeb5 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -247,7 +247,11 @@ def _build_sdfg( os.path.abspath(f"{sdfg.build_folder}/04-from_stree.sdfgz"), compress=True, ) - sdfg.apply_transformations_repeated(MapCollapse) + + # We want all maps properly collapse to make sure the codegen will see nD parallel + # axis as a single kernelizable map + with DaCeProgress(config, "Collapse maps"): + sdfg.apply_transformations_repeated(MapCollapse) with DaCeProgress(config, "Make transient persistents"): # Make the transients array persistents From fd9e588256d204ce6aff6db7e1beea522e70e023 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 9 Jun 2026 11:39:16 +0200 Subject: [PATCH 044/101] gt4py update to latest romanc/fix-log10-precision --- external/gt4py | 2 +- ndsl/quantity/quantity.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/gt4py b/external/gt4py index 210fcbd8..fa099e6e 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 210fcbd8c78800bf26421fac3c49c5b22e59d4e4 +Subproject commit fa099e6e3a80bf0feb357698eb4fd4a13c753cbd diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 0624a8c0..5d310674 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -287,7 +287,7 @@ def field(self) -> np.ndarray | cupy.ndarray: def data(self) -> np.ndarray | cupy.ndarray: """The underlying array of data""" warnings.warn( - "Quantity.data accessor is now deprecated. Use a slicing operation directly on" + "Quantity.data accessor is now deprecated. Use a slicing operation directly on " "the quantity, e.g. `my_quantity[:]` instead of `my_quantity.data[:]`", category=UserWarning, stacklevel=2, From c32fabf41ee0330a1d91fc8fc537515ce4028de4 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 9 Jun 2026 14:34:45 +0200 Subject: [PATCH 045/101] review of new gpu hardware detection --- ndsl/dsl/dace/hardware_config.py | 39 +++++++++++++++++++------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/ndsl/dsl/dace/hardware_config.py b/ndsl/dsl/dace/hardware_config.py index ebcdbfee..bbd367dc 100644 --- a/ndsl/dsl/dace/hardware_config.py +++ b/ndsl/dsl/dace/hardware_config.py @@ -1,13 +1,16 @@ import dataclasses -import os import sys +from pathlib import Path +from typing import Literal from ndsl import ndsl_log from ndsl.optional_imports import cupy as cp +GPUVendor = Literal["Nvidia"] | Literal["AMD"] | Literal["Intel"] | Literal["Unknown"] + # Taken straight out of https://pcisig.com/membership/member-companies -_VENDOR_PCI_SIGNAURES = { +_VENDOR_PCI_SIGNATURES: dict[int, GPUVendor] = { 0x10DE: "Nvidia", 0x1002: "AMD", 0x8086: "Intel", @@ -18,46 +21,49 @@ _GPU_HARDWARE_DEFAULTS = None -def _get_vendor() -> str: +def _get_vendor() -> GPUVendor: """Retrieve vendor using the current device PCI id to query the PCI vendor - from the kernel logs + from the kernel logs. - ⚠️ Only works on Linux - kicks back to "Unknwon" in other cases + ⚠️ Only works on Linux - kicks back to "Unknown" in other cases. """ if not sys.platform.startswith("linux"): - return _VENDOR_PCI_SIGNAURES[0x0] + ndsl_log.info("GPU hardware detection only possible on Linux system.") + return "Unknown" pci_device_id = cp.cuda.runtime.deviceGetPCIBusId(0) - dev_path = f"/sys/bus/pci/devices/{pci_device_id}" - if not os.path.exists(dev_path): + dev_path = Path("/sys", "bus", "pci", "devices", f"{pci_device_id}") + if not dev_path.exists(): + ndsl_log.info(f"GPU detection: PCI device not found at {dev_path}.") return "Unknown" - with open(os.path.join(dev_path, "vendor"), "r") as f: + with open(dev_path / "vendor", "r") as f: vendor_str = f.read().strip().replace("0x", "") vendor_id = int(vendor_str, 16) - if vendor_id not in _VENDOR_PCI_SIGNAURES: - ndsl_log.error(f"Unknown GPU vendor with PCI-SIG ID of {vendor_id:#X}") + if vendor_id not in _VENDOR_PCI_SIGNATURES: + ndsl_log.error(f"Unknown GPU vendor with PCI-SIG ID of {vendor_id:#X}.") return "Unknown" - return _VENDOR_PCI_SIGNAURES[int(vendor_str, 16)] + + return _VENDOR_PCI_SIGNATURES[vendor_id] @dataclasses.dataclass class GPUHardwareDefaults: """Compute defaults for common GPUs""" - vendor: str + vendor: GPUVendor block_size: list[int] = dataclasses.field(default_factory=list) compute_capability: int = -1 # Nvidia specific def get_gpu_hardware_defaults() -> GPUHardwareDefaults: - """Retrieve default values for GPU computation configuration""" + """Retrieve default values for GPU computation configuration.""" global _GPU_HARDWARE_DEFAULTS if _GPU_HARDWARE_DEFAULTS is not None: return _GPU_HARDWARE_DEFAULTS # type: ignore[unreachable] - if not cp or not cp.cuda.is_available(): + if cp is None or not cp.cuda.is_available(): ndsl_log.warning("No cupy - defaulting for GPU hardware") _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( vendor="Unknown", @@ -97,7 +103,8 @@ def get_gpu_hardware_defaults() -> GPUHardwareDefaults: ) elif vendor == "AMD": _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( - vendor=vendor, block_size=[64, 1, 1] # Default RDNA architectue is Wave64 + vendor=vendor, + block_size=[64, 1, 1], # Default RDNA architecture is Wave64 ) elif vendor == "Intel": _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults( From 7e45ff7ce27750f5d5012e3f760f595da5cf0d41 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 9 Jun 2026 15:04:43 +0200 Subject: [PATCH 046/101] update dace/gt4py to bring a typehint fix from dace --- external/dace | 2 +- external/gt4py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/dace b/external/dace index 4da9d096..c8cb49cc 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit 4da9d096ed3454ffa6dcb7b5233c281dc90696c2 +Subproject commit c8cb49cc29cc281a5c92bd84e231e31cb6dc561d diff --git a/external/gt4py b/external/gt4py index fa099e6e..cdca87a5 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit fa099e6e3a80bf0feb357698eb4fd4a13c753cbd +Subproject commit cdca87a5835062353aa9ae266b6e4d0adf9bb010 From 099cd91e48c8aba5800c30ef275ff67925040fc1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 12 Jun 2026 08:12:20 +0200 Subject: [PATCH 047/101] gt4py update: unit-aligned dace gpu backends --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index cdca87a5..65299a79 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit cdca87a5835062353aa9ae266b6e4d0adf9bb010 +Subproject commit 65299a797510d3a621d614b5ca544709985656eb From ea873ec18effadfccdca848bd56d0ea579cac3e6 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 12 Jun 2026 11:25:56 +0200 Subject: [PATCH 048/101] unrelated: use dace import shortcut convetions in ndsl code --- .../stree/optimizations/refine_transients.py | 26 +++++++++---------- .../stree/optimizations/specialize_maps.py | 7 ++--- .../dace/stree/optimizations/statistics.py | 18 ++++++------- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index cd8e2703..89b1692a 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -1,7 +1,7 @@ import warnings import dace.data -import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl.config import Backend, BackendFramework from ndsl.dsl.dace.stree.optimizations.common import AxisIterator @@ -74,7 +74,7 @@ def _reduce_cartesian_axis_size_to_1( return True -class CollectTransientRangeAccess(stree.ScheduleNodeVisitor): +class CollectTransientRangeAccess(tn.ScheduleNodeVisitor): """Unionize all transient arrays access into a single Range.""" def __init__(self) -> None: @@ -96,12 +96,12 @@ def __str__(self) -> str: def _find_first_map_or_loop( self, - node: stree.TaskletNode, + node: tn.TaskletNode, axis: AxisIterator, ) -> dace.nodes.MapEntry | None: parent = node.parent while parent is not None: - if isinstance(parent, stree.MapScope): + if isinstance(parent, tn.MapScope): for p in parent.node.params: if p.startswith(axis.as_str()): return parent.node @@ -111,8 +111,8 @@ def _find_first_map_or_loop( def _record_access( self, - node: stree.TaskletNode, - memlets: stree.MemletSet, + node: tn.TaskletNode, + memlets: tn.MemletSet, recording_set: dict[str, dace.subsets.Range | None], ) -> None: for memlet in memlets: @@ -145,11 +145,11 @@ def _record_access( AxisIterator._K.as_cartesian_index() ].add(map_entry) - def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + def visit_TaskletNode(self, node: tn.TaskletNode) -> None: self._record_access(node, node.input_memlets(), self.transients_range_writes) self._record_access(node, node.output_memlets(), self.transients_range_reads) - def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: self.containers = node.containers for name, data in self.containers.items(): if data.transient and isinstance(data, dace.data.Array): @@ -161,7 +161,7 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: self.visit(child) -class RebuildMemletsFromContainers(stree.ScheduleNodeVisitor): +class RebuildMemletsFromContainers(tn.ScheduleNodeVisitor): """Rebuild memlets from containers to ensure they are scope to the right size.""" def __init__(self, refined_arrays: set[str]) -> None: @@ -170,7 +170,7 @@ def __init__(self, refined_arrays: set[str]) -> None: def __str__(self) -> str: return "RefineTransientAxis" - def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + def visit_TaskletNode(self, node: tn.TaskletNode) -> None: for memlet in [*node.output_memlets(), *node.input_memlets()]: if memlet.data not in self._refined_arrays: continue @@ -187,13 +187,13 @@ def visit_TaskletNode(self, node: stree.TaskletNode) -> None: if array.shape[index] == 1: memlet.subset.ranges[index] = (0, 0, 1) - def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: self.containers = node.containers for child in node.children: self.visit(child) -class CartesianRefineTransients(stree.ScheduleNodeTransformer): +class CartesianRefineTransients(tn.ScheduleNodeTransformer): """Refine (reduce dimensionality) of transients based on their true use in the cartesian dimensions. @@ -254,7 +254,7 @@ def __init__(self, backend: Backend) -> None: def __str__(self) -> str: return "CartesianRefineTransients" - def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: collect_map = CollectTransientRangeAccess() collect_map.visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py index e2409e1a..9f7e4be4 100644 --- a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py +++ b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py @@ -1,15 +1,16 @@ -import dace.sdfg.analysis.schedule_tree.treenodes as stree import dace.subsets as sbs +from dace.sdfg.analysis.schedule_tree import treenodes as tn -class SpecializeCartesianMaps(stree.ScheduleNodeVisitor): +class SpecializeCartesianMaps(tn.ScheduleNodeVisitor): def __init__(self, mappings: dict[str, int]) -> None: super().__init__() self._mappings = mappings - def visit_MapScope(self, node: stree.MapScope) -> None: + def visit_MapScope(self, node: tn.MapScope) -> None: dims = [] for p in node.node.map.params: + assert isinstance(p, str) if p == "__i": dims.append((0, self._mappings["__I"], 1)) if p == "__j": diff --git a/ndsl/dsl/dace/stree/optimizations/statistics.py b/ndsl/dsl/dace/stree/optimizations/statistics.py index 6e5fe3af..6fc927f9 100644 --- a/ndsl/dsl/dace/stree/optimizations/statistics.py +++ b/ndsl/dsl/dace/stree/optimizations/statistics.py @@ -1,7 +1,7 @@ import dataclasses import dace -import dace.sdfg.analysis.schedule_tree.treenodes as stree +from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl.dsl.dace.stree.optimizations.common import ( AxisIterator, @@ -10,20 +10,20 @@ ) -class CountCartesianLoops(stree.ScheduleNodeVisitor): +class CountCartesianLoops(tn.ScheduleNodeVisitor): def __init__(self) -> None: super().__init__() self._maps = [0, 0, 0] self._fors = [0, 0, 0] - def visit_MapScope(self, node: stree.MapScope) -> None: + def visit_MapScope(self, node: tn.MapScope) -> None: for axis in AxisIterator: if is_axis_map(node, axis): self._maps[axis.as_cartesian_index()] += 1 self.visit(node.children) - def visit_ForScope(self, node: stree.ForScope) -> None: + def visit_ForScope(self, node: tn.ForScope) -> None: for axis in AxisIterator: if is_axis_for(node, axis): self._fors[axis.as_cartesian_index()] += 1 @@ -31,12 +31,12 @@ def visit_ForScope(self, node: stree.ForScope) -> None: self.visit(node.children) -class CountTransient(stree.ScheduleNodeVisitor): +class CountTransient(tn.ScheduleNodeVisitor): def __init__(self) -> None: super().__init__() self._counts = [0, 0, 0, 0, 0] - def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None: + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: for data in node.containers.values(): non_atomic_dims_count = sum(1 for x in data.shape if x != 1) if isinstance(data, dace.data.Array) and data.transient: @@ -72,7 +72,7 @@ def __init__(self) -> None: def _record( self, record: Record, - tree_root: stree.ScheduleTreeRoot, + tree_root: tn.ScheduleTreeRoot, ) -> None: """Record the state of a tree""" c = CountCartesianLoops() @@ -84,11 +84,11 @@ def _record( c.visit(tree_root) record.transients = c._counts - def original(self, tree_root: stree.ScheduleTreeRoot) -> None: + def original(self, tree_root: tn.ScheduleTreeRoot) -> None: """Record the original state of the tree, before optimization""" self._record(self._original_record, tree_root) - def optimized(self, tree_root: stree.ScheduleTreeRoot) -> None: + def optimized(self, tree_root: tn.ScheduleTreeRoot) -> None: """Record the state of the tree after optimization""" self._record(self._optimized_record, tree_root) From a94d5a0232314daece9ca0064dc23d468f00068e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 12 Jun 2026 15:03:33 +0200 Subject: [PATCH 049/101] Transform to kernelize maps on GPU --- ndsl/dsl/dace/stree/optimizations/__init__.py | 6 +- .../stree/optimizations/common/__init__.py | 3 +- .../dace/stree/optimizations/common/loops.py | 11 ++ .../stree/optimizations/kernelize_maps.py | 80 ++++++++ ndsl/dsl/dace/stree/pipeline.py | 2 + tests/dsl/dace/stree/common/test_loops.py | 27 +++ .../optimizations/test_kernelize_maps.py | 181 ++++++++++++++++++ 7 files changed, 307 insertions(+), 3 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/kernelize_maps.py create mode 100644 tests/dsl/dace/stree/optimizations/test_kernelize_maps.py diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 8cd77f55..2c1d7c3d 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -1,6 +1,7 @@ from .axis_merge import CartesianAxisMerge from .cartesian_merge import CartesianMerge from .clean_tree import CleanUpScheduleTree +from .kernelize_maps import KernelizeMaps from .offgrid_conditionals import ( ExtractOffgridConditionals, InlineOffgridConditionals, @@ -13,10 +14,11 @@ __all__ = [ "CartesianAxisMerge", "CartesianMerge", - "CartesianRefineTransients", "CleanUpScheduleTree", - "InlineVertical2DWrite", + "KernelizeMaps", "ExtractOffgridConditionals", "InlineOffgridConditionals", "MergeConditionals", + "CartesianRefineTransients", + "InlineVertical2DWrite", ] diff --git a/ndsl/dsl/dace/stree/optimizations/common/__init__.py b/ndsl/dsl/dace/stree/optimizations/common/__init__.py index c76887fb..2e342912 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/common/__init__.py @@ -1,5 +1,5 @@ from .memlet import AxisIterator, no_data_dependencies_on_cartesian_axis # isort: skip -from .loops import is_axis_for, is_axis_map +from .loops import is_axis_for, is_axis_map, is_cartesian_axis from .topology import ( detect_cycle, get_next_node, @@ -14,6 +14,7 @@ "AxisIterator", "no_data_dependencies_on_cartesian_axis", "is_axis_map", + "is_cartesian_axis", "is_axis_for", "get_next_node", "last_node", diff --git a/ndsl/dsl/dace/stree/optimizations/common/loops.py b/ndsl/dsl/dace/stree/optimizations/common/loops.py index 5d414915..9433eb76 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/loops.py +++ b/ndsl/dsl/dace/stree/optimizations/common/loops.py @@ -12,6 +12,17 @@ def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: return axis.is_equal(map_parameter[0]) +def is_cartesian_axis(node: tn.MapScope | tn.ForScope) -> bool: + """Returns true if the given node is a map over any cartesian axis.""" + for axis in AxisIterator: + if (isinstance(node, tn.MapScope) and is_axis_map(node, axis)) or ( + isinstance(node, tn.ForScope) and is_axis_for(node, axis) + ): + return True + + return False + + def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool: """Returns true if node is a For over the given axis.""" return axis.is_equal(node.loop.loop_variable) diff --git a/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py b/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py new file mode 100644 index 00000000..4beb766a --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py @@ -0,0 +1,80 @@ +from copy import deepcopy + +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from ndsl import Backend +from ndsl.config import BackendLoopOrder, BackendTargetDevice +from ndsl.dsl.dace.stree.optimizations.common import ( + AxisIterator, + is_axis_map, + is_cartesian_axis, +) + + +class _KernelizeMap(tn.ScheduleNodeTransformer): + def __init__(self, axis: AxisIterator) -> None: + super().__init__() + self._axis = axis + + def __str__(self) -> str: + return f"KernelizeMap_{self._axis}" + + def _count_cartesian_children(self, node: tn.ScheduleTreeScope) -> int: + cartesian_children = 0 + for child in node.children: + if isinstance(child, (tn.MapScope, tn.ForScope)) and is_cartesian_axis( + child + ): + cartesian_children += 1 + return cartesian_children + + def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope | list[tn.MapScope]: + # if this is a map on a cartesian axis + # and the children contain two or more cartesian axes + if is_axis_map(node, self._axis) and self._count_cartesian_children(node) > 1: + kernelized_maps: list[tn.MapScope] = [] + current_children: list[tn.ScheduleTreeNode] = [] + + for child in node.children: + current_children.append(child) + if isinstance(child, (tn.MapScope, tn.ForScope)) and is_cartesian_axis( + child + ): + kernelized_maps.append( + tn.MapScope( + node=deepcopy(node.node), + children=[child for child in current_children], + parent=node.parent, + state=node.state, + ) + ) + current_children = [] + return kernelized_maps + + return self.generic_visit(node) + + +class KernelizeMaps(tn.ScheduleNodeVisitor): + def __init__(self, backend: Backend) -> None: + super().__init__() + self._backend = backend + + if self._backend.device != BackendTargetDevice.GPU: + raise ValueError( + "The transformation `KernelizeMaps` is only intended to run on GPUs." + ) + + def __str__(self) -> str: + return "KernelizeMaps" + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + for axis in self._axis_order(): + _KernelizeMap(axis).visit(node) + + def _axis_order(self) -> list[AxisIterator]: + if self._backend.loop_order == BackendLoopOrder.IJK: + return [AxisIterator._J, AxisIterator._I] + + raise NotImplementedError( + f"KernelizeMaps is not configured for loop order {self._backend.loop_order}." + ) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index f44a399e..8a74849c 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -7,6 +7,7 @@ CartesianMerge, CartesianRefineTransients, CleanUpScheduleTree, + KernelizeMaps, ) from ndsl.dsl.dace.stree.optimizations.statistics import TreeOptimizationStatistics from ndsl.logging import ndsl_log_on_rank_0 @@ -94,6 +95,7 @@ def __init__( # TODO: Is it safe? Deactivate for now # InlineVertical2DWrite(), CartesianMerge(backend), + KernelizeMaps(backend), # 🐞 Transient refine can't be used # because of bugs transients showing in code generation # CartesianRefineTransients(backend), diff --git a/tests/dsl/dace/stree/common/test_loops.py b/tests/dsl/dace/stree/common/test_loops.py index 3020f551..a2c12d76 100644 --- a/tests/dsl/dace/stree/common/test_loops.py +++ b/tests/dsl/dace/stree/common/test_loops.py @@ -6,6 +6,7 @@ AxisIterator, is_axis_for, is_axis_map, + is_cartesian_axis, ) @@ -49,6 +50,32 @@ def test_is_axis_map_wrong_iterator() -> None: assert not is_axis_map(node, AxisIterator._J) +def test_is_cartesian_axis() -> None: + map_i = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[] + ) + assert is_cartesian_axis(map_i) + + map_j = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", ["__j"], [(0, 3, 1)])), children=[] + ) + assert is_cartesian_axis(map_j) + + map_k = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", ["__k_1234"], [(0, 3, 1)])), children=[] + ) + assert is_cartesian_axis(map_k) + + for_k = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[]) + assert is_cartesian_axis(for_k) + + map_non_cartesian = tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_other_i", ["__i0"], [(0, 3, 1)])), + children=[], + ) + assert not is_cartesian_axis(map_non_cartesian) + + def test_is_axis_for_k() -> None: node = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[]) assert is_axis_for(node, AxisIterator._K) diff --git a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py new file mode 100644 index 00000000..e16790e1 --- /dev/null +++ b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py @@ -0,0 +1,181 @@ +import pytest +from dace import nodes +from dace.sdfg.state import LoopRegion + +from ndsl import Backend, NDSLRuntime, orchestrate +from ndsl.boilerplate import get_factories_single_tile +from ndsl.constants import I_DIM, J_DIM, K_DIM +from ndsl.dsl.gt4py import BACKWARD, FORWARD, PARALLEL, computation, interval +from ndsl.dsl.stencil import StencilFactory +from ndsl.dsl.typing import FloatField +from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree.optimizations import Factories + + +def stencil_kernelize(in_field: FloatField, out_field: FloatField) -> None: # type: ignore + with computation(PARALLEL), interval(...): + value = in_field * 2 + tmp = value + + with computation(FORWARD), interval(0, -1): + tmp = 0.5 * (tmp + tmp[0, 0, 1]) + + with computation(PARALLEL), interval(...): + out_field = tmp + + +def stencil_only_serial_noop( + in_field: FloatField, out_field: FloatField +) -> None: # type:ignore + with computation(FORWARD), interval(...): + tmp = in_field + + with computation(BACKWARD), interval(...): + out_field = tmp + + +def stencil_only_parallel_noop( + in_field: FloatField, out_field: FloatField +) -> None: # type:ignore + with computation(PARALLEL), interval(0, 2): + out_field = in_field + + with computation(PARALLEL), interval(-2, None): + out_field = in_field + 1 + + +class OrchestratedCode(NDSLRuntime): + def __init__(self, stencil_factory: StencilFactory) -> None: + super().__init__(stencil_factory) + + methods_to_orchestrate = [ + "kernelize_k", + "only_serial_noop", + "only_parallel_noop", + ] + for method in methods_to_orchestrate: + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate=method, + ) + + self._stencil_kernelize_k = stencil_factory.from_dims_halo( + func=stencil_kernelize, + compute_dims=(I_DIM, J_DIM, K_DIM), + ) + self._stencil_only_serial_noop = stencil_factory.from_dims_halo( + func=stencil_only_serial_noop, + compute_dims=(I_DIM, J_DIM, K_DIM), + ) + self._stencil_only_parallel_noop = stencil_factory.from_dims_halo( + func=stencil_only_parallel_noop, + compute_dims=(I_DIM, J_DIM, K_DIM), + ) + + def kernelize_k(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._stencil_kernelize_k(in_field, out_field) + + def only_serial_noop(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._stencil_only_serial_noop(in_field, out_field) + + def only_parallel_noop(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore + self._stencil_only_parallel_noop(in_field, out_field) + + +class TestKernelizeMaps: + @pytest.fixture( + params=[ + "orch:dace:cpu:IJK", + pytest.param("orch:dace:gpu:IJK", marks=pytest.mark.gpu), + ] + ) + def factories(self, request: pytest.FixtureRequest) -> Factories: + domain = (3, 4, 5) + return get_factories_single_tile( + nx=domain[0], + ny=domain[1], + nz=domain[2], + nhalo=0, + backend=Backend(request.param), + ) + + def test_kernelize_k_gpu(self, factories: Factories) -> None: + stencil_factory, quantity_factory = factories + code = OrchestratedCode(stencil_factory) + + in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), "") + out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), "") + + with StreeOptimization(): + code.kernelize_k(in_field, out_field) + + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + + if stencil_factory.backend.is_gpu_backend(): + # check for kernelization + all_maps = [ + node + for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(node, nodes.MapEntry) + ] + + ij_maps = 0 + ijk_maps = 0 + for map_entry in all_maps: + if map_entry.map.params == ["__i", "__j"]: + ij_maps += 1 + elif len(map_entry.map.params) == 3: + params = map_entry.map.params + k_param = params[2] + if ( + params[0:2] == ["__i", "__j"] + and isinstance(k_param, str) + and k_param.startswith("__k") + ): + ijk_maps += 1 + + # expect two IJK-maps and one IJ-map + assert ij_maps == 1 + assert ijk_maps == 2 + assert len(all_maps) == 3 + + all_loop_regions = [ + node + for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(node, LoopRegion) + ] + # expect one k-loop is preserved + assert len(all_loop_regions) == 1 + assert all_loop_regions[0].loop_variable.startswith("__k") + else: + # check that we keep IJ loops merged + all_maps = [ + node + for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(node, nodes.MapEntry) + ] + + ij_maps = 0 + k_maps = 0 + for map_entry in all_maps: + if map_entry.map.params == ["__i", "__j"]: + ij_maps += 1 + elif len(map_entry.map.params) == 1: + param = map_entry.map.params[0] + if isinstance(param, str) and param.startswith("__k"): + k_maps += 1 + + # expect one IJ-map and two K-maps + assert ij_maps == 1 + assert k_maps == 2 + assert len(all_maps) == 3 + + all_loop_regions = [ + node + for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive() + if isinstance(node, LoopRegion) + ] + # expect one k-loop is preserved + assert len(all_loop_regions) == 1 + assert all_loop_regions[0].loop_variable.startswith("__k") From 42e16de0d49f39bc3839ec956f0906ebf583d2fb Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 12 Jun 2026 17:35:05 +0200 Subject: [PATCH 050/101] fix some red squiggly lines in vscode :) --- ndsl/dsl/dace/stree/optimizations/__init__.py | 2 ++ ndsl/dsl/dace/stree/optimizations/axis_merge.py | 7 +++++-- ndsl/dsl/dace/stree/optimizations/common/loops.py | 7 ++++--- ndsl/dsl/dace/stree/optimizations/refine_transients.py | 3 ++- ndsl/dsl/dace/stree/pipeline.py | 4 +++- 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index 2c1d7c3d..b1f69aa1 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -9,6 +9,7 @@ ) from .refine_transients import CartesianRefineTransients from .remove_loops import InlineVertical2DWrite +from .statistics import TreeOptimizationStatistics __all__ = [ @@ -21,4 +22,5 @@ "MergeConditionals", "CartesianRefineTransients", "InlineVertical2DWrite", + "TreeOptimizationStatistics", ] diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 0c5a476f..3a0ea1f6 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -222,7 +222,8 @@ def _map_overcompute_merge( ) # - then, guard children to only run in their respective range - axis_as_str = the_map.node.params[0] + axis_as_str = the_map.node.map.params[0] + assert isinstance(axis_as_str, str) first_map = InsertOvercomputationGuard( axis_as_str, merged_range=merged_range, original_range=first_range ).visit(the_map) @@ -231,7 +232,9 @@ def _map_overcompute_merge( merged_range=merged_range, original_range=second_range, ).visit(next_node) - merged_children: list[tn.MapScope] = [ + assert isinstance(first_map, tn.MapScope) + assert isinstance(second_map, tn.MapScope) + merged_children: list[tn.ScheduleTreeNode] = [ *first_map.children, *second_map.children, ] diff --git a/ndsl/dsl/dace/stree/optimizations/common/loops.py b/ndsl/dsl/dace/stree/optimizations/common/loops.py index 9433eb76..1f057954 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/loops.py +++ b/ndsl/dsl/dace/stree/optimizations/common/loops.py @@ -5,11 +5,12 @@ def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool: """Returns true if node is a Map over the given axis.""" - map_parameter = node.node.map.params - if len(map_parameter) != 1: + if len(node.node.map.params) != 1: return False - return axis.is_equal(map_parameter[0]) + param = node.node.map.params[0] + assert isinstance(param, str) + return axis.is_equal(param) def is_cartesian_axis(node: tn.MapScope | tn.ForScope) -> bool: diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 6c83d04d..39d213fc 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -102,7 +102,8 @@ def _find_first_map_or_loop( parent = node.parent while parent is not None: if isinstance(parent, tn.MapScope): - for p in parent.node.params: + for p in parent.node.map.params: + assert isinstance(p, str) if p.startswith(axis.as_str()): return parent.node diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 37247324..3b31296c 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -8,8 +8,8 @@ CartesianRefineTransients, CleanUpScheduleTree, KernelizeMaps, + TreeOptimizationStatistics, ) -from ndsl.dsl.dace.stree.optimizations.statistics import TreeOptimizationStatistics class StreePipeline: @@ -40,6 +40,7 @@ def run( tree_stats.original(stree) for i, p in enumerate(self.passes): + path: Path | None = None if verbose: path = self.cache_directory / f"pass{i}_{p}.txt" ndsl_log_on_rank_0.info(f"[Stree OPT] {p} (saving {path} after)") @@ -47,6 +48,7 @@ def run( p.visit(stree) if verbose: + assert path is not None with open(path, "w+") as f: f.write(stree.as_string()) From 2c649925564c22c9bd184a4d73039446f85cd538 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Sat, 13 Jun 2026 09:39:27 +0200 Subject: [PATCH 051/101] Don't raise kernelize_maps in KJI layout --- ndsl/dsl/dace/stree/optimizations/kernelize_maps.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py b/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py index 4beb766a..11135ef6 100644 --- a/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py +++ b/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py @@ -74,6 +74,8 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: def _axis_order(self) -> list[AxisIterator]: if self._backend.loop_order == BackendLoopOrder.IJK: return [AxisIterator._J, AxisIterator._I] + if self._backend.loop_order == BackendLoopOrder.KJI: + return [AxisIterator._J, AxisIterator._K] raise NotImplementedError( f"KernelizeMaps is not configured for loop order {self._backend.loop_order}." From 546ce16da59302ab34646f8bab86499ca15c49f7 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 15 Jun 2026 12:55:45 -0400 Subject: [PATCH 052/101] Move sdfg save on verbose into a DaceProgress --- ndsl/dsl/dace/orchestration.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 3bfebab6..a3c960df 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -186,11 +186,11 @@ def _build_sdfg( repl_dict[sym] = val my_sdfg.replace_dict(repl_dict) - if config.verbose_orchestration: - sdfg.save( - os.path.abspath(f"{sdfg.build_folder}/00-combined_from_stencils.sdfgz"), - compress=True, - ) + if config.verbose_orchestration: + sdfg.save( + os.path.abspath(f"{sdfg.build_folder}/00-combined_from_stencils.sdfgz"), + compress=True, + ) with DaCeProgress(config, "Simplify (1)"): _simplify(sdfg) From 145ee101e0a6cf76531d58936a8a880f37b91034 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 15 Jun 2026 12:55:58 -0400 Subject: [PATCH 053/101] Fix init of StreePipeline --- ndsl/dsl/dace/stree/pipeline.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 3b31296c..ad02dfbe 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -77,7 +77,7 @@ def __init__( CartesianRefineTransients(backend), ] super().__init__( - passes=passes, + passes=passes if passes is not None else [], cache_directory=cache_directory, ) @@ -101,10 +101,6 @@ def __init__( # because of bugs transients showing in code generation # CartesianRefineTransients(backend), ] - super().__init__( - passes=passes, - cache_directory=cache_directory, - ) super().__init__( passes=passes if passes is not None else [], cache_directory=cache_directory, From c997988179049aa3919f562d0dbf7a148e0ce22c Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 15 Jun 2026 12:56:38 -0400 Subject: [PATCH 054/101] Lint --- ndsl/dsl/dace/orchestration.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index a3c960df..71c020e3 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -188,7 +188,9 @@ def _build_sdfg( if config.verbose_orchestration: sdfg.save( - os.path.abspath(f"{sdfg.build_folder}/00-combined_from_stencils.sdfgz"), + os.path.abspath( + f"{sdfg.build_folder}/00-combined_from_stencils.sdfgz" + ), compress=True, ) From a390e532f30562c6eeda411ef81df50f5957be3b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 17 Jun 2026 18:29:47 +0200 Subject: [PATCH 055/101] allow map collapse with different schedules, unique loop region names --- external/dace | 2 +- external/gt4py | 2 +- ndsl/dsl/dace/orchestration.py | 10 +++++++++- ndsl/dsl/dace/stree/optimizations/clean_tree.py | 15 +++++---------- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/external/dace b/external/dace index c8cb49cc..c6bc57a3 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit c8cb49cc29cc281a5c92bd84e231e31cb6dc561d +Subproject commit c6bc57a3f23d2427da3cb23ece13255de4a9af47 diff --git a/external/gt4py b/external/gt4py index 65299a79..ddf3cb33 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 65299a797510d3a621d614b5ca544709985656eb +Subproject commit ddf3cb337a7d21545825feb416627e44a1ab1876 diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 71c020e3..854b5ffa 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -187,6 +187,7 @@ def _build_sdfg( my_sdfg.replace_dict(repl_dict) if config.verbose_orchestration: + ndsl_log.debug("saving 00-combined_from_stencils.sdfgz") sdfg.save( os.path.abspath( f"{sdfg.build_folder}/00-combined_from_stencils.sdfgz" @@ -197,6 +198,7 @@ def _build_sdfg( with DaCeProgress(config, "Simplify (1)"): _simplify(sdfg) if config.verbose_orchestration: + ndsl_log.debug("saving 01-simplify.sdfgz") sdfg.save( os.path.abspath(f"{sdfg.build_folder}/01-simplify_1.sdfgz"), compress=True, @@ -220,6 +222,7 @@ def _build_sdfg( ) stree = sdfg.as_schedule_tree() if config.verbose_orchestration: + ndsl_log.debug("saving 02-pre_opt.stree.txt") with open( os.path.abspath(f"{sdfg.build_folder}/02-pre_opt.stree.txt"), "w+", @@ -235,6 +238,7 @@ def _build_sdfg( ) pipeline.run(stree, verbose=config.verbose_schedule_tree_optimizations) if config.verbose_orchestration: + ndsl_log.debug("saving 03-post_opt.stree.txt") with open( os.path.abspath(f"{sdfg.build_folder}/03-post_opt.stree.txt"), "w+", @@ -244,6 +248,7 @@ def _build_sdfg( with DaCeProgress(config, "Schedule Tree: go back to SDFG"): sdfg = _tree_as_sdfg(stree) if config.verbose_orchestration: + ndsl_log.debug("saving 04-from_stree.sdfgz") sdfg.save( os.path.abspath(f"{sdfg.build_folder}/04-from_stree.sdfgz"), compress=True, @@ -252,7 +257,8 @@ def _build_sdfg( # We want all maps properly collapse to make sure the codegen will see nD parallel # axis as a single kernelizable map with DaCeProgress(config, "Collapse maps"): - sdfg.apply_transformations_repeated(MapCollapse) + # allow `MapCollapse` to collapse maps with different schedules + sdfg.apply_transformations_repeated(MapCollapse, permissive=True) with DaCeProgress(config, "Make transient persistents"): # Make the transients array persistents @@ -286,6 +292,7 @@ def _build_sdfg( # Apply common GPU transforms (includes a simplify) sdfg.apply_gpu_transformations() if config.verbose_orchestration: + ndsl_log.debug("saving 05-apply_gpu_xforms.sdfgz") sdfg.save( os.path.abspath( f"{sdfg.build_folder}/05-apply_gpu_xforms.sdfgz" @@ -296,6 +303,7 @@ def _build_sdfg( with DaCeProgress(config, "Simplify (2)"): _simplify(sdfg) if config.verbose_orchestration: + ndsl_log.debug("saving 05-simplify_2.sdfgz") sdfg.save( os.path.abspath(f"{sdfg.build_folder}/05-simplify_2.sdfgz"), compress=True, diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py index 7d3b5558..93798f42 100644 --- a/ndsl/dsl/dace/stree/optimizations/clean_tree.py +++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py @@ -26,32 +26,28 @@ def _remove_state_boundaries_from_children( def visit_WhileScope(self, node: tn.WhileScope) -> tn.WhileScope: self._remove_state_boundaries_from_children(node) - for child in node.children: - self.visit(child) + self.generic_visit(node) return node def visit_ForScope(self, node: tn.ForScope) -> tn.ForScope: self._remove_state_boundaries_from_children(node) - for child in node.children: - self.visit(child) + self.generic_visit(node) return node def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: self._remove_state_boundaries_from_children(node) - for child in node.children: - self.visit(child) + self.generic_visit(node) return node def visit_IfScope(self, node: tn.IfScope) -> tn.IfScope: self._remove_state_boundaries_from_children(node) - for child in node.children: - self.visit(child) + self.generic_visit(node) return node @@ -60,8 +56,7 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRo self._remove_state_boundaries_from_children(node) - for child in node.children: - self.visit(child) + self.generic_visit(node) ndsl_log.debug(f"{self}: removed {self._removed_state_boundaries} nodes") return node From 43a2903fe4ffdb8bf6a6dbe78d111663a00ddad7 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 17 Jun 2026 17:07:39 -0400 Subject: [PATCH 056/101] Detect write-after-write where offset/index differs --- .../dace/stree/optimizations/axis_merge.py | 7 +-- .../dace/stree/optimizations/common/memlet.py | 62 +++++++------------ 2 files changed, 25 insertions(+), 44 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 3a0ea1f6..5671059b 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -34,10 +34,9 @@ def _both_same_single_axis_maps( def _can_merge_axis_maps( first: tn.MapScope, second: tn.MapScope, axis: AxisIterator ) -> bool: - if _both_same_single_axis_maps(first, second, axis): - if no_data_dependencies_on_cartesian_axis(first, second, axis): - return True - return False + return _both_same_single_axis_maps( + first, second, axis + ) and no_data_dependencies_on_cartesian_axis(first, second, axis) class InsertOvercomputationGuard(tn.ScheduleNodeTransformer): diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index 8c745a32..426f4bcc 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -39,13 +39,16 @@ def no_data_dependencies_on_cartesian_axis( second: stree.MapScope, axis: AxisIterator, ) -> bool: - """Check for read after write. Allow when indexation on the axis - is not offset.""" + """Check for read after write and write after write with different offsets.""" write_collector = MemletCollector(collect_reads=False) write_collector.visit(first) + other_writes = MemletCollector(collect_reads=False) + other_writes.visit(second) read_collector = MemletCollector(collect_writes=False) read_collector.visit(second) + write_index = {} + for write in write_collector.out_memlets: # TODO: this can be optimized to allow non-overlapping intervals and such in the future @@ -58,6 +61,23 @@ def no_data_dependencies_on_cartesian_axis( previous_axis_index = normalize_cartesian_indexation( write.subset[axis_index][0], axis ) + + # Write-after-write with an offset case + for other_write in other_writes.out_memlets: + if write.data == other_write.data: + if previous_axis_index != normalize_cartesian_indexation( + other_write.subset[axis_index][0], axis + ): + ndsl_log.debug( + f"[{axis.name} Merge] Found write after write conflict " + f"for {write.data} " + f"w/ different offset to {axis.name} (" + f"first write at {previous_axis_index}, " + f"second write at {other_write.subset[axis_index][0]})" + ) + return False + + # Read-after-write with an offset case for read in read_collector.in_memlets: if write.data == read.data: if previous_axis_index != normalize_cartesian_indexation( @@ -71,45 +91,7 @@ def no_data_dependencies_on_cartesian_axis( f"read at {read.subset[axis_index][0]})" ) return False - return True - -def no_data_dependencies( - first: stree.MapScope, - second: stree.MapScope, - restrict_check_to_k: bool = False, -) -> bool: - write_collector = MemletCollector(collect_reads=False) - write_collector.visit(first) - read_collector = MemletCollector(collect_writes=False) - read_collector.visit(second) - for write in write_collector.out_memlets: - # Make sure we don't have read after write conditions. - # TODO: this can be optimized to allow non-overlapping intervals and such in the future - if restrict_check_to_k: - if write.subset.dims() < 3: - # Case of 2D write - no K dependency - continue - - previous_k_index = write.subset[2][0] - for read in read_collector.in_memlets: - if write.data == read.data: - if previous_k_index != read.subset[2][0]: - print( - "[K Merge] Found read after write conflict " - f"for {write.data} " - "w/ different offset to K (" - f"write at {write.subset[2][0]}, " - f"read at {read.subset[2][0]})" - ) - return False - - else: - if write.data in [read.data for read in read_collector.in_memlets]: - print( - f"[All dims merge] Found potential read after write conflict for {write.data}" - ) - return False return True From 5124545a003e5f2ae60f744d994d5eb7a54c02e2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 17 Jun 2026 17:08:10 -0400 Subject: [PATCH 057/101] Lint --- ndsl/dsl/dace/stree/optimizations/common/memlet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index 426f4bcc..266be19d 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -47,7 +47,6 @@ def no_data_dependencies_on_cartesian_axis( other_writes.visit(second) read_collector = MemletCollector(collect_writes=False) read_collector.visit(second) - write_index = {} for write in write_collector.out_memlets: # TODO: this can be optimized to allow non-overlapping intervals and such in the future From bd14cc6e51b6b9d43a56000922d8f4fb5f593da6 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 18 Jun 2026 14:12:49 -0400 Subject: [PATCH 058/101] Remove `lineinfo` in `DaCe` --- ndsl/dsl/dace/dace_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 125d6008..b722df77 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -356,6 +356,9 @@ def __init__( value="c", ) + # Debug lineinfo is incorrect anyway for the stencils + dace.config.Config.set("compiler", "lineinfo", value="none") + # Attempt to kill the dace.conf to avoid confusion dace_conf_to_kill = dace.config.Config.cfg_filename() if dace_conf_to_kill is not None: From c46c12efb5a01d509da4b2de9939216346c3ea51 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 18 Jun 2026 14:26:33 -0400 Subject: [PATCH 059/101] Mvoe `gt4py` to `tmp_June26_01` to bring in the `better_parallel_kernel` branch --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index ddf3cb33..e256ec5f 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit ddf3cb337a7d21545825feb416627e44a1ab1876 +Subproject commit e256ec5f2ae79e6240f7b5a4a29c9647877c12f4 From 741994f1df4f584205afac6cf7783e8f22a9f993 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 18 Jun 2026 14:27:21 -0400 Subject: [PATCH 060/101] dace: fix for read-after-write in input_memlets --- external/dace | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/dace b/external/dace index c6bc57a3..d186d86d 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit c6bc57a3f23d2427da3cb23ece13255de4a9af47 +Subproject commit d186d86dea15f7852545dcde0c4f5b9e6d4f072b From 63c57736f15a98a2db82aa6ba9e980f2faf5b2fb Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 19 Jun 2026 14:04:41 +0200 Subject: [PATCH 061/101] feat: enable/disable stree via dace_config --- ndsl/dsl/dace/dace_config.py | 30 ++++++++ ndsl/dsl/dace/orchestration.py | 11 +-- tests/dsl/dace/stree/__init__.py | 4 +- .../optimizations/test_kernelize_maps.py | 9 ++- .../dace/stree/optimizations/test_merge.py | 56 ++++++-------- .../test_offgrid_conditionals.py | 18 ++--- .../dace/stree/optimizations/test_pipeline.py | 5 +- .../stree/optimizations/test_remove_loops.py | 16 ++-- .../optimizations/test_transient_refine.py | 76 +++++++++---------- tests/dsl/dace/stree/sdfg_stree_tools.py | 5 +- tests/dsl/test_dace_config.py | 36 +++++++++ 11 files changed, 159 insertions(+), 107 deletions(-) diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index b722df77..3f3f70fc 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -2,6 +2,7 @@ import enum import os +import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Self @@ -215,6 +216,11 @@ def __init__( os.getenv("NDSL_VERBOSE_SCHEDULE_TREE_OPTIMIZATIONS", "False") == "True" ) + # Schedule tree optimization + self._schedule_tree_optimization = ( + os.getenv("NDSL_STREE_OPT", "False") == "True" + ) + # We hijack the optimization level of GT4Py because we don't # have the configuration at NDSL level, but we do use the GT4Py # level @@ -405,6 +411,30 @@ def get_orchestrate(self) -> DaCeOrchestration: def get_sync_debug(self) -> bool: return dace.config.Config.get_bool("compiler", "cuda", "syncdebug") + def enable_schedule_tree(self) -> None: + """Enables optimizations based on the schedule tree.""" + if not self.is_dace_orchestrated(): + warnings.warn( + "Enabling schedule tree optimization on a non-orchestrated backend has no effect.", + UserWarning, + stacklevel=2, + ) + self._schedule_tree_optimization = True + + def disable_schedule_tree(self) -> None: + """Disables optimizations based on the schedule tree.""" + if not self.is_dace_orchestrated(): + warnings.warn( + "Disabling schedule tree optimization on a non-orchestrated backend has no effect.", + UserWarning, + stacklevel=2, + ) + self._schedule_tree_optimization = False + + def schedule_tree_enabled(self) -> bool: + """Returns true iff schedule tree optimizations are enabled.""" + return self._schedule_tree_optimization + def as_dict(self) -> dict[str, Any]: return { "_orchestrate": str(self._orchestrate.name), diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 854b5ffa..97b96836 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -51,11 +51,6 @@ from ndsl.quantity import Quantity, State -_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = ( - os.environ.get("NDSL_STREE_OPT", "False") == "True" -) -"""INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer.""" - _INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES: list[tn.ScheduleNodeVisitor] | None = None @@ -204,7 +199,7 @@ def _build_sdfg( compress=True, ) - if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + if config.schedule_tree_enabled(): # Here be 🐉 - but tests exists in test_optimization.py with DaCeProgress(config, "Schedule Tree: generate from SDFG"): # Break all loops into uni-dimensional loops to simplify optimizations @@ -599,7 +594,7 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] **kwargs, ) # Label the code (this is the topmost code) - if sdfg is not None and _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + if sdfg is not None and self.lazy_method.config.schedule_tree_enabled(): set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=True) return _call_sdfg( self.daceprog, @@ -612,7 +607,7 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] sdfg = _parse_sdfg(self.daceprog, self.lazy_method.config, *args, **kwargs) # Label the code - if sdfg is not None and _INTERNAL__SCHEDULE_TREE_OPTIMIZATION: + if sdfg is not None and self.lazy_method.config.schedule_tree_enabled(): set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=False) return sdfg diff --git a/tests/dsl/dace/stree/__init__.py b/tests/dsl/dace/stree/__init__.py index 2fa38d13..b43c1d92 100644 --- a/tests/dsl/dace/stree/__init__.py +++ b/tests/dsl/dace/stree/__init__.py @@ -1,7 +1,7 @@ -from .sdfg_stree_tools import StreeOptimization, get_SDFG_and_purge +from .sdfg_stree_tools import StreePipeline, get_SDFG_and_purge __all__ = [ - "StreeOptimization", + "StreePipeline", "get_SDFG_and_purge", ] diff --git a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py index e16790e1..8ac4abf7 100644 --- a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py +++ b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py @@ -8,7 +8,7 @@ from ndsl.dsl.gt4py import BACKWARD, FORWARD, PARALLEL, computation, interval from ndsl.dsl.stencil import StencilFactory from ndsl.dsl.typing import FloatField -from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree import get_SDFG_and_purge from tests.dsl.dace.stree.optimizations import Factories @@ -92,13 +92,15 @@ class TestKernelizeMaps: ) def factories(self, request: pytest.FixtureRequest) -> Factories: domain = (3, 4, 5) - return get_factories_single_tile( + stencil_factory, quantity_factory = get_factories_single_tile( nx=domain[0], ny=domain[1], nz=domain[2], nhalo=0, backend=Backend(request.param), ) + stencil_factory.config.dace_config.enable_schedule_tree() + return stencil_factory, quantity_factory def test_kernelize_k_gpu(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories @@ -107,8 +109,7 @@ def test_kernelize_k_gpu(self, factories: Factories) -> None: in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), "") out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), "") - with StreeOptimization(): - code.kernelize_k(in_field, out_field) + code.kernelize_k(in_field, out_field) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index 1a9ed508..53db91d5 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -10,7 +10,7 @@ from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField -from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree import get_SDFG_and_purge from tests.dsl.dace.stree.optimizations import Factories @@ -133,9 +133,11 @@ class TestStreeMergeMapsIJK: @pytest.fixture def factories(self) -> Factories: domain = (3, 3, 4) - return get_factories_single_tile_orchestrated( + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend=Backend("orch:dace:cpu:IJK") ) + stencil_factory.config.dace_config.enable_schedule_tree() + return stencil_factory, quantity_factory @pytest.fixture def code(self, factories: Factories) -> OrchestratedCode: @@ -146,8 +148,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.trivial_merge(in_qty, out_qty) + code.trivial_merge(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) all_maps = [ @@ -166,8 +167,7 @@ def test_missing_merge_of_forscope_and_map( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.missing_merge_of_forscope_and_map(in_qty, out_qty) + code.missing_merge_of_forscope_and_map(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -190,8 +190,7 @@ def test_overcompute_merge( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.overcompute_merge(in_qty, out_qty) + code.overcompute_merge(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -208,9 +207,8 @@ def test_block_merge_when_dependencies_are_found( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # Forbid merging when data dependencies are detected - code.block_merge_when_dependencies_are_found(in_qty, out_qty) + # Forbid merging when data dependencies are detected + code.block_merge_when_dependencies_are_found(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -227,10 +225,9 @@ def test_push_non_cartesian_for( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # Push non-cartesian ForScope inwards, which allow to potentially - # merge cartesian maps - code.push_non_cartesian_for(in_qty, out_qty) + # Push non-cartesian ForScope inwards, which allow to potentially + # merge cartesian maps + code.push_non_cartesian_for(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -251,9 +248,11 @@ class TestStreeMergeMapsKJI: @pytest.fixture def factories(self) -> Factories: domain = (3, 3, 4) - return get_factories_single_tile_orchestrated( + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend=Backend("orch:dace:cpu:KJI") ) + stencil_factory.config.dace_config.enable_schedule_tree() + return stencil_factory, quantity_factory @pytest.fixture def code(self, factories: Factories) -> OrchestratedCode: @@ -264,8 +263,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.trivial_merge(in_qty, out_qty) + code.trivial_merge(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) all_maps = [ @@ -284,9 +282,8 @@ def test_missing_merge_of_forscope_and_map( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # K iterative loop - blocks all merges - code.missing_merge_of_forscope_and_map(in_qty, out_qty) + # K iterative loop - blocks all merges + code.missing_merge_of_forscope_and_map(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -309,9 +306,8 @@ def test_overcompute_merge( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # Overcompute merge in K - we merge and introduce an If guard - code.overcompute_merge(in_qty, out_qty) + # Overcompute merge in K - we merge and introduce an If guard + code.overcompute_merge(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -328,9 +324,8 @@ def test_block_merge_when_dependencies_are_found( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # Forbid merging when data dependencies are detected - code.block_merge_when_dependencies_are_found(in_qty, out_qty) + # Forbid merging when data dependencies are detected + code.block_merge_when_dependencies_are_found(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ @@ -347,10 +342,9 @@ def test_push_non_cartesian_for( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - # Push non-cartesian ForScope inwards, which allow to potentially - # merge cartesian maps - code.push_non_cartesian_for(in_qty, out_qty) + # Push non-cartesian ForScope inwards, which allow to potentially + # merge cartesian maps + code.push_non_cartesian_for(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg all_maps = [ diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index fcfe33bc..e9979dcb 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -5,7 +5,7 @@ from ndsl.boilerplate import get_factories_single_tile from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.typing import FloatField -from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree import get_SDFG_and_purge from tests.dsl.dace.stree.optimizations import Factories @@ -62,9 +62,11 @@ class TestStreeInlineOffgridConditionals: @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) def factories(self, request: pytest.FixtureRequest) -> Factories: domain = (3, 3, 4) - return get_factories_single_tile( + stencil_factory, quantity_factory = get_factories_single_tile( domain[0], domain[1], domain[2], 0, backend=Backend(request.param) ) + stencil_factory.config.dace_config.enable_schedule_tree() + return stencil_factory, quantity_factory def test_happy_case(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories @@ -73,8 +75,7 @@ def test_happy_case(self, factories: Factories) -> None: in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.happy_case(in_quantity, out_quantity) + code.happy_case(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -92,8 +93,7 @@ def test_happy_case_2(self, factories: Factories) -> None: in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.happy_case_2(in_quantity, out_quantity) + code.happy_case_2(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -111,8 +111,7 @@ def test_blocked_by_else(self, factories: Factories) -> None: in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.blocked_by_else(in_quantity, out_quantity) + code.blocked_by_else(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -130,8 +129,7 @@ def test_blocked_by_other_nodes(self, factories: Factories) -> None: in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code.blocked_by_other_nodes(in_quantity, out_quantity) + code.blocked_by_other_nodes(in_quantity, out_quantity) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) diff --git a/tests/dsl/dace/stree/optimizations/test_pipeline.py b/tests/dsl/dace/stree/optimizations/test_pipeline.py index 677790bc..b9c4eb78 100644 --- a/tests/dsl/dace/stree/optimizations/test_pipeline.py +++ b/tests/dsl/dace/stree/optimizations/test_pipeline.py @@ -4,7 +4,6 @@ from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import PARALLEL, computation, interval from ndsl.dsl.typing import FloatField -from tests.dsl.dace.stree import StreeOptimization def double_map(in_field: FloatField, out_field: FloatField): @@ -32,12 +31,12 @@ def test_stree_roundtrip_no_opt(): stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend=Backend.cpu() ) + stencil_factory.config.dace_config.enable_schedule_tree() code = TriviallyMergeableCode(stencil_factory) in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(): - code(in_qty, out_qty) + code(in_qty, out_qty) assert (out_qty.field[:] == 4).all() diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 9469f204..77082ef4 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -15,7 +15,7 @@ from ndsl.dsl.gt4py import FORWARD, computation, interval from ndsl.dsl.typing import FloatField, FloatFieldIJ from ndsl.stencils import copy -from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree import StreePipeline, get_SDFG_and_purge from tests.dsl.dace.stree.optimizations import Factories @@ -118,9 +118,11 @@ class TestStree2DWriteInline: def factories(self, request: pytest.FixtureRequest) -> Factories: domain = (3, 3, 4) - return get_factories_single_tile( + stencil_factory, quantity_factory = get_factories_single_tile( domain[0], domain[1], domain[2], 0, backend=Backend(request.param) ) + stencil_factory.config.dace_config.enable_schedule_tree() + return stencil_factory, quantity_factory def test_common_2D_write(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories @@ -136,7 +138,7 @@ def test_common_2D_write(self, factories: Factories) -> None: out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, 0] = Float(32.0) - with StreeOptimization(passes=pipeline): + with StreePipeline(passes=pipeline): code.write_at_0(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -169,7 +171,7 @@ def test_2D_write_K_top(self, factories: Factories) -> None: out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, -1] = Float(32.0) - with StreeOptimization(passes=pipeline): + with StreePipeline(passes=pipeline): code.write_at_top(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -201,7 +203,7 @@ def test_do_not_inline(self, factories: Factories) -> None: in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreeOptimization(passes=pipeline): + with StreePipeline(passes=pipeline): code.do_not_inline(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -234,7 +236,7 @@ def test_combined_stencils(self, factories: Factories) -> None: field_2 = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") - with StreeOptimization(passes=pipeline): + with StreePipeline(passes=pipeline): code.combined_stencils(field, field_2, field_IJ) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -272,7 +274,7 @@ def test_multiple_statements(self, factories: Factories) -> None: field_IJ_2 = quantity_factory.zeros([I_DIM, J_DIM], "") field.field[:, :, 0] = Float(42.0) - with StreeOptimization(passes=pipeline): + with StreePipeline(passes=pipeline): code.multiple_statements(field, field_IJ, field_IJ_2) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) diff --git a/tests/dsl/dace/stree/optimizations/test_transient_refine.py b/tests/dsl/dace/stree/optimizations/test_transient_refine.py index 9795957a..bf23cbbc 100644 --- a/tests/dsl/dace/stree/optimizations/test_transient_refine.py +++ b/tests/dsl/dace/stree/optimizations/test_transient_refine.py @@ -4,7 +4,7 @@ from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import IJK, PARALLEL, Field, J, K, computation, interval from ndsl.dsl.typing import Float, FloatField -from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge +from tests.dsl.dace.stree import get_SDFG_and_purge DATADIM_SIZE = 8 @@ -95,6 +95,7 @@ def test_stree_roundtrip_transient_is_refined() -> None: stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend=Backend.cpu() ) + stencil_factory.config.dace_config.enable_schedule_tree() in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") @@ -105,40 +106,39 @@ def test_stree_roundtrip_transient_is_refined() -> None: code = TransientRefineableCode(stencil_factory, quantity_factory) - with StreeOptimization(): - # Refine to scalar - code.refine_to_scalar(in_qty, out_qty) - precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - for array in precompiled_sdfg.sdfg.arrays.values(): - if array.transient: - assert array.shape == (1, 1, 1) - - # Refine cartesian axis to buffers - # IJ merges - K is a buffer - code.refine_to_K_buffer(in_qty, out_qty) - precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - for array in precompiled_sdfg.sdfg.arrays.values(): - if array.transient: - assert array.shape == ( - 1, - 1, - domain[2] + 1, # Quantity are domain size + 1 - ) - - # I merges - JK buffer - code.refine_to_JK_buffer(in_qty, out_qty) - precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - for array in precompiled_sdfg.sdfg.arrays.values(): - if array.transient: - assert array.shape == ( - 1, - domain[1] + 1, # Quantity are domain size + 1 - domain[2] + 1, - ) - - # Refine to remaining data dimensions - code.do_not_refine_datadims(in_qty_ddim, out_qty_ddim) - precompiled_sdfg = get_SDFG_and_purge(stencil_factory) - for array in precompiled_sdfg.sdfg.arrays.values(): - if array.transient: - assert array.shape == (1, 1, 1, DATADIM_SIZE) or len(array.shape) == 1 + # Refine to scalar + code.refine_to_scalar(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1) + + # Refine cartesian axis to buffers + # IJ merges - K is a buffer + code.refine_to_K_buffer(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == ( + 1, + 1, + domain[2] + 1, # Quantity are domain size + 1 + ) + + # I merges - JK buffer + code.refine_to_JK_buffer(in_qty, out_qty) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == ( + 1, + domain[1] + 1, # Quantity are domain size + 1 + domain[2] + 1, + ) + + # Refine to remaining data dimensions + code.do_not_refine_datadims(in_qty_ddim, out_qty_ddim) + precompiled_sdfg = get_SDFG_and_purge(stencil_factory) + for array in precompiled_sdfg.sdfg.arrays.values(): + if array.transient: + assert array.shape == (1, 1, 1, DATADIM_SIZE) or len(array.shape) == 1 diff --git a/tests/dsl/dace/stree/sdfg_stree_tools.py b/tests/dsl/dace/stree/sdfg_stree_tools.py index b913a134..aeb149a5 100644 --- a/tests/dsl/dace/stree/sdfg_stree_tools.py +++ b/tests/dsl/dace/stree/sdfg_stree_tools.py @@ -21,14 +21,12 @@ def get_SDFG_and_purge(stencil_factory: StencilFactory) -> dace.CompiledSDFG: return sdfg -class StreeOptimization: +class StreePipeline: def __init__(self, *, passes: list[tn.ScheduleNodeVisitor] | None = None) -> None: self.passes = passes def __enter__(self) -> None: self.original_passes = orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES - - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES = self.passes def __exit__( @@ -37,5 +35,4 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES = self.original_passes diff --git a/tests/dsl/test_dace_config.py b/tests/dsl/test_dace_config.py index 8003937f..0384ab42 100644 --- a/tests/dsl/test_dace_config.py +++ b/tests/dsl/test_dace_config.py @@ -1,5 +1,8 @@ +import os import unittest.mock +import pytest + from ndsl import CubedSpherePartitioner, DaceConfig, DaCeOrchestration, TilePartitioner from ndsl.comm.partitioner import Partitioner from ndsl.config import Backend @@ -155,3 +158,36 @@ def _does_compile(rank: int, partitioner: Partitioner) -> bool: for i in range(layout[0] * layout[1] * 6): compiling += 1 if _does_compile(i, partition) else 0 assert compiling == 9 + + +def test_schedule_tree_enable_disable() -> None: + config = DaceConfig( + communicator=None, + backend=Backend("orch:dace:cpu:KIJ"), + orchestration=DaCeOrchestration.BuildAndRun, + ) + + default = os.getenv("NDSL_STREE_OPT", "False") == "True" + assert config.schedule_tree_enabled() == default + + if default: + config.disable_schedule_tree() + assert not config.schedule_tree_enabled() + config.enable_schedule_tree() + assert config.schedule_tree_enabled() + else: + config.enable_schedule_tree() + assert config.schedule_tree_enabled() + config.disable_schedule_tree() + assert not config.schedule_tree_enabled() + + +def test_schedule_tree_warns_when_not_orchestrated(): + config = DaceConfig( + communicator=None, + backend=Backend("st:dace:cpu:KIJ"), + orchestration=DaCeOrchestration.BuildAndRun, + ) + + with pytest.warns(UserWarning, match="non-orchestrated backend"): + config.enable_schedule_tree() From 839af5d3c39b2c02aee1af5850734d4bed9b9ed5 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 19 Jun 2026 15:02:13 +0200 Subject: [PATCH 062/101] feat: turn on/off overcompute merge `NDSL_STREE_OVERCOMPUTE_MERGE` --- .../dace/stree/optimizations/axis_merge.py | 15 ++++++-- .../stree/optimizations/cartesian_merge.py | 12 +++++-- ndsl/dsl/dace/stree/pipeline.py | 19 ++++++---- .../dace/stree/optimizations/test_merge.py | 36 ++++++++++++++++++- 4 files changed, 69 insertions(+), 13 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py index 5671059b..196d3d0e 100644 --- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py @@ -90,20 +90,23 @@ class CartesianAxisMerge(tn.ScheduleNodeTransformer): Can do: - merge a given axis with the next maps at the same recursion level - - does overcompute to allow for more merging at the cost of an if + - can overcompute to allow for more merging at the cost of an if It expects: - All Maps and ForLoop are on a single axis - but doesn't check for it. Args: axis: AxisIterator to be merged + overcompute: merge at the cost of an if statement. """ - def __init__(self, axis: AxisIterator) -> None: + def __init__(self, axis: AxisIterator, *, overcompute: bool) -> None: self.axis = axis + self.overcompute = overcompute def __str__(self) -> str: - return f"CartesianAxisMerge_{self.axis.name}" + suffix = "_overcompute" if self.overcompute else "" + return f"CartesianAxisMerge_{self.axis.name}{suffix}" def _merge_node( self, node: tn.ScheduleTreeNode, nodes: list[tn.ScheduleTreeNode] @@ -220,6 +223,12 @@ def _map_overcompute_merge( ] ) + # only overcompute if configured - otherwise no merge + if not self.overcompute and ( + first_range != merged_range or second_range != merged_range + ): + return 0 + # - then, guard children to only run in their respective range axis_as_str = the_map.node.map.params[0] assert isinstance(axis_as_str, str) diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index 1dd64458..16d72380 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -11,11 +11,17 @@ class CartesianMerge(tn.ScheduleNodeTransformer): - """Merge Cartesian computation blocks""" + """Merge Cartesian computation blocks. - def __init__(self, backend: Backend) -> None: + Args: + backend: The loop order influences the merge order. + overcompute: Whether to merge at the cost of an if statement. Defaults to True. + """ + + def __init__(self, backend: Backend, *, overcompute: bool = True) -> None: super().__init__() self._backend = backend + self._overcompute = overcompute def __str__(self) -> str: return "CartesianMerge" @@ -26,7 +32,7 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: MergeConditionals().visit(node) for axis in self._backend_order(): - CartesianAxisMerge(axis).visit(node) + CartesianAxisMerge(axis, overcompute=self._overcompute).visit(node) ExtractOffgridConditionals().visit(node) MergeConditionals().visit(node) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index ad02dfbe..6fdbaace 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from dace.sdfg.analysis.schedule_tree import treenodes as tn @@ -69,15 +70,18 @@ def __init__( cache_directory: Path | None = None, ) -> None: if passes is None: - passes = [ + overcompute = os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True") == "True" + ppl_passes = [ CleanUpScheduleTree(), # TODO: Is it safe? Deactivate for now # InlineVertical2DWrite(), - CartesianMerge(backend), + CartesianMerge(backend, overcompute=overcompute), CartesianRefineTransients(backend), ] + else: + ppl_passes = passes super().__init__( - passes=passes if passes is not None else [], + passes=ppl_passes, cache_directory=cache_directory, ) @@ -91,17 +95,20 @@ def __init__( cache_directory: Path | None = None, ) -> None: if passes is None: - passes = [ + overcompute = os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True") == "True" + ppl_passes = [ CleanUpScheduleTree(), # TODO: Is it safe? Deactivate for now # InlineVertical2DWrite(), - CartesianMerge(backend), + CartesianMerge(backend, overcompute=overcompute), KernelizeMaps(backend), # 🐞 Transient refine can't be used # because of bugs transients showing in code generation # CartesianRefineTransients(backend), ] + else: + ppl_passes = passes super().__init__( - passes=passes if passes is not None else [], + passes=ppl_passes, cache_directory=cache_directory, ) diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index 53db91d5..b074c255 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -8,9 +8,10 @@ from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM +from ndsl.dsl.dace.stree.pipeline import CartesianMerge, CleanUpScheduleTree from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField -from tests.dsl.dace.stree import get_SDFG_and_purge +from tests.dsl.dace.stree import StreePipeline, get_SDFG_and_purge from tests.dsl.dace.stree.optimizations import Factories @@ -200,6 +201,39 @@ def test_overcompute_merge( ] assert len(all_maps) == 1 # All maps merged and collapsed + def test_no_overcompute_merge( + self, code: OrchestratedCode, factories: Factories + ) -> None: + stencil_factory, quantity_factory = factories + in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") + out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") + + no_overcompute = [ + CleanUpScheduleTree(), + CartesianMerge(stencil_factory.backend, overcompute=False), + ] + + with StreePipeline(passes=no_overcompute): + code.overcompute_merge(in_qty, out_qty) + + sdfg = get_SDFG_and_purge(stencil_factory).sdfg + + all_maps = [ + me for me, _ in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry) + ] + k_maps = 0 + ij_maps = 0 + for map_entry in all_maps: + if len(map_entry.map.params) == 1 and map_entry.map.params[0].startswith( + "__k" + ): + k_maps += 1 + if map_entry.map.params == ["__i", "__j"]: + ij_maps += 1 + + assert ij_maps == 1 + assert k_maps == 2 + def test_block_merge_when_dependencies_are_found( self, code: OrchestratedCode, factories: Factories ) -> None: From af8c2d474713bf62402d36cff65cde3119af6d86 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 19 Jun 2026 11:39:10 -0400 Subject: [PATCH 063/101] Re-work GPU xforms to exclude callback from going to host --- ndsl/dsl/dace/orchestration.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 97b96836..3ebbc9e9 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -275,8 +275,10 @@ def _build_sdfg( if config.is_gpu_backend(): with DaCeProgress(config, "Apply GPU transformations"): - # Set block size on GPU maps + # Set block size on GPU maps and collect callback + # tasklets to exclude next gpu_defaults = get_gpu_hardware_defaults() + exclude_taskslets_list = [] for me, _state in sdfg.all_nodes_recursive(): if ( isinstance(me, nodes.MapEntry) @@ -284,8 +286,22 @@ def _build_sdfg( ): if me.map.gpu_block_size is None: me.map.gpu_block_size = gpu_defaults.block_size + + if isinstance(me, nodes.Tasklet) and "callback_" in me.label: + exclude_taskslets_list.append(me.label) + # Apply common GPU transforms (includes a simplify) - sdfg.apply_gpu_transformations() + # while making sure tasklet remain on the host + from dace.transformation.interstate import GPUTransformSDFG + + sdfg.apply_transformations( + GPUTransformSDFG, + options={ + "exclude_tasklets": ",".join(exclude_taskslets_list), + "host_data": ["__pystate"], + }, + ) + if config.verbose_orchestration: ndsl_log.debug("saving 05-apply_gpu_xforms.sdfgz") sdfg.save( From de89c11b629750996943e74c739fb608873438e0 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 19 Jun 2026 17:13:46 +0200 Subject: [PATCH 064/101] feat: optimization config for orchestrated code This commit brings an optimization config, which will allow to teak optmiization parameters per NDSLRuntime and/or `orchestrate()` call. --- ndsl/__init__.py | 2 + ndsl/dsl/dace/dace_config.py | 30 -------- ndsl/dsl/dace/orchestration.py | 68 +++++++++++++++---- ndsl/dsl/dace/stree/pipeline.py | 15 ++-- ndsl/dsl/ndsl_runtime.py | 16 ++++- ndsl/dsl/optimization_config.py | 18 +++++ .../optimizations/test_kernelize_maps.py | 12 ++-- .../dace/stree/optimizations/test_merge.py | 12 ++-- .../test_offgrid_conditionals.py | 17 +++-- .../dace/stree/optimizations/test_pipeline.py | 10 ++- .../stree/optimizations/test_remove_loops.py | 8 +-- .../optimizations/test_transient_refine.py | 14 +++- tests/dsl/test_dace_config.py | 36 ---------- tests/test_ndsl_runtime.py | 50 +++++++++++--- 14 files changed, 186 insertions(+), 122 deletions(-) create mode 100644 ndsl/dsl/optimization_config.py diff --git a/ndsl/__init__.py b/ndsl/__init__.py index d468d58f..72ea237f 100644 --- a/ndsl/__init__.py +++ b/ndsl/__init__.py @@ -10,6 +10,7 @@ from .constants import ConstantVersions from .dsl.caches.codepath import FV3CodePath from .quantity import Quantity +from .dsl.optimization_config import OptimizationConfig from .dsl.ndsl_runtime import NDSLRuntime from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig @@ -90,6 +91,7 @@ "MetaEnumStr", "State", "LocalState", + "OptimizationConfig", "NDSLRuntime", "Local", "DiagManagerMonitor", diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py index 3f3f70fc..b722df77 100644 --- a/ndsl/dsl/dace/dace_config.py +++ b/ndsl/dsl/dace/dace_config.py @@ -2,7 +2,6 @@ import enum import os -import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Self @@ -216,11 +215,6 @@ def __init__( os.getenv("NDSL_VERBOSE_SCHEDULE_TREE_OPTIMIZATIONS", "False") == "True" ) - # Schedule tree optimization - self._schedule_tree_optimization = ( - os.getenv("NDSL_STREE_OPT", "False") == "True" - ) - # We hijack the optimization level of GT4Py because we don't # have the configuration at NDSL level, but we do use the GT4Py # level @@ -411,30 +405,6 @@ def get_orchestrate(self) -> DaCeOrchestration: def get_sync_debug(self) -> bool: return dace.config.Config.get_bool("compiler", "cuda", "syncdebug") - def enable_schedule_tree(self) -> None: - """Enables optimizations based on the schedule tree.""" - if not self.is_dace_orchestrated(): - warnings.warn( - "Enabling schedule tree optimization on a non-orchestrated backend has no effect.", - UserWarning, - stacklevel=2, - ) - self._schedule_tree_optimization = True - - def disable_schedule_tree(self) -> None: - """Disables optimizations based on the schedule tree.""" - if not self.is_dace_orchestrated(): - warnings.warn( - "Disabling schedule tree optimization on a non-orchestrated backend has no effect.", - UserWarning, - stacklevel=2, - ) - self._schedule_tree_optimization = False - - def schedule_tree_enabled(self) -> bool: - """Returns true iff schedule tree optimizations are enabled.""" - return self._schedule_tree_optimization - def as_dict(self) -> dict[str, Any]: return { "_orchestrate": str(self._orchestrate.name), diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 3ebbc9e9..e6f5d5be 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -24,7 +24,7 @@ from gt4py import storage as gt_storage import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements -from ndsl import Backend, ndsl_log +from ndsl import Backend, OptimizationConfig, ndsl_log from ndsl.comm.mpi import MPI from ndsl.dsl.dace.build import get_sdfg_path, write_build_info from ndsl.dsl.dace.dace_config import ( @@ -143,6 +143,7 @@ def _tree_as_sdfg(stree: tn.ScheduleTreeRoot) -> SDFG: def _optimization_pipeline( + config: OptimizationConfig, device_type: DeviceType, backend: Backend, *, @@ -150,10 +151,14 @@ def _optimization_pipeline( cache_directory: Path | None = None, ) -> StreePipeline: if device_type == device_type.CPU: - return CPUPipeline(backend, passes=passes, cache_directory=cache_directory) + return CPUPipeline( + config, backend, passes=passes, cache_directory=cache_directory + ) if device_type == DeviceType.GPU: - return GPUPipeline(backend, passes=passes, cache_directory=cache_directory) + return GPUPipeline( + config, backend, passes=passes, cache_directory=cache_directory + ) raise ValueError( f"Unknown device type `{device_type}`, expected {DeviceType.CPU} or {DeviceType.GPU}." @@ -161,7 +166,12 @@ def _optimization_pipeline( def _build_sdfg( - dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, args: Any, kwargs: Any + dace_program: DaceProgram, + sdfg: SDFG, + config: DaceConfig, + optimization_config: OptimizationConfig, + args: Any, + kwargs: Any, ) -> None: """Build the .so out of the SDFG on the top tile ranks only.""" is_compiling = True if DEACTIVATE_DISTRIBUTED_DACE_COMPILE else config.do_compile @@ -199,7 +209,7 @@ def _build_sdfg( compress=True, ) - if config.schedule_tree_enabled(): + if optimization_config.stree.enabled: # Here be 🐉 - but tests exists in test_optimization.py with DaCeProgress(config, "Schedule Tree: generate from SDFG"): # Break all loops into uni-dimensional loops to simplify optimizations @@ -226,6 +236,7 @@ def _build_sdfg( with DaCeProgress(config, "Schedule Tree: optimization"): pipeline = _optimization_pipeline( + optimization_config, device_type, backend_name, cache_directory=Path(sdfg.build_folder), @@ -409,7 +420,12 @@ def _build_sdfg( def _call_sdfg( - dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, args: Any, kwargs: Any + dace_program: DaceProgram, + sdfg: SDFG, + config: DaceConfig, + optimization_config: OptimizationConfig, + args: Any, + kwargs: Any, ) -> list | None: """Dispatch to either SDFG execution and/or build.""" @@ -421,7 +437,7 @@ def _call_sdfg( and dace_program not in config.loaded_dace_executables # already cached ): ndsl_log.info("Building DaCe orchestration") - _build_sdfg(dace_program, sdfg, config, args, kwargs) + _build_sdfg(dace_program, sdfg, config, optimization_config, args, kwargs) if mode not in [DaCeOrchestration.BuildAndRun, DaCeOrchestration.Run]: raise ValueError(f"Unexpected DaceOrchestration mode `{mode}`.") @@ -528,9 +544,15 @@ class _LazyComputepathFunction(SDFGConvertible): that will be compiled but not regenerated. """ - def __init__(self, func: Callable, config: DaceConfig) -> None: + def __init__( + self, + func: Callable, + config: DaceConfig, + optimization_config: OptimizationConfig, + ) -> None: self.func = func self.config = config + self.optimization_config = optimization_config self.daceprog: DaceProgram = dace_program(self.func) self._sdfg = None @@ -546,6 +568,7 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] self.daceprog, sdfg, self.config, + self.optimization_config, args, kwargs, ) @@ -610,12 +633,13 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] **kwargs, ) # Label the code (this is the topmost code) - if sdfg is not None and self.lazy_method.config.schedule_tree_enabled(): + if sdfg is not None and self.lazy_method.optimization_config.stree.enabled: set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=True) return _call_sdfg( self.daceprog, sdfg, self.lazy_method.config, + self.lazy_method.optimization_config, args, kwargs, ) @@ -623,7 +647,7 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] sdfg = _parse_sdfg(self.daceprog, self.lazy_method.config, *args, **kwargs) # Label the code - if sdfg is not None and self.lazy_method.config.schedule_tree_enabled(): + if sdfg is not None and self.lazy_method.optimization_config.stree.enabled: set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=False) return sdfg @@ -638,9 +662,15 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None): # t constant_args, given_args, parent_closure ) - def __init__(self, func: Callable, config: DaceConfig): + def __init__( + self, + func: Callable, + config: DaceConfig, + optimization_config: OptimizationConfig, + ) -> None: self.func = func self.config = config + self.optimization_config = optimization_config def __get__(self, obj: object, objtype: Any = None) -> SDFGEnabledCallable: """Return SDFGEnabledCallable wrapping original obj.method from cache. @@ -659,6 +689,7 @@ def orchestrate( config: DaceConfig, method_to_orchestrate: str = "__call__", dace_compiletime_args: Sequence[str] | None = None, + optimization_config: OptimizationConfig | None = None, ) -> None: """ Orchestrate a method of an object with DaCe. @@ -689,6 +720,11 @@ def orchestrate( if dace_compiletime_args is None: dace_compiletime_args = [] + if optimization_config is None: + opt_config = OptimizationConfig() + else: + opt_config = optimization_config + func: Callable = type.__getattribute__(type(obj), method_to_orchestrate) # Flag argument as dace.constant @@ -711,7 +747,7 @@ def orchestrate( # Build DaCe orchestrated wrapper # This is a JIT object, e.g. DaCe compilation will happen on call - wrapped = _LazyComputepathMethod(func, config).__get__(obj) + wrapped = _LazyComputepathMethod(func, config, opt_config).__get__(obj) if method_to_orchestrate == "__call__": # Grab the function from the type of the child class @@ -763,6 +799,7 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None): # t def orchestrate_function( config: DaceConfig, dace_compiletime_args: Sequence[str] | None = None, + optimization_config: OptimizationConfig | None = None, ) -> Callable[..., Any] | _LazyComputepathFunction: """ Decorator orchestrating a method of an object with DaCe. @@ -777,11 +814,16 @@ def orchestrate_function( if dace_compiletime_args is None: dace_compiletime_args = [] + if optimization_config is None: + opt_config = OptimizationConfig() + else: + opt_config = optimization_config + def _decorator(func: Callable[..., Any]): # type: ignore[no-untyped-def] def _wrapper(*args, **kwargs): # type: ignore[no-untyped-def] for argument in dace_compiletime_args: func.__annotations__[argument] = DaceCompiletime - return _LazyComputepathFunction(func, config) + return _LazyComputepathFunction(func, config, opt_config) return _wrapper(func) if config.is_dace_orchestrated() else func diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 6fdbaace..0b7d4713 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -1,9 +1,8 @@ -import os from pathlib import Path from dace.sdfg.analysis.schedule_tree import treenodes as tn -from ndsl import Backend, ndsl_log_on_rank_0 +from ndsl import Backend, OptimizationConfig, ndsl_log_on_rank_0 from ndsl.dsl.dace.stree.optimizations import ( CartesianMerge, CartesianRefineTransients, @@ -16,6 +15,7 @@ class StreePipeline: def __init__( self, + config: OptimizationConfig, *, passes: list[tn.ScheduleNodeVisitor], cache_directory: Path | None = None, @@ -25,6 +25,7 @@ def __init__( self.cache_directory = cache_directory self.passes = passes + self.config = config def __hash__(self) -> int: return hash(repr(self)) @@ -64,23 +65,24 @@ def run( class CPUPipeline(StreePipeline): def __init__( self, + config: OptimizationConfig, backend: Backend, *, passes: list[tn.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: if passes is None: - overcompute = os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True") == "True" ppl_passes = [ CleanUpScheduleTree(), # TODO: Is it safe? Deactivate for now # InlineVertical2DWrite(), - CartesianMerge(backend, overcompute=overcompute), + CartesianMerge(backend, overcompute=config.stree.merger.overcompute), CartesianRefineTransients(backend), ] else: ppl_passes = passes super().__init__( + config=config, passes=ppl_passes, cache_directory=cache_directory, ) @@ -89,18 +91,18 @@ def __init__( class GPUPipeline(StreePipeline): def __init__( self, + config: OptimizationConfig, backend: Backend, *, passes: list[tn.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> None: if passes is None: - overcompute = os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True") == "True" ppl_passes = [ CleanUpScheduleTree(), # TODO: Is it safe? Deactivate for now # InlineVertical2DWrite(), - CartesianMerge(backend, overcompute=overcompute), + CartesianMerge(backend, overcompute=config.stree.merger.overcompute), KernelizeMaps(backend), # 🐞 Transient refine can't be used # because of bugs transients showing in code generation @@ -109,6 +111,7 @@ def __init__( else: ppl_passes = passes super().__init__( + config=config, passes=ppl_passes, cache_directory=cache_directory, ) diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index a994c61a..294f5711 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -5,6 +5,7 @@ from collections.abc import Callable from typing import Any, Sequence +from ndsl import OptimizationConfig from ndsl.dsl.dace.orchestration import orchestrate from ndsl.dsl.stencil import StencilFactory from ndsl.dsl.typing import Float @@ -21,10 +22,22 @@ class NDSLRuntime: The __call__ function will automatically be orchestrated.""" - def __init__(self, stencil_factory: StencilFactory) -> None: + def __init__( + self, + stencil_factory: StencilFactory, + optimization_config: OptimizationConfig | None = None, + ) -> None: self._stencil_factory = stencil_factory # Use this flag to detect that the init wasn't done properly self._base_class_was_properly_super_init = True + if optimization_config is None: + # TODO + # - Decide where to put defaults. + # - For now, they are in the OptimizationConfig object itself. + # - We could have specialized defaults here for NDSLRuntime code. + self._optimization_config = OptimizationConfig() + else: + self._optimization_config = optimization_config def __init_subclass__(cls: type[NDSLRuntime], **kwargs: dict[str, Any]) -> None: # WARNING: no code outside the `init_decorator` this is cls @@ -75,6 +88,7 @@ def check_for_quantity(object_: object) -> None: orchestrate( obj=self, config=self._stencil_factory.config.dace_config, + optimization_config=self._optimization_config, ) def __getattribute__(self, name: str) -> Any: diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py new file mode 100644 index 00000000..ada339ce --- /dev/null +++ b/ndsl/dsl/optimization_config.py @@ -0,0 +1,18 @@ +import os +from dataclasses import dataclass, field + + +@dataclass +class OptimizationConfig: + @dataclass + class TreeConfig: + @dataclass + class MergerConfig: + overcompute: bool = ( + os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True") == "True" + ) + + enabled: bool = os.getenv("NDSL_STREE_OPT", "False") == "True" + merger: MergerConfig = field(default_factory=MergerConfig) + + stree: TreeConfig = field(default_factory=TreeConfig) diff --git a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py index 8ac4abf7..423986e6 100644 --- a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py +++ b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py @@ -2,7 +2,7 @@ from dace import nodes from dace.sdfg.state import LoopRegion -from ndsl import Backend, NDSLRuntime, orchestrate +from ndsl import Backend, NDSLRuntime, OptimizationConfig, orchestrate from ndsl.boilerplate import get_factories_single_tile from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.gt4py import BACKWARD, FORWARD, PARALLEL, computation, interval @@ -46,7 +46,10 @@ def stencil_only_parallel_noop( class OrchestratedCode(NDSLRuntime): def __init__(self, stencil_factory: StencilFactory) -> None: - super().__init__(stencil_factory) + optimization_config = OptimizationConfig( + OptimizationConfig.TreeConfig(enabled=True) + ) + super().__init__(stencil_factory, optimization_config) methods_to_orchestrate = [ "kernelize_k", @@ -58,6 +61,7 @@ def __init__(self, stencil_factory: StencilFactory) -> None: obj=self, config=stencil_factory.config.dace_config, method_to_orchestrate=method, + optimization_config=optimization_config, ) self._stencil_kernelize_k = stencil_factory.from_dims_halo( @@ -92,15 +96,13 @@ class TestKernelizeMaps: ) def factories(self, request: pytest.FixtureRequest) -> Factories: domain = (3, 4, 5) - stencil_factory, quantity_factory = get_factories_single_tile( + return get_factories_single_tile( nx=domain[0], ny=domain[1], nz=domain[2], nhalo=0, backend=Backend(request.param), ) - stencil_factory.config.dace_config.enable_schedule_tree() - return stencil_factory, quantity_factory def test_kernelize_k_gpu(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index b074c255..86cef061 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -4,7 +4,7 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg.state import LoopRegion -from ndsl import QuantityFactory, StencilFactory, orchestrate +from ndsl import OptimizationConfig, QuantityFactory, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM @@ -54,6 +54,7 @@ def __init__( stencil_factory: StencilFactory, quantity_factory: QuantityFactory, ) -> None: + config = OptimizationConfig(stree=OptimizationConfig.TreeConfig(enabled=True)) orchestratable_methods = [ "trivial_merge", "missing_merge_of_forscope_and_map", @@ -66,6 +67,7 @@ def __init__( obj=self, config=stencil_factory.config.dace_config, method_to_orchestrate=method, + optimization_config=config, ) self.stencil = stencil_factory.from_dims_halo( @@ -134,11 +136,9 @@ class TestStreeMergeMapsIJK: @pytest.fixture def factories(self) -> Factories: domain = (3, 3, 4) - stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + return get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend=Backend("orch:dace:cpu:IJK") ) - stencil_factory.config.dace_config.enable_schedule_tree() - return stencil_factory, quantity_factory @pytest.fixture def code(self, factories: Factories) -> OrchestratedCode: @@ -282,11 +282,9 @@ class TestStreeMergeMapsKJI: @pytest.fixture def factories(self) -> Factories: domain = (3, 3, 4) - stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + return get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend=Backend("orch:dace:cpu:KJI") ) - stencil_factory.config.dace_config.enable_schedule_tree() - return stencil_factory, quantity_factory @pytest.fixture def code(self, factories: Factories) -> OrchestratedCode: diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index e9979dcb..88cab2b2 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -1,7 +1,14 @@ import pytest from dace import nodes -from ndsl import Backend, NDSLRuntime, StencilFactory, orchestrate, stencils +from ndsl import ( + Backend, + NDSLRuntime, + OptimizationConfig, + StencilFactory, + orchestrate, + stencils, +) from ndsl.boilerplate import get_factories_single_tile from ndsl.constants import I_DIM, J_DIM, K_DIM from ndsl.dsl.typing import FloatField @@ -11,7 +18,8 @@ class OrchestratedCode(NDSLRuntime): def __init__(self, stencil_factory: StencilFactory) -> None: - super().__init__(stencil_factory) + config = OptimizationConfig(stree=OptimizationConfig.TreeConfig(enabled=True)) + super().__init__(stencil_factory, config) methods_to_orchestrate = [ "happy_case", @@ -25,6 +33,7 @@ def __init__(self, stencil_factory: StencilFactory) -> None: obj=self, config=stencil_factory.config.dace_config, method_to_orchestrate=method, + optimization_config=config, ) self._copy_stencil = stencil_factory.from_dims_halo( @@ -62,11 +71,9 @@ class TestStreeInlineOffgridConditionals: @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"]) def factories(self, request: pytest.FixtureRequest) -> Factories: domain = (3, 3, 4) - stencil_factory, quantity_factory = get_factories_single_tile( + return get_factories_single_tile( domain[0], domain[1], domain[2], 0, backend=Backend(request.param) ) - stencil_factory.config.dace_config.enable_schedule_tree() - return stencil_factory, quantity_factory def test_happy_case(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories diff --git a/tests/dsl/dace/stree/optimizations/test_pipeline.py b/tests/dsl/dace/stree/optimizations/test_pipeline.py index b9c4eb78..39997cde 100644 --- a/tests/dsl/dace/stree/optimizations/test_pipeline.py +++ b/tests/dsl/dace/stree/optimizations/test_pipeline.py @@ -1,4 +1,4 @@ -from ndsl import StencilFactory, orchestrate +from ndsl import OptimizationConfig, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM @@ -16,7 +16,12 @@ def double_map(in_field: FloatField, out_field: FloatField): class TriviallyMergeableCode: def __init__(self, stencil_factory: StencilFactory): - orchestrate(obj=self, config=stencil_factory.config.dace_config) + config = OptimizationConfig(stree=OptimizationConfig.TreeConfig(enabled=True)) + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + optimization_config=config, + ) self.stencil = stencil_factory.from_dims_halo( func=double_map, compute_dims=[I_DIM, J_DIM, K_DIM], @@ -31,7 +36,6 @@ def test_stree_roundtrip_no_opt(): stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend=Backend.cpu() ) - stencil_factory.config.dace_config.enable_schedule_tree() code = TriviallyMergeableCode(stencil_factory) in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 77082ef4..378056da 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -2,7 +2,7 @@ from dace import nodes from dace.sdfg.state import LoopRegion -from ndsl import StencilFactory, orchestrate +from ndsl import OptimizationConfig, StencilFactory, orchestrate from ndsl.boilerplate import get_factories_single_tile from ndsl.config import Backend, BackendLoopOrder from ndsl.constants import I_DIM, J_DIM, K_DIM, Float @@ -44,6 +44,7 @@ def stencil_forward_at_K(in_field: FloatField, out_field: FloatField) -> None: class OrchestratedCode: def __init__(self, stencil_factory: StencilFactory) -> None: + config = OptimizationConfig(stree=OptimizationConfig.TreeConfig(enabled=True)) methods_to_orchestrate = [ "write_at_0", "write_at_top", @@ -56,6 +57,7 @@ def __init__(self, stencil_factory: StencilFactory) -> None: obj=self, config=stencil_factory.config.dace_config, method_to_orchestrate=method, + optimization_config=config, ) self.stencil_simple_2D_write = stencil_factory.from_dims_halo( @@ -118,11 +120,9 @@ class TestStree2DWriteInline: def factories(self, request: pytest.FixtureRequest) -> Factories: domain = (3, 3, 4) - stencil_factory, quantity_factory = get_factories_single_tile( + return get_factories_single_tile( domain[0], domain[1], domain[2], 0, backend=Backend(request.param) ) - stencil_factory.config.dace_config.enable_schedule_tree() - return stencil_factory, quantity_factory def test_common_2D_write(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories diff --git a/tests/dsl/dace/stree/optimizations/test_transient_refine.py b/tests/dsl/dace/stree/optimizations/test_transient_refine.py index bf23cbbc..21bbc98c 100644 --- a/tests/dsl/dace/stree/optimizations/test_transient_refine.py +++ b/tests/dsl/dace/stree/optimizations/test_transient_refine.py @@ -1,4 +1,11 @@ -from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate +from ndsl import ( + NDSLRuntime, + OptimizationConfig, + Quantity, + QuantityFactory, + StencilFactory, + orchestrate, +) from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM @@ -39,7 +46,8 @@ class TransientRefineableCode(NDSLRuntime): def __init__( self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory ) -> None: - super().__init__(stencil_factory) + config = OptimizationConfig(stree=OptimizationConfig.TreeConfig(enabled=True)) + super().__init__(stencil_factory, optimization_config=config) orchestratable_methods = [ "refine_to_scalar", "refine_to_K_buffer", @@ -51,6 +59,7 @@ def __init__( obj=self, config=stencil_factory.config.dace_config, method_to_orchestrate=method, + optimization_config=config, ) self.stencil = stencil_factory.from_dims_halo( func=stencil, @@ -95,7 +104,6 @@ def test_stree_roundtrip_transient_is_refined() -> None: stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend=Backend.cpu() ) - stencil_factory.config.dace_config.enable_schedule_tree() in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") diff --git a/tests/dsl/test_dace_config.py b/tests/dsl/test_dace_config.py index 0384ab42..8003937f 100644 --- a/tests/dsl/test_dace_config.py +++ b/tests/dsl/test_dace_config.py @@ -1,8 +1,5 @@ -import os import unittest.mock -import pytest - from ndsl import CubedSpherePartitioner, DaceConfig, DaCeOrchestration, TilePartitioner from ndsl.comm.partitioner import Partitioner from ndsl.config import Backend @@ -158,36 +155,3 @@ def _does_compile(rank: int, partitioner: Partitioner) -> bool: for i in range(layout[0] * layout[1] * 6): compiling += 1 if _does_compile(i, partition) else 0 assert compiling == 9 - - -def test_schedule_tree_enable_disable() -> None: - config = DaceConfig( - communicator=None, - backend=Backend("orch:dace:cpu:KIJ"), - orchestration=DaCeOrchestration.BuildAndRun, - ) - - default = os.getenv("NDSL_STREE_OPT", "False") == "True" - assert config.schedule_tree_enabled() == default - - if default: - config.disable_schedule_tree() - assert not config.schedule_tree_enabled() - config.enable_schedule_tree() - assert config.schedule_tree_enabled() - else: - config.enable_schedule_tree() - assert config.schedule_tree_enabled() - config.disable_schedule_tree() - assert not config.schedule_tree_enabled() - - -def test_schedule_tree_warns_when_not_orchestrated(): - config = DaceConfig( - communicator=None, - backend=Backend("st:dace:cpu:KIJ"), - orchestration=DaCeOrchestration.BuildAndRun, - ) - - with pytest.warns(UserWarning, match="non-orchestrated backend"): - config.enable_schedule_tree() diff --git a/tests/test_ndsl_runtime.py b/tests/test_ndsl_runtime.py index 67e4f226..83694274 100644 --- a/tests/test_ndsl_runtime.py +++ b/tests/test_ndsl_runtime.py @@ -2,20 +2,19 @@ import pytest -from ndsl import NDSLRuntime, QuantityFactory, StencilFactory +from ndsl import ( + NDSLRuntime, + OptimizationConfig, + QuantityFactory, + StencilFactory, + stencils, +) from ndsl.boilerplate import ( get_factories_single_tile, get_factories_single_tile_orchestrated, ) from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM -from ndsl.dsl.gt4py import PARALLEL, computation, interval -from ndsl.dsl.typing import FloatField - - -def the_copy_stencil(from_: FloatField, to: FloatField) -> None: - with computation(PARALLEL), interval(...): - to = from_ class Code(NDSLRuntime): @@ -24,7 +23,7 @@ def __init__( ) -> None: super().__init__(stencil_factory) self.copy = stencil_factory.from_dims_halo( - the_copy_stencil, compute_dims=[I_DIM, J_DIM, K_DIM] + stencils.copy, compute_dims=[I_DIM, J_DIM, K_DIM] ) self.local = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM]) @@ -105,3 +104,36 @@ def test_runtime_fail_when_not_super_init() -> None: RuntimeError, match="inherit from NDSLRuntime but didn't call super()" ): bad_code = BadCode_NoSuperInit() + + +def test_runtime_with_performance_config() -> None: + class CustomPerformanceConfig(NDSLRuntime): + def __init__( + self, + stencil_factory: StencilFactory, + optimization_config: OptimizationConfig, + ) -> None: + super().__init__(stencil_factory, optimization_config) + self.copy = stencil_factory.from_dims_halo( + stencils.copy, compute_dims=[I_DIM, J_DIM, K_DIM] + ) + + def __call__(self, src, dst) -> None: # type: ignore[no-untyped-def] + self.copy(src, dst) + + stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( + nx=5, ny=5, nz=3, nhalo=0, backend=Backend.cpu() + ) + + # setup code + config = OptimizationConfig() + code = CustomPerformanceConfig(stencil_factory, config) + + # setup inputs/outputs + src = quantity_factory.ones(dims=[I_DIM, J_DIM, K_DIM], units="n/a") + dst = quantity_factory.zeros(dims=[I_DIM, J_DIM, K_DIM], units="n/a") + + # call code with inputs/outputs + code(src, dst) + + assert (src.field[:] == dst.field[:]).all() From 20018ff15b0c30de9f8edfa416fc83589a38e420 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 19 Jun 2026 13:12:52 -0400 Subject: [PATCH 065/101] `OptimizationConfig` tweaks: - Remove "Config" from nested name - Add `GPU` & apply to orchestration - Some docs - Display configuration for building ranks --- ndsl/dsl/dace/orchestration.py | 29 ++++++++++++++++++----------- ndsl/dsl/optimization_config.py | 24 ++++++++++++++++++------ 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index e6f5d5be..2807507a 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -4,6 +4,7 @@ import os from collections.abc import Callable, Sequence from pathlib import Path +from pprint import pformat from typing import Any from dace import SDFG, CompiledSDFG, DeviceType @@ -179,6 +180,7 @@ def _build_sdfg( backend_name = config.get_backend() if is_compiling: + ndsl_log.debug(f"Compiling config:\n{pformat(optimization_config, indent=2)}") # Fully specialize all known symbols and then propagate these changes in the simplify # pass that follows. This is not only a smart idea in general, but also simplifies (haha) # the schedule tree (optimization) roundtrip. @@ -301,17 +303,22 @@ def _build_sdfg( if isinstance(me, nodes.Tasklet) and "callback_" in me.label: exclude_taskslets_list.append(me.label) - # Apply common GPU transforms (includes a simplify) - # while making sure tasklet remain on the host - from dace.transformation.interstate import GPUTransformSDFG - - sdfg.apply_transformations( - GPUTransformSDFG, - options={ - "exclude_tasklets": ",".join(exclude_taskslets_list), - "host_data": ["__pystate"], - }, - ) + if optimization_config.gpu.common_gpu_xforms: + with DaCeProgress(config, "Apply common GPU xforms"): + # Apply common GPU transforms (includes a simplify) + # while making sure tasklet remain on the host + from dace.transformation.interstate import GPUTransformSDFG + + sdfg.apply_transformations( + GPUTransformSDFG, + options={ + "exclude_tasklets": ",".join(exclude_taskslets_list), + "host_data": ["__pystate"], + }, + ) + else: + with DaCeProgress(config, "GPU simplify"): + _simplify(sdfg) if config.verbose_orchestration: ndsl_log.debug("saving 05-apply_gpu_xforms.sdfgz") diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py index ada339ce..42a3c0cb 100644 --- a/ndsl/dsl/optimization_config.py +++ b/ndsl/dsl/optimization_config.py @@ -5,14 +5,26 @@ @dataclass class OptimizationConfig: @dataclass - class TreeConfig: + class Tree: + """Optimization using the Schedule Tree IR""" + @dataclass - class MergerConfig: + class Merger: overcompute: bool = ( - os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True") == "True" + os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True").lower() == "true" ) + """When merging allow map of different size to merge by inserting an if guard""" + + enabled: bool = os.getenv("NDSL_STREE_OPT", "False").lower() == "true" + """Enable Schedule Tree transformations""" + merger: Merger = field(default_factory=Merger) + + @dataclass + class GPU: + """Optimization dedicated for GPU""" - enabled: bool = os.getenv("NDSL_STREE_OPT", "False") == "True" - merger: MergerConfig = field(default_factory=MergerConfig) + common_gpu_xforms: bool = True + """DaCe common xforms bundled in `apply_gpu_transformations`""" - stree: TreeConfig = field(default_factory=TreeConfig) + stree: Tree = field(default_factory=Tree) + gpu: GPU = field(default_factory=GPU) From ed298be27b3ca2c66cd9f49c2ddf1de11c3e899c Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 19 Jun 2026 14:39:46 -0400 Subject: [PATCH 066/101] GPU opt: apply AddThreadBlock so we have proper thread-blocking --- ndsl/dsl/dace/orchestration.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 2807507a..03068abf 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -21,6 +21,7 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.transformation.auto.auto_optimize import make_transients_persistent from dace.transformation.dataflow import MapCollapse, MapExpansion +from dace.transformation.dataflow.add_threadblock_map import AddThreadBlockMap from dace.transformation.helpers import get_parent_map from gt4py import storage as gt_storage @@ -292,6 +293,7 @@ def _build_sdfg( # tasklets to exclude next gpu_defaults = get_gpu_hardware_defaults() exclude_taskslets_list = [] + for me, _state in sdfg.all_nodes_recursive(): if ( isinstance(me, nodes.MapEntry) @@ -303,6 +305,8 @@ def _build_sdfg( if isinstance(me, nodes.Tasklet) and "callback_" in me.label: exclude_taskslets_list.append(me.label) + sdfg.apply_transformations_repeated(AddThreadBlockMap) + if optimization_config.gpu.common_gpu_xforms: with DaCeProgress(config, "Apply common GPU xforms"): # Apply common GPU transforms (includes a simplify) From df1aac36356a8f300e21353c31b48609e419fb85 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 19 Jun 2026 14:54:58 -0400 Subject: [PATCH 067/101] Fix `OptimizationConfig` tests --- tests/dsl/dace/stree/optimizations/test_kernelize_maps.py | 2 +- tests/dsl/dace/stree/optimizations/test_merge.py | 2 +- tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py | 2 +- tests/dsl/dace/stree/optimizations/test_pipeline.py | 2 +- tests/dsl/dace/stree/optimizations/test_remove_loops.py | 2 +- tests/dsl/dace/stree/optimizations/test_transient_refine.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py index 423986e6..12ebba15 100644 --- a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py +++ b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py @@ -47,7 +47,7 @@ def stencil_only_parallel_noop( class OrchestratedCode(NDSLRuntime): def __init__(self, stencil_factory: StencilFactory) -> None: optimization_config = OptimizationConfig( - OptimizationConfig.TreeConfig(enabled=True) + OptimizationConfig.Tree(enabled=True) ) super().__init__(stencil_factory, optimization_config) diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index 86cef061..d20e3c4c 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -54,7 +54,7 @@ def __init__( stencil_factory: StencilFactory, quantity_factory: QuantityFactory, ) -> None: - config = OptimizationConfig(stree=OptimizationConfig.TreeConfig(enabled=True)) + config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) orchestratable_methods = [ "trivial_merge", "missing_merge_of_forscope_and_map", diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index 88cab2b2..ce1446b2 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -18,7 +18,7 @@ class OrchestratedCode(NDSLRuntime): def __init__(self, stencil_factory: StencilFactory) -> None: - config = OptimizationConfig(stree=OptimizationConfig.TreeConfig(enabled=True)) + config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) super().__init__(stencil_factory, config) methods_to_orchestrate = [ diff --git a/tests/dsl/dace/stree/optimizations/test_pipeline.py b/tests/dsl/dace/stree/optimizations/test_pipeline.py index 39997cde..89662b4e 100644 --- a/tests/dsl/dace/stree/optimizations/test_pipeline.py +++ b/tests/dsl/dace/stree/optimizations/test_pipeline.py @@ -16,7 +16,7 @@ def double_map(in_field: FloatField, out_field: FloatField): class TriviallyMergeableCode: def __init__(self, stencil_factory: StencilFactory): - config = OptimizationConfig(stree=OptimizationConfig.TreeConfig(enabled=True)) + config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) orchestrate( obj=self, config=stencil_factory.config.dace_config, diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 378056da..331fb44b 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -44,7 +44,7 @@ def stencil_forward_at_K(in_field: FloatField, out_field: FloatField) -> None: class OrchestratedCode: def __init__(self, stencil_factory: StencilFactory) -> None: - config = OptimizationConfig(stree=OptimizationConfig.TreeConfig(enabled=True)) + config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) methods_to_orchestrate = [ "write_at_0", "write_at_top", diff --git a/tests/dsl/dace/stree/optimizations/test_transient_refine.py b/tests/dsl/dace/stree/optimizations/test_transient_refine.py index 21bbc98c..3190f39d 100644 --- a/tests/dsl/dace/stree/optimizations/test_transient_refine.py +++ b/tests/dsl/dace/stree/optimizations/test_transient_refine.py @@ -46,7 +46,7 @@ class TransientRefineableCode(NDSLRuntime): def __init__( self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory ) -> None: - config = OptimizationConfig(stree=OptimizationConfig.TreeConfig(enabled=True)) + config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) super().__init__(stencil_factory, optimization_config=config) orchestratable_methods = [ "refine_to_scalar", From 5bdec8f3f67bbeb9369b3903a1a30ba3daf6df7a Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 19 Jun 2026 17:11:45 -0400 Subject: [PATCH 068/101] Lint --- tests/dsl/dace/stree/optimizations/test_kernelize_maps.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py index 12ebba15..3343b3eb 100644 --- a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py +++ b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py @@ -46,9 +46,7 @@ def stencil_only_parallel_noop( class OrchestratedCode(NDSLRuntime): def __init__(self, stencil_factory: StencilFactory) -> None: - optimization_config = OptimizationConfig( - OptimizationConfig.Tree(enabled=True) - ) + optimization_config = OptimizationConfig(OptimizationConfig.Tree(enabled=True)) super().__init__(stencil_factory, optimization_config) methods_to_orchestrate = [ From f2430f4380569df0a33a9a94dcf4c5d98d9af46c Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 19 Jun 2026 18:07:05 -0400 Subject: [PATCH 069/101] Default `common_gpu_xforms` to False as it crashes more often than not (sic) --- ndsl/dsl/optimization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py index 42a3c0cb..049bc8d0 100644 --- a/ndsl/dsl/optimization_config.py +++ b/ndsl/dsl/optimization_config.py @@ -23,7 +23,7 @@ class Merger: class GPU: """Optimization dedicated for GPU""" - common_gpu_xforms: bool = True + common_gpu_xforms: bool = False """DaCe common xforms bundled in `apply_gpu_transformations`""" stree: Tree = field(default_factory=Tree) From 1c8d377661c48f7a0792117946d7c7107758d6bc Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Fri, 19 Jun 2026 21:53:06 -0400 Subject: [PATCH 070/101] Add option `pad_non_interface_dimensions` to GridSizer to undo the behavior of allocation when`fortran_aligned` that does not pad dimensions to be the same length Added a `--pad_non_interface_dimensions` to the unit tests to pass the flag during testing --- ndsl/initialization/subtile_grid_sizer.py | 20 ++++++++++++++++++-- ndsl/stencils/testing/conftest.py | 16 ++++++++++++++++ ndsl/stencils/testing/grid.py | 14 +++++++++++++- ndsl/stencils/testing/translate.py | 23 +++++++++++++++++------ 4 files changed, 64 insertions(+), 9 deletions(-) diff --git a/ndsl/initialization/subtile_grid_sizer.py b/ndsl/initialization/subtile_grid_sizer.py index 4a257080..c923afee 100644 --- a/ndsl/initialization/subtile_grid_sizer.py +++ b/ndsl/initialization/subtile_grid_sizer.py @@ -17,11 +17,18 @@ def __init__( n_halo: int, data_dimensions: dict[str, int], backend: Backend, + *, + pad_non_interface_dimensions: bool = False, ) -> None: super().__init__(nx, ny, nz, n_halo, data_dimensions) fortran_style_memory = backend.is_fortran_aligned() - self._pad_non_interface_dimensions = not fortran_style_memory + + # TODO: pad_non_interface_dimensions should not be kept. In general + # this should _always_ be False and non-interface dimensions never padded by default + self._pad_non_interface_dimensions = ( + not fortran_style_memory or pad_non_interface_dimensions + ) @classmethod def from_tile_params( @@ -36,6 +43,7 @@ def from_tile_params( data_dimensions: dict[str, int] | None = None, tile_partitioner: TilePartitioner | None = None, tile_rank: int = 0, + pad_non_interface_dimensions: bool = False, ) -> Self: """Create a SubtileGridSizer from parameters about the full tile. @@ -76,7 +84,15 @@ def from_tile_params( "SubtileGridSizer::from_tile_params: Compute domain extent must be greater than halo size" ) - return cls(nx, ny, nz, n_halo, data_dimensions, backend) + return cls( + nx, + ny, + nz, + n_halo, + data_dimensions, + backend, + pad_non_interface_dimensions=pad_non_interface_dimensions, + ) @classmethod def from_namelist( diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index 6e5b17af..652e132a 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -105,6 +105,12 @@ def pytest_addoption(parser: pytest.Parser) -> None: default=False, help="Do not generate logging report or NetCDF in .translate-errors", ) + parser.addoption( + "--pad_non_interface_dimensions", + action="store_true", + default=False, + help="Pad the non interface dimensions in all backends. Default to False.", + ) def pytest_configure(config: pytest.Config) -> None: @@ -255,6 +261,9 @@ def _sequential_savepoint_cases( topology_mode = metafunc.config.getoption("topology") sort_report = metafunc.config.getoption("sort_report") no_report = metafunc.config.getoption("no_report") + pad_non_interface_dimensions = metafunc.config.getoption( + "pad_non_interface_dimensions" + ) return _savepoint_cases( savepoint_names, @@ -268,6 +277,7 @@ def _sequential_savepoint_cases( topology_mode, sort_report=sort_report, no_report=no_report, + pad_non_interface_dimensions=pad_non_interface_dimensions, ) @@ -283,6 +293,7 @@ def _savepoint_cases( topology_mode: str, sort_report: str, no_report: bool, + pad_non_interface_dimensions: bool, ) -> list[SavepointCase]: grid_params = grid_params_from_f90nml(namelist) return_list = [] @@ -305,6 +316,7 @@ def _savepoint_cases( rank=rank, layout=grid_params["layout"], backend=backend, + pad_non_interface_dimensions=pad_non_interface_dimensions, ).python_grid() if grid_mode == "compute": _compute_grid_data( @@ -377,6 +389,9 @@ def _parallel_savepoint_cases( savepoint_names = _parallel_savepoint_names(metafunc, data_path) grid_mode = metafunc.config.getoption("grid") savepoint_to_replay = _get_savepoint_restriction(metafunc) + pad_non_interface_dimensions = metafunc.config.getoption( + "pad_non_interface_dimensions" + ) return _savepoint_cases( savepoint_names, @@ -390,6 +405,7 @@ def _parallel_savepoint_cases( topology_mode, sort_report=sort_report, no_report=no_report, + pad_non_interface_dimensions=pad_non_interface_dimensions, ) diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py index 3af290e4..db24fd13 100644 --- a/ndsl/stencils/testing/grid.py +++ b/ndsl/stencils/testing/grid.py @@ -60,6 +60,7 @@ def _make( layout: tuple[int, int], rank: int, backend: Backend, + pad_non_interface_dimensions: bool = False, ) -> "Grid": shape_params = { "npx": npx, @@ -81,7 +82,15 @@ def _make( "js": N_HALO_DEFAULT, "je": ny + N_HALO_DEFAULT - 1, } - return cls(indices, shape_params, rank, layout, backend, local_indices=True) + return cls( + indices, + shape_params, + rank, + layout, + backend, + local_indices=True, + pad_non_interface_dimensions=pad_non_interface_dimensions, + ) @classmethod def from_namelist(cls, namelist: Namelist, rank: int, backend: Backend) -> "Grid": @@ -112,6 +121,7 @@ def __init__( backend: Backend, data_fields: dict | None = None, local_indices: bool = False, + pad_non_interface_dimensions: bool = False, ) -> None: if data_fields is None: data_fields = {} @@ -162,6 +172,7 @@ def __init__( self._grid_data: GridData | None = None self._driver_grid_data: DriverGridData | None = None self._damping_coefficients: DampingCoefficients | None = None + self._pad_non_interface_dimensions = pad_non_interface_dimensions @property def sizer(self) -> GridSizer: @@ -180,6 +191,7 @@ def sizer(self) -> GridSizer: }, layout=self.layout, backend=self.backend, + pad_non_interface_dimensions=self._pad_non_interface_dimensions, ) return self._sizer diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py index 29afc577..4011a1c0 100644 --- a/ndsl/stencils/testing/translate.py +++ b/ndsl/stencils/testing/translate.py @@ -68,10 +68,7 @@ def __init__( self.ordered_input_vars = None self.ignore_near_zero_errors: dict[str, Any] = {} self.skip_test = skip_test - if self.stencil_factory.backend.is_fortran_aligned(): - self.maxshape = self.grid.domain_shape_full() - else: - self.maxshape = self.grid.domain_shape_full(add=(1, 1, 1)) + self.maxshape = self.grid.domain_shape_full(add=(1, 1, 1)) def extra_data_load(self, data_loader: DataLoader): pass @@ -322,7 +319,15 @@ def new_from_serialized_data(cls, serializer, rank, layout, backend: Backend): grid_data[field] = read_serialized_data(serializer, grid_savepoint, field) return cls(grid_data, rank, layout, backend=backend) - def __init__(self, inputs, rank, layout, *, backend: Backend): + def __init__( + self, + inputs, + rank, + layout, + *, + backend: Backend, + pad_non_interface_dimensions: bool = False, + ): self.backend = backend self.indices = {} self.shape_params = {} @@ -338,6 +343,7 @@ def __init__(self, inputs, rank, layout, *, backend: Backend): del inputs[index] self.data = inputs + self._pad_non_interface_dimensions = pad_non_interface_dimensions def _make_composite_var_storage(self, varname, data3d, shape, count): for s in range(count): @@ -444,7 +450,12 @@ def make_grid_storage(self, pygrid): def python_grid(self): pygrid = Grid( - self.indices, self.shape_params, self.rank, self.layout, self.backend + self.indices, + self.shape_params, + self.rank, + self.layout, + self.backend, + pad_non_interface_dimensions=self._pad_non_interface_dimensions, ) self.make_grid_storage(pygrid) pygrid.add_data(self.data) From 86567041586fcf80970d40baaece50db1f33babb Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 21 Jun 2026 09:49:40 -0400 Subject: [PATCH 071/101] [Opt config] Add `kernalize` to `stree` and `enabled` to `stree.merger` --- ndsl/dsl/dace/stree/pipeline.py | 22 ++++++++++++---------- ndsl/dsl/optimization_config.py | 5 +++++ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 0b7d4713..08e29f5e 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -98,16 +98,18 @@ def __init__( cache_directory: Path | None = None, ) -> None: if passes is None: - ppl_passes = [ - CleanUpScheduleTree(), - # TODO: Is it safe? Deactivate for now - # InlineVertical2DWrite(), - CartesianMerge(backend, overcompute=config.stree.merger.overcompute), - KernelizeMaps(backend), - # 🐞 Transient refine can't be used - # because of bugs transients showing in code generation - # CartesianRefineTransients(backend), - ] + ppl_passes = [CleanUpScheduleTree()] + # TODO: Is it safe? Deactivate for now + # ppl_passes.append(InlineVertical2DWrite()) + if config.stree.merger.enabled: + ppl_passes.append( + CartesianMerge(backend, overcompute=config.stree.merger.overcompute) + ) + if config.stree.kernalize: + ppl_passes.append(KernelizeMaps(backend)) + # 🐞 Transient refine can't be used + # because of bugs transients showing in code generation + # CartesianRefineTransients(backend), else: ppl_passes = passes super().__init__( diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py index 049bc8d0..2049ff84 100644 --- a/ndsl/dsl/optimization_config.py +++ b/ndsl/dsl/optimization_config.py @@ -10,6 +10,8 @@ class Tree: @dataclass class Merger: + enabled: bool = False + """Enable cartesian merging""" overcompute: bool = ( os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True").lower() == "true" ) @@ -17,6 +19,9 @@ class Merger: enabled: bool = os.getenv("NDSL_STREE_OPT", "False").lower() == "true" """Enable Schedule Tree transformations""" + kernalize: bool = True + """Enable maximizing 3-axis kernalization by duplicating maps (GPU only)""" + merger: Merger = field(default_factory=Merger) @dataclass From d26bdf49a3fb1ea989cb9c3cc8f47874c2b210cf Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 21 Jun 2026 14:32:24 -0400 Subject: [PATCH 072/101] Always print report of schedule tree opt --- ndsl/dsl/dace/stree/pipeline.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 08e29f5e..25ecd4a3 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -55,10 +55,7 @@ def run( f.write(stree.as_string()) tree_stats.optimized(stree) - - if verbose: - ndsl_log_on_rank_0.info(tree_stats.report()) - + ndsl_log_on_rank_0.info(tree_stats.report()) return stree From 24c33dcf1aef98bad799967c80a0f544afe92569 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 21 Jun 2026 15:19:59 -0400 Subject: [PATCH 073/101] Add 3D kernel count to stree stats --- ndsl/dsl/dace/stree/optimizations/statistics.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ndsl/dsl/dace/stree/optimizations/statistics.py b/ndsl/dsl/dace/stree/optimizations/statistics.py index 6fc927f9..54c0b09d 100644 --- a/ndsl/dsl/dace/stree/optimizations/statistics.py +++ b/ndsl/dsl/dace/stree/optimizations/statistics.py @@ -15,12 +15,18 @@ def __init__(self) -> None: super().__init__() self._maps = [0, 0, 0] self._fors = [0, 0, 0] + self._3D_kernels = 0 def visit_MapScope(self, node: tn.MapScope) -> None: for axis in AxisIterator: if is_axis_map(node, axis): self._maps[axis.as_cartesian_index()] += 1 + if isinstance(node.children[0], tn.MapScope) and isinstance( + node.children[0].children[0], tn.MapScope + ): + self._3D_kernels += 1 + self.visit(node.children) def visit_ForScope(self, node: tn.ForScope) -> None: @@ -61,6 +67,7 @@ class Record: cartesian_maps: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) cartesian_fors: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0]) + threeD_kernels: int = 0 transients: list[int] = dataclasses.field( default_factory=lambda: [0, 0, 0, 0, 0] ) @@ -79,6 +86,7 @@ def _record( c.visit(tree_root) record.cartesian_fors = c._fors record.cartesian_maps = c._maps + record.threeD_kernels = c._3D_kernels c = CountTransient() c.visit(tree_root) @@ -98,4 +106,5 @@ def report(self) -> str: msg += f" Cartesian maps [I, J, K]: {self._original_record.cartesian_maps} -> {self._optimized_record.cartesian_maps}\n" msg += f" Cartesian fors [I, J, K]: {self._original_record.cartesian_fors} -> {self._optimized_record.cartesian_fors}\n" msg += f" Transients [Scalarized Array, 1D, 2D, 3D, 4D+]: {self._original_record.transients} -> {self._optimized_record.transients}\n" + msg += f" Full 3D kernels: {self._original_record.threeD_kernels} -> {self._optimized_record.threeD_kernels}\n" return msg From 450da57586798f66f1775050550c59e8f7bd9026 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 22 Jun 2026 15:22:50 +0200 Subject: [PATCH 074/101] feat/fix: push maps to GPU if on GPU --- ndsl/dsl/dace/orchestration.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 03068abf..eaf5a072 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -203,6 +203,23 @@ def _build_sdfg( compress=True, ) + if config.is_gpu_backend(): + with DaCeProgress(config, "Configure maps to run on GPU"): + for this_sdfg in sdfg.all_sdfgs_recursive(): + for state in this_sdfg.states(): + for node in state.nodes(): + if ( + isinstance(node, nodes.EntryNode) + and node.schedule != ScheduleType.Sequential + ): + node.schedule = ScheduleType.GPU_Device + + ndsl_log.debug("saving 00-gpu-maps.sdfgz") + sdfg.save( + os.path.abspath(f"{sdfg.build_folder}/00-gpu-maps.sdfgz"), + compress=True, + ) + with DaCeProgress(config, "Simplify (1)"): _simplify(sdfg) if config.verbose_orchestration: From dce60afad1672629dcea87155b94ef9cab69ca45 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 22 Jun 2026 17:07:02 +0200 Subject: [PATCH 075/101] refactor: rename "kernalize" -> "kernelize" in optimization config Let's keep the naming of this consistent with the pass name, i.e. `KernelizeMaps` / `KernelizeMap`. --- ndsl/dsl/dace/stree/pipeline.py | 2 +- ndsl/dsl/optimization_config.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 25ecd4a3..fd9932d5 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -102,7 +102,7 @@ def __init__( ppl_passes.append( CartesianMerge(backend, overcompute=config.stree.merger.overcompute) ) - if config.stree.kernalize: + if config.stree.kernelize: ppl_passes.append(KernelizeMaps(backend)) # 🐞 Transient refine can't be used # because of bugs transients showing in code generation diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py index 2049ff84..139ac0b5 100644 --- a/ndsl/dsl/optimization_config.py +++ b/ndsl/dsl/optimization_config.py @@ -18,9 +18,10 @@ class Merger: """When merging allow map of different size to merge by inserting an if guard""" enabled: bool = os.getenv("NDSL_STREE_OPT", "False").lower() == "true" - """Enable Schedule Tree transformations""" - kernalize: bool = True - """Enable maximizing 3-axis kernalization by duplicating maps (GPU only)""" + """Enable Schedule Tree transformations.""" + + kernelize: bool = True + """Enable maximizing 3-axis kernelization by duplicating maps (GPU only).""" merger: Merger = field(default_factory=Merger) From f14836a5c0ebaa9c5dc3f332e7b7679d7af4148a Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 22 Jun 2026 17:15:11 +0200 Subject: [PATCH 076/101] gt4py update: keep dace version in sync --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index e256ec5f..d568ac7d 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit e256ec5f2ae79e6240f7b5a4a29c9647877c12f4 +Subproject commit d568ac7de033b712ca3f86685f8beb9a254e3f05 From 944ee09a0c19a3c6f503684e07071fac862ab820 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 22 Jun 2026 18:08:46 +0200 Subject: [PATCH 077/101] refactor: leveage OptimizationConfig as pipeline config On GPU, we already used the OptimizationConfig to turn on/off certain pipeline passes. This commit extends that work to the `CPUPipeline` and adds OptimizationConfig flags where missing. --- ndsl/dsl/dace/stree/pipeline.py | 32 ++++++----- ndsl/dsl/optimization_config.py | 13 ++++- .../optimizations/test_kernelize_maps.py | 7 ++- .../dace/stree/optimizations/test_merge.py | 7 ++- .../test_offgrid_conditionals.py | 7 ++- .../dace/stree/optimizations/test_pipeline.py | 8 ++- .../stree/optimizations/test_remove_loops.py | 54 +++++-------------- .../optimizations/test_transient_refine.py | 7 ++- 8 files changed, 73 insertions(+), 62 deletions(-) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index fd9932d5..52b1f5db 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -7,6 +7,7 @@ CartesianMerge, CartesianRefineTransients, CleanUpScheduleTree, + InlineVertical2DWrite, KernelizeMaps, TreeOptimizationStatistics, ) @@ -69,13 +70,15 @@ def __init__( cache_directory: Path | None = None, ) -> None: if passes is None: - ppl_passes = [ - CleanUpScheduleTree(), - # TODO: Is it safe? Deactivate for now - # InlineVertical2DWrite(), - CartesianMerge(backend, overcompute=config.stree.merger.overcompute), - CartesianRefineTransients(backend), - ] + ppl_passes = [CleanUpScheduleTree()] + if config.stree.inline_K_loops_size_one: + ppl_passes.append(InlineVertical2DWrite()) + if config.stree.merger.enabled: + ppl_passes.append( + CartesianMerge(backend, overcompute=config.stree.merger.overcompute) + ) + if config.stree.refine_transients: + ppl_passes.append(CartesianRefineTransients(backend)) else: ppl_passes = passes super().__init__( @@ -96,17 +99,22 @@ def __init__( ) -> None: if passes is None: ppl_passes = [CleanUpScheduleTree()] - # TODO: Is it safe? Deactivate for now - # ppl_passes.append(InlineVertical2DWrite()) + if config.stree.inline_K_loops_size_one: + ppl_passes.append(InlineVertical2DWrite()) if config.stree.merger.enabled: ppl_passes.append( CartesianMerge(backend, overcompute=config.stree.merger.overcompute) ) if config.stree.kernelize: ppl_passes.append(KernelizeMaps(backend)) - # 🐞 Transient refine can't be used - # because of bugs transients showing in code generation - # CartesianRefineTransients(backend), + if config.stree.refine_transients: + # TODO + # 🐞 Transient refine can't be used + # because of bugs transients showing in code generation + # ppl_passes.append(CartesianRefineTransients(backend)) + raise ValueError( + "Transient refinement is currently unavailable in the GPU pipeline." + ) else: ppl_passes = passes super().__init__( diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py index 139ac0b5..d2e91fae 100644 --- a/ndsl/dsl/optimization_config.py +++ b/ndsl/dsl/optimization_config.py @@ -11,19 +11,28 @@ class Tree: @dataclass class Merger: enabled: bool = False - """Enable cartesian merging""" + """Enable cartesian axis merging.""" + overcompute: bool = ( os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True").lower() == "true" ) - """When merging allow map of different size to merge by inserting an if guard""" + """When merging allow maps of different sizes to merge by inserting an `if` guard.""" enabled: bool = os.getenv("NDSL_STREE_OPT", "False").lower() == "true" """Enable Schedule Tree transformations.""" + # TODO: Is it safe? Deactivate by default for now + inline_K_loops_size_one: bool = False + """"Remove serial for loops of size one in the K-axis.""" + kernelize: bool = True """Enable maximizing 3-axis kernelization by duplicating maps (GPU only).""" merger: Merger = field(default_factory=Merger) + """Configuration object for cartesian axis merging.""" + + refine_transients: bool = True + """Reduce dimensionality of transient arrays based on their usage.""" @dataclass class GPU: diff --git a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py index 3343b3eb..5ddb9764 100644 --- a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py +++ b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py @@ -46,7 +46,12 @@ def stencil_only_parallel_noop( class OrchestratedCode(NDSLRuntime): def __init__(self, stencil_factory: StencilFactory) -> None: - optimization_config = OptimizationConfig(OptimizationConfig.Tree(enabled=True)) + optimization_config = OptimizationConfig( + stree=OptimizationConfig.Tree( + enabled=True, + merger=OptimizationConfig.Tree.Merger(enabled=True), + ) + ) super().__init__(stencil_factory, optimization_config) methods_to_orchestrate = [ diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index d20e3c4c..36e366b3 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -54,7 +54,12 @@ def __init__( stencil_factory: StencilFactory, quantity_factory: QuantityFactory, ) -> None: - config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) + config = OptimizationConfig( + stree=OptimizationConfig.Tree( + enabled=True, + merger=OptimizationConfig.Tree.Merger(enabled=True), + ) + ) orchestratable_methods = [ "trivial_merge", "missing_merge_of_forscope_and_map", diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py index ce1446b2..6232903a 100644 --- a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py +++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py @@ -18,7 +18,12 @@ class OrchestratedCode(NDSLRuntime): def __init__(self, stencil_factory: StencilFactory) -> None: - config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) + config = OptimizationConfig( + stree=OptimizationConfig.Tree( + enabled=True, + merger=OptimizationConfig.Tree.Merger(enabled=True), + ) + ) super().__init__(stencil_factory, config) methods_to_orchestrate = [ diff --git a/tests/dsl/dace/stree/optimizations/test_pipeline.py b/tests/dsl/dace/stree/optimizations/test_pipeline.py index 89662b4e..545f5ffa 100644 --- a/tests/dsl/dace/stree/optimizations/test_pipeline.py +++ b/tests/dsl/dace/stree/optimizations/test_pipeline.py @@ -16,7 +16,11 @@ def double_map(in_field: FloatField, out_field: FloatField): class TriviallyMergeableCode: def __init__(self, stencil_factory: StencilFactory): - config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) + config = OptimizationConfig( + stree=OptimizationConfig.Tree( + enabled=True, merger=OptimizationConfig.Tree.Merger(enabled=True) + ) + ) orchestrate( obj=self, config=stencil_factory.config.dace_config, @@ -31,7 +35,7 @@ def __call__(self, in_field: FloatField, out_field: FloatField): self.stencil(in_field, out_field) -def test_stree_roundtrip_no_opt(): +def test_stree_roundtrip(): domain = (3, 3, 4) stencil_factory, quantity_factory = get_factories_single_tile_orchestrated( domain[0], domain[1], domain[2], 0, backend=Backend.cpu() diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py index 331fb44b..afb98023 100644 --- a/tests/dsl/dace/stree/optimizations/test_remove_loops.py +++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py @@ -6,12 +6,6 @@ from ndsl.boilerplate import get_factories_single_tile from ndsl.config import Backend, BackendLoopOrder from ndsl.constants import I_DIM, J_DIM, K_DIM, Float -from ndsl.dsl.dace.stree.optimizations import InlineVertical2DWrite -from ndsl.dsl.dace.stree.pipeline import ( - CartesianMerge, - CartesianRefineTransients, - CleanUpScheduleTree, -) from ndsl.dsl.gt4py import FORWARD, computation, interval from ndsl.dsl.typing import FloatField, FloatFieldIJ from ndsl.stencils import copy @@ -44,7 +38,13 @@ def stencil_forward_at_K(in_field: FloatField, out_field: FloatField) -> None: class OrchestratedCode: def __init__(self, stencil_factory: StencilFactory) -> None: - config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) + config = OptimizationConfig( + stree=OptimizationConfig.Tree( + enabled=True, + inline_K_loops_size_one=True, + merger=OptimizationConfig.Tree.Merger(enabled=True), + ) + ) methods_to_orchestrate = [ "write_at_0", "write_at_top", @@ -127,18 +127,12 @@ def factories(self, request: pytest.FixtureRequest) -> Factories: def test_common_2D_write(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) - pipeline = [ - CleanUpScheduleTree(), - InlineVertical2DWrite(), - CartesianMerge(stencil_factory.backend), - CartesianRefineTransients(stencil_factory.backend), - ] in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, 0] = Float(32.0) - with StreePipeline(passes=pipeline): + with StreePipeline(): code.write_at_0(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -160,18 +154,12 @@ def test_common_2D_write(self, factories: Factories) -> None: def test_2D_write_K_top(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) - pipeline = [ - CleanUpScheduleTree(), - InlineVertical2DWrite(), - CartesianMerge(stencil_factory.backend), - CartesianRefineTransients(stencil_factory.backend), - ] in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM], "") in_qty.field[:, :, -1] = Float(32.0) - with StreePipeline(passes=pipeline): + with StreePipeline(): code.write_at_top(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -193,17 +181,11 @@ def test_2D_write_K_top(self, factories: Factories) -> None: def test_do_not_inline(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) - pipeline = [ - CleanUpScheduleTree(), - InlineVertical2DWrite(), - CartesianMerge(stencil_factory.backend), - CartesianRefineTransients(stencil_factory.backend), - ] in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - with StreePipeline(passes=pipeline): + with StreePipeline(): code.do_not_inline(in_qty, out_qty) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -225,18 +207,12 @@ def test_do_not_inline(self, factories: Factories) -> None: def test_combined_stencils(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) - pipeline = [ - CleanUpScheduleTree(), - InlineVertical2DWrite(), - CartesianMerge(stencil_factory.backend), - CartesianRefineTransients(stencil_factory.backend), - ] field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") field_2 = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") - with StreePipeline(passes=pipeline): + with StreePipeline(): code.combined_stencils(field, field_2, field_IJ) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) @@ -262,19 +238,13 @@ def test_combined_stencils(self, factories: Factories) -> None: def test_multiple_statements(self, factories: Factories) -> None: stencil_factory, quantity_factory = factories code = OrchestratedCode(stencil_factory) - pipeline = [ - CleanUpScheduleTree(), - InlineVertical2DWrite(), - CartesianMerge(stencil_factory.backend), - CartesianRefineTransients(stencil_factory.backend), - ] field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "") field_IJ_2 = quantity_factory.zeros([I_DIM, J_DIM], "") field.field[:, :, 0] = Float(42.0) - with StreePipeline(passes=pipeline): + with StreePipeline(): code.multiple_statements(field, field_IJ, field_IJ_2) precompiled_sdfg = get_SDFG_and_purge(stencil_factory) diff --git a/tests/dsl/dace/stree/optimizations/test_transient_refine.py b/tests/dsl/dace/stree/optimizations/test_transient_refine.py index 3190f39d..d3c7604f 100644 --- a/tests/dsl/dace/stree/optimizations/test_transient_refine.py +++ b/tests/dsl/dace/stree/optimizations/test_transient_refine.py @@ -46,7 +46,12 @@ class TransientRefineableCode(NDSLRuntime): def __init__( self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory ) -> None: - config = OptimizationConfig(stree=OptimizationConfig.Tree(enabled=True)) + config = OptimizationConfig( + stree=OptimizationConfig.Tree( + enabled=True, + merger=OptimizationConfig.Tree.Merger(enabled=True), + ) + ) super().__init__(stencil_factory, optimization_config=config) orchestratable_methods = [ "refine_to_scalar", From 3842db55281a55c53057a408b31d51a3d872d472 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 23 Jun 2026 07:33:40 +0200 Subject: [PATCH 078/101] build: ignore non-version tags in setuptools-scm --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a9dbcc1a..03d6eb01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] build-backend = "setuptools.build_meta" -requires = ["setuptools >= 80", "setuptools-scm>=8"] +requires = ["setuptools >= 80", "setuptools-scm>=10"] [project] authors = [{name = "NOAA/NASA"}] @@ -147,3 +147,4 @@ include = ["ndsl", "ndsl.*"] [tool.setuptools_scm] version_scheme = "no-guess-dev" +tag.strict = "true" # only consider tags with at least one dot (e.g. v1.0, 2025.06.03) From 71efc5a2975d8a15453325553be5ef4239cebeff Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 23 Jun 2026 11:39:22 +0200 Subject: [PATCH 079/101] gt4py update: avoid `x+0` / `0+x` in horizontal regions --- external/gt4py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external/gt4py b/external/gt4py index d568ac7d..98e63aa9 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit d568ac7de033b712ca3f86685f8beb9a254e3f05 +Subproject commit 98e63aa945a7af0c5ed06bda88f00bb9304e4911 From 9b4d0be187784c4316b25d7b149f2023a9369641 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 23 Jun 2026 12:56:53 +0200 Subject: [PATCH 080/101] update dace/gt4py: gpu transformation fix --- external/dace | 2 +- external/gt4py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/external/dace b/external/dace index d186d86d..7c526886 160000 --- a/external/dace +++ b/external/dace @@ -1 +1 @@ -Subproject commit d186d86dea15f7852545dcde0c4f5b9e6d4f072b +Subproject commit 7c526886bfadeb9808a06a66fcbca1dbfa6b8ad4 diff --git a/external/gt4py b/external/gt4py index 98e63aa9..331c7bba 160000 --- a/external/gt4py +++ b/external/gt4py @@ -1 +1 @@ -Subproject commit 98e63aa945a7af0c5ed06bda88f00bb9304e4911 +Subproject commit 331c7bba9161b96cf94f6d5d9bda06161703db28 From 908b02e0256ceee7a049d799e096ab8d2f193ed4 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 23 Jun 2026 11:27:05 -0400 Subject: [PATCH 081/101] Push `matplolib` import into the plotting function --- ndsl/quantity/quantity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py index 5d310674..d493fe5f 100644 --- a/ndsl/quantity/quantity.py +++ b/ndsl/quantity/quantity.py @@ -6,7 +6,6 @@ from typing import Any, cast import dace -import matplotlib.pyplot as plt import numpy as np import xarray as xr from gt4py import storage as gt_storage @@ -459,6 +458,8 @@ def transpose( return transposed def plot_k_level(self, k_index: int = 0) -> None: + import matplotlib.pyplot as plt + field = self._data plt.xlabel("I") plt.ylabel("J") From b8432e414618f6930d1bc9da0d35770f21ecab96 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 24 Jun 2026 17:37:15 -0400 Subject: [PATCH 082/101] Moved log to grab `_parse` call and move labeler so it is applied only once --- ndsl/dsl/dace/orchestration.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index eaf5a072..dfea89f8 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -464,7 +464,6 @@ def _call_sdfg( mode in [DaCeOrchestration.Build, DaCeOrchestration.BuildAndRun] and dace_program not in config.loaded_dace_executables # already cached ): - ndsl_log.info("Building DaCe orchestration") _build_sdfg(dace_program, sdfg, config, optimization_config, args, kwargs) if mode not in [DaCeOrchestration.BuildAndRun, DaCeOrchestration.Run]: @@ -510,6 +509,7 @@ def _call_sdfg( def _parse_sdfg( dace_program: DaceProgram, config: DaceConfig, + optimization: OptimizationConfig, *args: Any, **kwargs: Any, ) -> SDFG | CompiledSDFG | None: @@ -524,6 +524,8 @@ def _parse_sdfg( if dace_program in config.loaded_dace_executables: return config.loaded_dace_executables[dace_program].compiled_sdfg + ndsl_log.info(f"Building DaCe orchestration for {dace_program.f.__qualname__}") + # Build expected path sdfg_path = get_sdfg_path(dace_program.name, config) if sdfg_path is None: @@ -545,6 +547,11 @@ def _parse_sdfg( simplify=False, validate=False, # TODO: should we have a "debug flag" to turn this on? ) + + # Label the code (this is the topmost code) + if sdfg is not None and optimization.stree.enabled: + set_label(sdfg, dace_program.f.__qualname__, is_top_sdfg=True) + return sdfg if os.path.isfile(sdfg_path): @@ -589,6 +596,7 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] sdfg = _parse_sdfg( self.daceprog, self.config, + self.optimization_config, *args, **kwargs, ) @@ -657,12 +665,10 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] sdfg = _parse_sdfg( self.daceprog, self.lazy_method.config, + self.lazy_method.optimization_config, *args, **kwargs, ) - # Label the code (this is the topmost code) - if sdfg is not None and self.lazy_method.optimization_config.stree.enabled: - set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=True) return _call_sdfg( self.daceprog, sdfg, From 633f73b1d9279d87c18405dddf75068fa46313ed Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 25 Jun 2026 08:23:09 +0200 Subject: [PATCH 083/101] fixup: consistently pass opt config to `_parse_sdfg()` --- ndsl/dsl/dace/orchestration.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index dfea89f8..085e42e3 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -618,7 +618,9 @@ def global_vars(self, value): # type: ignore[no-untyped-def] self.daceprog.global_vars = value def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] - return _parse_sdfg(self.daceprog, self.config, *args, **kwargs) + return _parse_sdfg( + self.daceprog, self.config, self.optimization_config, *args, **kwargs + ) def __sdfg_closure__(self, *args, **kwargs): # type: ignore[no-untyped-def] return self.daceprog.__sdfg_closure__(*args, **kwargs) @@ -679,7 +681,13 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] ) def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] - sdfg = _parse_sdfg(self.daceprog, self.lazy_method.config, *args, **kwargs) + sdfg = _parse_sdfg( + self.daceprog, + self.lazy_method.config, + self.lazy_method.optimization_config, + *args, + **kwargs, + ) # Label the code if sdfg is not None and self.lazy_method.optimization_config.stree.enabled: set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=False) From 7d480480c0f5993f5c9e0fa2102e36b03786f15f Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 25 Jun 2026 08:25:08 +0200 Subject: [PATCH 084/101] refacor: use "dace import slang" in our memlet helper --- ndsl/dsl/dace/stree/optimizations/common/memlet.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index 266be19d..edb4f3e0 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -1,7 +1,7 @@ from enum import Enum -import dace.sdfg.analysis.schedule_tree.treenodes as stree from dace.memlet import Memlet +from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.symbolic import symbol from ndsl import ndsl_log @@ -35,8 +35,8 @@ def normalize_cartesian_indexation(index: symbol, axis: AxisIterator) -> symbol: def no_data_dependencies_on_cartesian_axis( - first: stree.MapScope, - second: stree.MapScope, + first: tn.MapScope, + second: tn.MapScope, axis: AxisIterator, ) -> bool: """Check for read after write and write after write with different offsets.""" @@ -94,7 +94,7 @@ def no_data_dependencies_on_cartesian_axis( return True -class MemletCollector(stree.ScheduleNodeVisitor): +class MemletCollector(tn.ScheduleNodeVisitor): """Gathers in_memlets and out_memlets of TaskNodes and LibraryCalls.""" in_memlets: list[Memlet] @@ -109,13 +109,13 @@ def __init__( self.in_memlets = [] self.out_memlets = [] - def visit_TaskletNode(self, node: stree.TaskletNode) -> None: + def visit_TaskletNode(self, node: tn.TaskletNode) -> None: if self._collect_reads: self.in_memlets.extend([memlet for memlet in node.in_memlets.values()]) if self._collect_writes: self.out_memlets.extend([memlet for memlet in node.out_memlets.values()]) - def visit_LibraryCall(self, node: stree.LibraryCall) -> None: + def visit_LibraryCall(self, node: tn.LibraryCall) -> None: if self._collect_reads: if isinstance(node.in_memlets, set): self.in_memlets.extend(node.in_memlets) @@ -133,7 +133,7 @@ def visit_LibraryCall(self, node: stree.LibraryCall) -> None: ) -def has_dynamic_memlets(first: stree.MapScope, second: stree.MapScope) -> bool: +def has_dynamic_memlets(first: tn.MapScope, second: tn.MapScope) -> bool: first_collector = MemletCollector() second_collector = MemletCollector() first_collector.visit(first) From de9763de52a53e67394e08e11feec9a7f68e3e55 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 25 Jun 2026 09:08:31 +0200 Subject: [PATCH 085/101] fix: account for map start in axis normalization --- .../dace/stree/optimizations/common/memlet.py | 39 ++++++--- tests/dsl/dace/stree/common/test_memlet.py | 80 +++++++++++++++++-- 2 files changed, 105 insertions(+), 14 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index edb4f3e0..cefd96ac 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -25,13 +25,34 @@ def is_equal(self, other: str) -> bool: return other == self.as_str() -def normalize_cartesian_indexation(index: symbol, axis: AxisIterator) -> symbol: - """Return a normalize indexation symbol for cartesian indexation.""" +def normalize_cartesian_indexation( + index: symbol, axis: AxisIterator, map_scope: tn.MapScope +) -> symbol: + """Return a normalized indexation symbol for cartesian indexation.""" + if len(map_scope.node.map.params) != 1: + raise ValueError( + f"Expected a map with only one parameter, got {map_scope.node.map.params}." + ) + + axis_name = axis.as_str() + if not map_scope.node.map.params[0].startswith(axis_name): + raise ValueError( + f"Mismatch of axis iterator {axis} and MapScope parameter {map_scope.node.map.params}." + ) + + # potentially rename rename_maps = {} - for symb in index.free_symbols: - if symb.name.startswith(axis.as_str()): - rename_maps[symb] = symbol(axis.as_str()) - return index.subs(rename_maps) + for sym in index.free_symbols: + if sym.name != axis_name and sym.name.startswith(axis_name): + rename_maps[sym] = symbol(axis_name) + renamed = index.subs(rename_maps) + + # handle potential map start + map_start = map_scope.node.map.range.min_element()[0] + if map_start != 0: + return renamed + map_start + + return renamed def no_data_dependencies_on_cartesian_axis( @@ -58,14 +79,14 @@ def no_data_dependencies_on_cartesian_axis( continue previous_axis_index = normalize_cartesian_indexation( - write.subset[axis_index][0], axis + write.subset[axis_index][0], axis, first ) # Write-after-write with an offset case for other_write in other_writes.out_memlets: if write.data == other_write.data: if previous_axis_index != normalize_cartesian_indexation( - other_write.subset[axis_index][0], axis + other_write.subset[axis_index][0], axis, second ): ndsl_log.debug( f"[{axis.name} Merge] Found write after write conflict " @@ -80,7 +101,7 @@ def no_data_dependencies_on_cartesian_axis( for read in read_collector.in_memlets: if write.data == read.data: if previous_axis_index != normalize_cartesian_indexation( - read.subset[axis_index][0], axis + read.subset[axis_index][0], axis, second ): ndsl_log.debug( f"[{axis.name} Merge] Found read after write conflict " diff --git a/tests/dsl/dace/stree/common/test_memlet.py b/tests/dsl/dace/stree/common/test_memlet.py index 44fe15e1..ebd6f94f 100644 --- a/tests/dsl/dace/stree/common/test_memlet.py +++ b/tests/dsl/dace/stree/common/test_memlet.py @@ -1,3 +1,6 @@ +import pytest +from dace import nodes, subsets +from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.symbolic import symbol from ndsl.dsl.dace.stree.optimizations.common import AxisIterator @@ -6,27 +9,94 @@ ) -def test_normalize_cartesian_index(): +@pytest.fixture +def k_map() -> tn.MapScope: + return tn.MapScope( + node=nodes.MapEntry( + nodes.Map("map", ["__k_123456789"], subsets.Range.from_string("0:5")) + ), + children=[], + ) + + +def test_normalize_cartesian_index(k_map: tn.MapScope) -> None: # Case of __k_id(node) - original case original_symbol = symbol("__k_12345678789") - norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + norm_symbol = normalize_cartesian_indexation( + original_symbol, AxisIterator._K, k_map + ) assert norm_symbol == symbol("__k") # Case of offset original_symbol = 1 + symbol("__k_12345678789") - norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + norm_symbol = normalize_cartesian_indexation( + original_symbol, AxisIterator._K, k_map + ) assert norm_symbol == symbol("__k") + 1 # Case of no-op (with offset) original_symbol = 1 + symbol("__k") - norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + norm_symbol = normalize_cartesian_indexation( + original_symbol, AxisIterator._K, k_map + ) assert norm_symbol == symbol("__k") + 1 # Case of index named with _k - so not a cartesian axis original_symbol = 1 + symbol("_kindex") - norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) + norm_symbol = normalize_cartesian_indexation( + original_symbol, AxisIterator._K, k_map + ) assert norm_symbol == symbol("_kindex") + 1 + + +def test_normalize_cartesian_index_map_two_params() -> None: + ij_map = tn.MapScope( + node=nodes.MapEntry( + nodes.Map("map", ["__i", "__j"], subsets.Range([(0, 3, 1), (0, 4, 2)])) + ), + children=[], + ) + with pytest.raises(ValueError, match="Expected a map with only one parameter"): + normalize_cartesian_indexation(symbol("__i"), AxisIterator._I, ij_map) + + +def test_normalize_cartesian_index_map_wrong_index(k_map) -> None: + with pytest.raises(ValueError, match="Mismatch of axis iterator"): + normalize_cartesian_indexation(symbol("__i"), AxisIterator._I, k_map) + + +def test_normalize_cartesian_index_map_start(k_map) -> None: + map_m1 = tn.MapScope( + node=nodes.MapEntry( + nodes.Map("map", ["__i"], subsets.Range.from_string("-1:3")) + ), + children=[], + ) + + original_symbol = symbol("__i") + normalized = normalize_cartesian_indexation( + original_symbol, AxisIterator._I, map_m1 + ) + assert normalized == original_symbol - 1 + + original_symbol = 1 + symbol("__i") + normalized = normalize_cartesian_indexation( + original_symbol, AxisIterator._I, map_m1 + ) + assert normalized == symbol("__i") + + original_symbol = symbol("__i") + 5 + normalized = normalize_cartesian_indexation( + original_symbol, AxisIterator._I, map_m1 + ) + assert normalized == symbol("__i") + 4 + + original_symbol = 1 + symbol("__i_1234") + normalized = normalize_cartesian_indexation( + original_symbol, AxisIterator._I, map_m1 + ) + assert normalized == symbol("__i") From b24e1fcc11e045a8d29a5fb726e7abd90550a392 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 25 Jun 2026 09:14:57 +0200 Subject: [PATCH 086/101] fixup: use normalized indices in debug message --- .../dace/stree/optimizations/common/memlet.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index cefd96ac..01858807 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -85,30 +85,32 @@ def no_data_dependencies_on_cartesian_axis( # Write-after-write with an offset case for other_write in other_writes.out_memlets: if write.data == other_write.data: - if previous_axis_index != normalize_cartesian_indexation( + current_axis_index = normalize_cartesian_indexation( other_write.subset[axis_index][0], axis, second - ): + ) + if previous_axis_index != current_axis_index: ndsl_log.debug( f"[{axis.name} Merge] Found write after write conflict " f"for {write.data} " f"w/ different offset to {axis.name} (" f"first write at {previous_axis_index}, " - f"second write at {other_write.subset[axis_index][0]})" + f"second write at {current_axis_index})" ) return False # Read-after-write with an offset case for read in read_collector.in_memlets: if write.data == read.data: - if previous_axis_index != normalize_cartesian_indexation( + current_axis_index = normalize_cartesian_indexation( read.subset[axis_index][0], axis, second - ): + ) + if previous_axis_index != current_axis_index: ndsl_log.debug( f"[{axis.name} Merge] Found read after write conflict " f"for {write.data} " f"w/ different offset to {axis.name} (" - f"write at {write.subset[axis_index][0]}, " - f"read at {read.subset[axis_index][0]})" + f"write at {previous_axis_index}, " + f"read at {current_axis_index})" ) return False From 0d9844581c441ab7679daf0c75b7045a90155978 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 25 Jun 2026 13:09:23 +0200 Subject: [PATCH 087/101] Make "dace" import non-ambiguous Because we have a "dace/" directory in `ndsl/dsl`, the previous import could be resolved as a local import. If that happened (depending on import order), then the DaCe's `Config` object would not be found there. Resolved by importing `Config` from `dace.config`, which is unambiguous to resolve. --- ndsl/dsl/stencil.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 7b448944..96d00946 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -7,8 +7,8 @@ from collections.abc import Callable, Iterable, Mapping, Sequence from typing import Any, cast -import dace import numpy as np +from dace.config import Config as DaceConfig from gt4py.cartesian import config as gt_config from gt4py.cartesian import definitions as gt_definitions from gt4py.cartesian import gtscript @@ -321,7 +321,7 @@ def __init__( BackendFramework.DACE == self.stencil_config.compilation_config.backend.framework ): - dace.Config.set( + DaceConfig.set( "default_build_folder", value="{gt_root}/{gt_cache}/dacecache".format( gt_root=gt_config.cache_settings["root_path"], From 668334726d1d75dc24e7bf288193c5e2c65d15ba Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 25 Jun 2026 13:14:41 +0200 Subject: [PATCH 088/101] feat: support for custom merging oder in OptimizationConfig The default merging order for `CartesianMerge` is to follow the loop order of the given backend. This commit adds support for a custom merge oder override. --- .../stree/optimizations/cartesian_merge.py | 56 +++++++++++++++++-- ndsl/dsl/dace/stree/pipeline.py | 6 +- ndsl/dsl/optimization_config.py | 6 ++ 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index 16d72380..d943e8df 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -18,26 +18,53 @@ class CartesianMerge(tn.ScheduleNodeTransformer): overcompute: Whether to merge at the cost of an if statement. Defaults to True. """ - def __init__(self, backend: Backend, *, overcompute: bool = True) -> None: + def __init__( + self, + backend: Backend, + *, + overcompute: bool = True, + merge_order: str = "default", + ) -> None: super().__init__() self._backend = backend self._overcompute = overcompute + self._merge_order = merge_order + + if self._merge_order not in ( + "default", + "IJK", + "IKJ", + "JIK", + "JKI", + "KIJ", + "KJI", + ): + raise ValueError(f"Unexpected merge order {self._merge_order}.") def __str__(self) -> str: return "CartesianMerge" def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: - for axis in self._backend_order(): + axis_merge_order = self._axis_merge_order() + for axis in axis_merge_order: InlineOffgridConditionals(axis).visit(node) MergeConditionals().visit(node) - for axis in self._backend_order(): + for axis in axis_merge_order: CartesianAxisMerge(axis, overcompute=self._overcompute).visit(node) ExtractOffgridConditionals().visit(node) MergeConditionals().visit(node) - def _backend_order(self) -> tuple[AxisIterator, AxisIterator, AxisIterator]: + def _axis_merge_order(self) -> tuple[AxisIterator, AxisIterator, AxisIterator]: + if self._merge_order == "default": + return self._axis_merge_order_default() + + return self._axis_merge_order_custom() + + def _axis_merge_order_default( + self, + ) -> tuple[AxisIterator, AxisIterator, AxisIterator]: if self._backend.loop_order == BackendLoopOrder.IJK: return (AxisIterator._I, AxisIterator._J, AxisIterator._K) @@ -55,3 +82,24 @@ def _backend_order(self) -> tuple[AxisIterator, AxisIterator, AxisIterator]: assert self._backend.loop_order == BackendLoopOrder.KJI return (AxisIterator._K, AxisIterator._J, AxisIterator._I) + + def _axis_merge_order_custom( + self, + ) -> tuple[AxisIterator, AxisIterator, AxisIterator]: + if self._merge_order == "IJK": + return (AxisIterator._I, AxisIterator._J, AxisIterator._K) + + if self._merge_order == "IKJ": + return (AxisIterator._I, AxisIterator._K, AxisIterator._J) + + if self._merge_order == "JIK": + return (AxisIterator._J, AxisIterator._I, AxisIterator._K) + + if self._merge_order == "JKI": + return (AxisIterator._J, AxisIterator._K, AxisIterator._I) + + if self._merge_order == "KIJ": + return (AxisIterator._K, AxisIterator._I, AxisIterator._J) + + assert self._merge_order == "KJI" + return (AxisIterator._K, AxisIterator._J, AxisIterator._I) diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index 52b1f5db..cb1bbb12 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -75,7 +75,11 @@ def __init__( ppl_passes.append(InlineVertical2DWrite()) if config.stree.merger.enabled: ppl_passes.append( - CartesianMerge(backend, overcompute=config.stree.merger.overcompute) + CartesianMerge( + backend, + overcompute=config.stree.merger.overcompute, + merge_order=config.stree.merger.order, + ) ) if config.stree.refine_transients: ppl_passes.append(CartesianRefineTransients(backend)) diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py index d2e91fae..feb22abe 100644 --- a/ndsl/dsl/optimization_config.py +++ b/ndsl/dsl/optimization_config.py @@ -18,6 +18,12 @@ class Merger: ) """When merging allow maps of different sizes to merge by inserting an `if` guard.""" + order: str = "default" + """ + Allows to manually override the merging order (e.g. `KJI` will merge `K`, then `J`, then `I`). + The default follows loop order of the backend given to `CartesianMerge`. + """ + enabled: bool = os.getenv("NDSL_STREE_OPT", "False").lower() == "true" """Enable Schedule Tree transformations.""" From 74ff2026b884d39556cf5a476378117a4bddd247 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 26 Jun 2026 07:32:37 +0200 Subject: [PATCH 089/101] Revert accounting for map start in axis normalization Revert "fixup: use normalized indices in debug message" This reverts commit b24e1fcc11e045a8d29a5fb726e7abd90550a392. Revert "fix: account for map start in axis normalization" This reverts commit de9763de52a53e67394e08e11feec9a7f68e3e55. --- .../dace/stree/optimizations/common/memlet.py | 55 ++++--------- tests/dsl/dace/stree/common/test_memlet.py | 80 ++----------------- 2 files changed, 21 insertions(+), 114 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index 01858807..edb4f3e0 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -25,34 +25,13 @@ def is_equal(self, other: str) -> bool: return other == self.as_str() -def normalize_cartesian_indexation( - index: symbol, axis: AxisIterator, map_scope: tn.MapScope -) -> symbol: - """Return a normalized indexation symbol for cartesian indexation.""" - if len(map_scope.node.map.params) != 1: - raise ValueError( - f"Expected a map with only one parameter, got {map_scope.node.map.params}." - ) - - axis_name = axis.as_str() - if not map_scope.node.map.params[0].startswith(axis_name): - raise ValueError( - f"Mismatch of axis iterator {axis} and MapScope parameter {map_scope.node.map.params}." - ) - - # potentially rename +def normalize_cartesian_indexation(index: symbol, axis: AxisIterator) -> symbol: + """Return a normalize indexation symbol for cartesian indexation.""" rename_maps = {} - for sym in index.free_symbols: - if sym.name != axis_name and sym.name.startswith(axis_name): - rename_maps[sym] = symbol(axis_name) - renamed = index.subs(rename_maps) - - # handle potential map start - map_start = map_scope.node.map.range.min_element()[0] - if map_start != 0: - return renamed + map_start - - return renamed + for symb in index.free_symbols: + if symb.name.startswith(axis.as_str()): + rename_maps[symb] = symbol(axis.as_str()) + return index.subs(rename_maps) def no_data_dependencies_on_cartesian_axis( @@ -79,38 +58,36 @@ def no_data_dependencies_on_cartesian_axis( continue previous_axis_index = normalize_cartesian_indexation( - write.subset[axis_index][0], axis, first + write.subset[axis_index][0], axis ) # Write-after-write with an offset case for other_write in other_writes.out_memlets: if write.data == other_write.data: - current_axis_index = normalize_cartesian_indexation( - other_write.subset[axis_index][0], axis, second - ) - if previous_axis_index != current_axis_index: + if previous_axis_index != normalize_cartesian_indexation( + other_write.subset[axis_index][0], axis + ): ndsl_log.debug( f"[{axis.name} Merge] Found write after write conflict " f"for {write.data} " f"w/ different offset to {axis.name} (" f"first write at {previous_axis_index}, " - f"second write at {current_axis_index})" + f"second write at {other_write.subset[axis_index][0]})" ) return False # Read-after-write with an offset case for read in read_collector.in_memlets: if write.data == read.data: - current_axis_index = normalize_cartesian_indexation( - read.subset[axis_index][0], axis, second - ) - if previous_axis_index != current_axis_index: + if previous_axis_index != normalize_cartesian_indexation( + read.subset[axis_index][0], axis + ): ndsl_log.debug( f"[{axis.name} Merge] Found read after write conflict " f"for {write.data} " f"w/ different offset to {axis.name} (" - f"write at {previous_axis_index}, " - f"read at {current_axis_index})" + f"write at {write.subset[axis_index][0]}, " + f"read at {read.subset[axis_index][0]})" ) return False diff --git a/tests/dsl/dace/stree/common/test_memlet.py b/tests/dsl/dace/stree/common/test_memlet.py index ebd6f94f..44fe15e1 100644 --- a/tests/dsl/dace/stree/common/test_memlet.py +++ b/tests/dsl/dace/stree/common/test_memlet.py @@ -1,6 +1,3 @@ -import pytest -from dace import nodes, subsets -from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.symbolic import symbol from ndsl.dsl.dace.stree.optimizations.common import AxisIterator @@ -9,94 +6,27 @@ ) -@pytest.fixture -def k_map() -> tn.MapScope: - return tn.MapScope( - node=nodes.MapEntry( - nodes.Map("map", ["__k_123456789"], subsets.Range.from_string("0:5")) - ), - children=[], - ) - - -def test_normalize_cartesian_index(k_map: tn.MapScope) -> None: +def test_normalize_cartesian_index(): # Case of __k_id(node) - original case original_symbol = symbol("__k_12345678789") - norm_symbol = normalize_cartesian_indexation( - original_symbol, AxisIterator._K, k_map - ) + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) assert norm_symbol == symbol("__k") # Case of offset original_symbol = 1 + symbol("__k_12345678789") - norm_symbol = normalize_cartesian_indexation( - original_symbol, AxisIterator._K, k_map - ) + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) assert norm_symbol == symbol("__k") + 1 # Case of no-op (with offset) original_symbol = 1 + symbol("__k") - norm_symbol = normalize_cartesian_indexation( - original_symbol, AxisIterator._K, k_map - ) + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) assert norm_symbol == symbol("__k") + 1 # Case of index named with _k - so not a cartesian axis original_symbol = 1 + symbol("_kindex") - norm_symbol = normalize_cartesian_indexation( - original_symbol, AxisIterator._K, k_map - ) + norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K) assert norm_symbol == symbol("_kindex") + 1 - - -def test_normalize_cartesian_index_map_two_params() -> None: - ij_map = tn.MapScope( - node=nodes.MapEntry( - nodes.Map("map", ["__i", "__j"], subsets.Range([(0, 3, 1), (0, 4, 2)])) - ), - children=[], - ) - with pytest.raises(ValueError, match="Expected a map with only one parameter"): - normalize_cartesian_indexation(symbol("__i"), AxisIterator._I, ij_map) - - -def test_normalize_cartesian_index_map_wrong_index(k_map) -> None: - with pytest.raises(ValueError, match="Mismatch of axis iterator"): - normalize_cartesian_indexation(symbol("__i"), AxisIterator._I, k_map) - - -def test_normalize_cartesian_index_map_start(k_map) -> None: - map_m1 = tn.MapScope( - node=nodes.MapEntry( - nodes.Map("map", ["__i"], subsets.Range.from_string("-1:3")) - ), - children=[], - ) - - original_symbol = symbol("__i") - normalized = normalize_cartesian_indexation( - original_symbol, AxisIterator._I, map_m1 - ) - assert normalized == original_symbol - 1 - - original_symbol = 1 + symbol("__i") - normalized = normalize_cartesian_indexation( - original_symbol, AxisIterator._I, map_m1 - ) - assert normalized == symbol("__i") - - original_symbol = symbol("__i") + 5 - normalized = normalize_cartesian_indexation( - original_symbol, AxisIterator._I, map_m1 - ) - assert normalized == symbol("__i") + 4 - - original_symbol = 1 + symbol("__i_1234") - normalized = normalize_cartesian_indexation( - original_symbol, AxisIterator._I, map_m1 - ) - assert normalized == symbol("__i") From f37b54787f879c08bcee83f99433ed68d80e7b27 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 26 Jun 2026 12:11:11 +0200 Subject: [PATCH 090/101] feature: `LabledSection`s for use with local optimizations --- ndsl/dsl/dace/labeler.py | 52 +++++++++-- ndsl/dsl/dace/orchestration.py | 50 +++++----- ndsl/dsl/dace/stree/optimizations/__init__.py | 2 + .../dace/stree/optimizations/clean_tree.py | 8 ++ .../optimizations/local_optimizations.py | 93 +++++++++++++++++++ ndsl/dsl/dace/stree/pipeline.py | 5 +- ndsl/dsl/ndsl_runtime.py | 9 +- ndsl/dsl/optimization_config.py | 2 + 8 files changed, 180 insertions(+), 41 deletions(-) create mode 100644 ndsl/dsl/dace/stree/optimizations/local_optimizations.py diff --git a/ndsl/dsl/dace/labeler.py b/ndsl/dsl/dace/labeler.py index 08398ca0..b2bb102f 100644 --- a/ndsl/dsl/dace/labeler.py +++ b/ndsl/dsl/dace/labeler.py @@ -1,11 +1,11 @@ -from __future__ import annotations - from typing import Any -import dace.properties +import dace from dace import library, nodes from dace.transformation import transformation as xf +from ndsl import OptimizationConfig + @library.node class _Labeler(nodes.LibraryNode): @@ -13,9 +13,28 @@ class _Labeler(nodes.LibraryNode): default_implementation = "pure" unique_name = dace.properties.Property(dtype=str, desc="Unique name") - def __init__(self, unique_name: str, **kwargs: dict[str, Any]) -> None: + def __init__( + self, + unique_name: str, + local_optimization: OptimizationConfig | None, + **kwargs: dict[str, Any], + ) -> None: super().__init__(name="NDSLRuntime_Label", **kwargs) + # HACK to avoid state fusion of labeler states + # MPI WaitAll block state fusion, so we just pretend to be one 🐉. + # Keeping the labeler states non-fused is important to keep code flow consistent until we + # get to the schedule tree. + self.label = "_Waitall_" + self._unique_name = unique_name + self._local_optimizations = local_optimization + + def has_side_effects(self) -> bool: + # HACK + # LibraryNodes with side effects aren't touched by simplify. This + # keeps the library nodes alive until we get to the schedule tree + # where we can use the information. + return True @library.register_expansion(_Labeler, "pure") @@ -32,7 +51,10 @@ def expansion( def set_label( - sdfg: dace.SDFG | dace.CompiledSDFG, qualname: str, is_top_sdfg: bool + sdfg: dace.SDFG | dace.CompiledSDFG, + qualname: str, + is_top_sdfg: bool, + local_optimizations: OptimizationConfig | None, ) -> None: """Surround the SDFG with two state/library node combo labelling the code for future reference in further optimization. @@ -50,19 +72,29 @@ def set_label( # With the topmost SDFG we have to skip over the # "init" state if is_top_sdfg: - state = sdfg.add_state_after( + label_state = sdfg.add_state_after( state, label=f"__Label_Enter__{qualname}", ) else: - state = sdfg.add_state_before( + label_state = sdfg.add_state_before( state, label=f"__Label_Enter__{qualname}", ) - state.add_node(_Labeler(unique_name=f"Enter__{qualname}")) + label_state.add_node( + _Labeler( + unique_name=f"Enter__{qualname}", + local_optimization=local_optimizations, + ) + ) if sdfg.out_edges(state) == []: - state = sdfg.add_state_after( + label_state = sdfg.add_state_after( state, label=f"__Label_Exit__{qualname}", ) - state.add_node(_Labeler(unique_name=f"Exit__{qualname}")) + label_state.add_node( + _Labeler( + unique_name=f"Exit__{qualname}", + local_optimization=local_optimizations, + ) + ) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index 085e42e3..d267a98b 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -171,7 +171,7 @@ def _build_sdfg( dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, - optimization_config: OptimizationConfig, + optimization_config: OptimizationConfig | None, args: Any, kwargs: Any, ) -> None: @@ -181,6 +181,10 @@ def _build_sdfg( backend_name = config.get_backend() if is_compiling: + if optimization_config is None: + ndsl_log.debug(f"Using default optimization config for {sdfg.label}.") + optimization_config = OptimizationConfig() + ndsl_log.debug(f"Compiling config:\n{pformat(optimization_config, indent=2)}") # Fully specialize all known symbols and then propagate these changes in the simplify # pass that follows. This is not only a smart idea in general, but also simplifies (haha) @@ -451,7 +455,7 @@ def _call_sdfg( dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, - optimization_config: OptimizationConfig, + optimization_config: OptimizationConfig | None, args: Any, kwargs: Any, ) -> list | None: @@ -509,7 +513,7 @@ def _call_sdfg( def _parse_sdfg( dace_program: DaceProgram, config: DaceConfig, - optimization: OptimizationConfig, + optimization: OptimizationConfig | None, *args: Any, **kwargs: Any, ) -> SDFG | CompiledSDFG | None: @@ -549,8 +553,13 @@ def _parse_sdfg( ) # Label the code (this is the topmost code) - if sdfg is not None and optimization.stree.enabled: - set_label(sdfg, dace_program.f.__qualname__, is_top_sdfg=True) + if sdfg is not None and optimization is not None and optimization.stree.enabled: + set_label( + sdfg, + dace_program.f.__qualname__, + is_top_sdfg=True, + local_optimizations=optimization, + ) return sdfg @@ -583,7 +592,7 @@ def __init__( self, func: Callable, config: DaceConfig, - optimization_config: OptimizationConfig, + optimization_config: OptimizationConfig | None, ) -> None: self.func = func self.config = config @@ -689,8 +698,17 @@ def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def] **kwargs, ) # Label the code - if sdfg is not None and self.lazy_method.optimization_config.stree.enabled: - set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=False) + if ( + sdfg is not None + and self.lazy_method.optimization_config is not None + and self.lazy_method.optimization_config.stree.enabled + ): + set_label( + sdfg, + type(self.obj_to_bind).__qualname__, + is_top_sdfg=False, + local_optimizations=self.lazy_method.optimization_config, + ) return sdfg def __sdfg_closure__(self, reevaluate=None): # type: ignore[no-untyped-def] @@ -708,7 +726,7 @@ def __init__( self, func: Callable, config: DaceConfig, - optimization_config: OptimizationConfig, + optimization_config: OptimizationConfig | None, ) -> None: self.func = func self.config = config @@ -762,11 +780,6 @@ def orchestrate( if dace_compiletime_args is None: dace_compiletime_args = [] - if optimization_config is None: - opt_config = OptimizationConfig() - else: - opt_config = optimization_config - func: Callable = type.__getattribute__(type(obj), method_to_orchestrate) # Flag argument as dace.constant @@ -789,7 +802,7 @@ def orchestrate( # Build DaCe orchestrated wrapper # This is a JIT object, e.g. DaCe compilation will happen on call - wrapped = _LazyComputepathMethod(func, config, opt_config).__get__(obj) + wrapped = _LazyComputepathMethod(func, config, optimization_config).__get__(obj) if method_to_orchestrate == "__call__": # Grab the function from the type of the child class @@ -856,16 +869,11 @@ def orchestrate_function( if dace_compiletime_args is None: dace_compiletime_args = [] - if optimization_config is None: - opt_config = OptimizationConfig() - else: - opt_config = optimization_config - def _decorator(func: Callable[..., Any]): # type: ignore[no-untyped-def] def _wrapper(*args, **kwargs): # type: ignore[no-untyped-def] for argument in dace_compiletime_args: func.__annotations__[argument] = DaceCompiletime - return _LazyComputepathFunction(func, config, opt_config) + return _LazyComputepathFunction(func, config, optimization_config) return _wrapper(func) if config.is_dace_orchestrated() else func diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py index b1f69aa1..210f7978 100644 --- a/ndsl/dsl/dace/stree/optimizations/__init__.py +++ b/ndsl/dsl/dace/stree/optimizations/__init__.py @@ -2,6 +2,7 @@ from .cartesian_merge import CartesianMerge from .clean_tree import CleanUpScheduleTree from .kernelize_maps import KernelizeMaps +from .local_optimizations import LocalOptimizations from .offgrid_conditionals import ( ExtractOffgridConditionals, InlineOffgridConditionals, @@ -17,6 +18,7 @@ "CartesianMerge", "CleanUpScheduleTree", "KernelizeMaps", + "LocalOptimizations", "ExtractOffgridConditionals", "InlineOffgridConditionals", "MergeConditionals", diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py index 93798f42..acd7bd79 100644 --- a/ndsl/dsl/dace/stree/optimizations/clean_tree.py +++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py @@ -23,6 +23,14 @@ def _remove_state_boundaries_from_children( self._removed_state_boundaries += 1 node.children.remove(boundary) + def visit_LibraryCall(self, node: tn.LibraryCall) -> tn.LibraryCall | None: + # Filter duplicate labeled regions + # TODO: this shouldn't be necessary and needs to be cleaned up. + if node.node.unique_name.endswith("_patched"): + return None + + return node + def visit_WhileScope(self, node: tn.WhileScope) -> tn.WhileScope: self._remove_state_boundaries_from_children(node) diff --git a/ndsl/dsl/dace/stree/optimizations/local_optimizations.py b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py new file mode 100644 index 00000000..c70765c2 --- /dev/null +++ b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py @@ -0,0 +1,93 @@ +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from ndsl import OptimizationConfig +from ndsl.dsl.dace.stree.optimizations.common import list_index + + +class LabeledSection(tn.ScheduleTreeScope): + def __init__( + self, + *, + children: list[tn.ScheduleTreeNode], + parent: tn.ScheduleTreeScope, + label: str, + optimizations: OptimizationConfig, + ) -> None: + super().__init__(children=children, parent=parent) + self.label = label + self.optimizations = optimizations + + def as_string(self, indent: int = 0) -> str: + result = indent * tn.INDENTATION + f"section '{self.label}':\n" + return result + super().as_string(indent) + + +class _LabelSections(tn.ScheduleNodeVisitor): + _enter_labels: list[tn.LibraryCall] + + def __init__(self) -> None: + super().__init__() + + def __str__(self) -> str: + return "_LabelSections" + + def visit_LibraryCall(self, node: tn.LibraryCall) -> None: + # Only look at "our" label nodes + if node.node.name != "NDSLRuntime_Label": + return + + if node.node.unique_name.startswith("Enter__"): + # keep taps on where we start + self._enter_labels.append(node) + return + + if node.node.unique_name.startswith("Exit__"): + # find the matching start point + section_start = self._enter_labels.pop() + + # sanity checks + # - ensure we have the right section + name = section_start.node.unique_name.removeprefix("Enter__") + exit_name = node.node.unique_name.removeprefix("Exit__") + assert name == exit_name + # - ensure we have the same parent (if not something is screwed up) + parent = section_start.parent + assert parent == node.parent + + # grab all the nodes in-between and put them in a `LabeledSection` + start_index = list_index(parent.children, section_start) + end_index = list_index(parent.children, node) + new_node = LabeledSection( + children=parent.children[start_index + 1 : end_index], + parent=parent, + label=name, + optimizations=node.node._local_optimizations, + ) + + # overwrite the nodes (including the labels) with the new node + parent.children[start_index : end_index + 1] = [new_node] + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + # reset the stack of enter labels + self._enter_labels = [] + + # then, visit all the children + self.generic_visit(node) + + # make sure we have replaced everybody + assert len(self._enter_labels) == 0 + + +class LocalOptimizations(tn.ScheduleNodeVisitor): + def __init__(self) -> None: + super().__init__() + + def __str__(self) -> str: + return "LocalOptimizations" + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: + # First, parse enter/exit labels into `LabeledSection`s. + _LabelSections().visit(node) + + # Then, apply local optimizations on children of `LabeledSection`s. + assert node diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index cb1bbb12..ed4f062f 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -9,6 +9,7 @@ CleanUpScheduleTree, InlineVertical2DWrite, KernelizeMaps, + LocalOptimizations, TreeOptimizationStatistics, ) @@ -70,7 +71,7 @@ def __init__( cache_directory: Path | None = None, ) -> None: if passes is None: - ppl_passes = [CleanUpScheduleTree()] + ppl_passes = [CleanUpScheduleTree(), LocalOptimizations()] if config.stree.inline_K_loops_size_one: ppl_passes.append(InlineVertical2DWrite()) if config.stree.merger.enabled: @@ -102,7 +103,7 @@ def __init__( cache_directory: Path | None = None, ) -> None: if passes is None: - ppl_passes = [CleanUpScheduleTree()] + ppl_passes = [CleanUpScheduleTree(), LocalOptimizations()] if config.stree.inline_K_loops_size_one: ppl_passes.append(InlineVertical2DWrite()) if config.stree.merger.enabled: diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py index 294f5711..492e62dc 100644 --- a/ndsl/dsl/ndsl_runtime.py +++ b/ndsl/dsl/ndsl_runtime.py @@ -30,14 +30,7 @@ def __init__( self._stencil_factory = stencil_factory # Use this flag to detect that the init wasn't done properly self._base_class_was_properly_super_init = True - if optimization_config is None: - # TODO - # - Decide where to put defaults. - # - For now, they are in the OptimizationConfig object itself. - # - We could have specialized defaults here for NDSLRuntime code. - self._optimization_config = OptimizationConfig() - else: - self._optimization_config = optimization_config + self._optimization_config = optimization_config def __init_subclass__(cls: type[NDSLRuntime], **kwargs: dict[str, Any]) -> None: # WARNING: no code outside the `init_decorator` this is cls diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py index feb22abe..c4213771 100644 --- a/ndsl/dsl/optimization_config.py +++ b/ndsl/dsl/optimization_config.py @@ -49,3 +49,5 @@ class GPU: stree: Tree = field(default_factory=Tree) gpu: GPU = field(default_factory=GPU) + + name: str = "unset" From 55db71a08428dc0f6d9421368a5a76908a0d4566 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Sat, 27 Jun 2026 01:46:22 +0200 Subject: [PATCH 091/101] feat: local optimization --- ndsl/dsl/dace/orchestration.py | 2 +- .../stree/optimizations/kernelize_maps.py | 25 ++- .../optimizations/local_optimizations.py | 190 ++++++++++++++++-- ndsl/dsl/dace/stree/pipeline.py | 12 +- ndsl/dsl/optimization_config.py | 2 +- 5 files changed, 197 insertions(+), 34 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index d267a98b..bde5d8fe 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -152,7 +152,7 @@ def _optimization_pipeline( passes: list[tn.ScheduleNodeVisitor] | None = None, cache_directory: Path | None = None, ) -> StreePipeline: - if device_type == device_type.CPU: + if device_type == DeviceType.CPU: return CPUPipeline( config, backend, passes=passes, cache_directory=cache_directory ) diff --git a/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py b/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py index 11135ef6..03edb878 100644 --- a/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py +++ b/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py @@ -3,7 +3,7 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import Backend -from ndsl.config import BackendLoopOrder, BackendTargetDevice +from ndsl.config import BackendLoopOrder from ndsl.dsl.dace.stree.optimizations.common import ( AxisIterator, is_axis_map, @@ -55,11 +55,12 @@ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope | list[tn.MapScope]: class KernelizeMaps(tn.ScheduleNodeVisitor): - def __init__(self, backend: Backend) -> None: + def __init__(self, backend: Backend, *, apply_order: str = "default") -> None: super().__init__() self._backend = backend + self._apply_order = apply_order - if self._backend.device != BackendTargetDevice.GPU: + if not self._backend.is_gpu_backend(): raise ValueError( "The transformation `KernelizeMaps` is only intended to run on GPUs." ) @@ -72,6 +73,14 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: _KernelizeMap(axis).visit(node) def _axis_order(self) -> list[AxisIterator]: + if self._apply_order == "default": + # By default, follow the backend's axis order. + return self._axis_order_backend() + + # Allow custom order (e.g. for local optimizations). + return self._axis_order_custom() + + def _axis_order_backend(self) -> list[AxisIterator]: if self._backend.loop_order == BackendLoopOrder.IJK: return [AxisIterator._J, AxisIterator._I] if self._backend.loop_order == BackendLoopOrder.KJI: @@ -80,3 +89,13 @@ def _axis_order(self) -> list[AxisIterator]: raise NotImplementedError( f"KernelizeMaps is not configured for loop order {self._backend.loop_order}." ) + + def _axis_order_custom(self) -> list[AxisIterator]: + if self._apply_order == "JI": + return [AxisIterator._J, AxisIterator._I] + if self._apply_order == "JK": + return [AxisIterator._J, AxisIterator._K] + + raise NotImplementedError( + f"KernelizeMaps is not configured for custom apply order {self._apply_order}." + ) diff --git a/ndsl/dsl/dace/stree/optimizations/local_optimizations.py b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py index c70765c2..189d4962 100644 --- a/ndsl/dsl/dace/stree/optimizations/local_optimizations.py +++ b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py @@ -1,10 +1,16 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn -from ndsl import OptimizationConfig +from ndsl import Backend, OptimizationConfig, ndsl_log +from ndsl.dsl.dace.stree.optimizations.cartesian_merge import CartesianMerge from ndsl.dsl.dace.stree.optimizations.common import list_index +from ndsl.dsl.dace.stree.optimizations.kernelize_maps import KernelizeMaps +from ndsl.dsl.dace.stree.optimizations.refine_transients import ( + CartesianRefineTransients, +) +from ndsl.dsl.dace.stree.optimizations.remove_loops import InlineVertical2DWrite -class LabeledSection(tn.ScheduleTreeScope): +class _LabeledSection(tn.ScheduleTreeScope): def __init__( self, *, @@ -23,30 +29,70 @@ def as_string(self, indent: int = 0) -> str: class _LabelSections(tn.ScheduleNodeVisitor): - _enter_labels: list[tn.LibraryCall] + """ + Transform entry/exit labeler nodes into a `LabeledSection` (see above) + for easier later handling in case of local optimizations. Handles nested + labeled sections. + + Before + + ```none + # program before + + library_node("entry my_stencil") + map i in [...] + map j in [...] + map k in [...] + # contents of "my_stencil" + library node("exit my_stencil") + + # program continues + ``` + + After + + ```none + # program before + + labeled_section "my_stecil": + map i in [...] + map j in [...] + map k in [...] + # contents of "my_stencil + + # program continues + ``` + """ + + _entry_nodes: list[tn.LibraryCall] + """ + Stack of entry nodes for labeled sections. Nodes get pushed on entering the + labeled section and are removed again upon reaching the matching exit node. + """ def __init__(self) -> None: super().__init__() + self._entry_nodes = [] def __str__(self) -> str: return "_LabelSections" def visit_LibraryCall(self, node: tn.LibraryCall) -> None: - # Only look at "our" label nodes if node.node.name != "NDSLRuntime_Label": + # Only look at "our" label nodes. return if node.node.unique_name.startswith("Enter__"): - # keep taps on where we start - self._enter_labels.append(node) + # Keep taps on where we start. + self._entry_nodes.append(node) return if node.node.unique_name.startswith("Exit__"): - # find the matching start point - section_start = self._enter_labels.pop() + # Find the matching entry node. + section_start = self._entry_nodes.pop() # sanity checks - # - ensure we have the right section + # - ensure we have the right section (if not, something is screwed up) name = section_start.node.unique_name.removeprefix("Enter__") exit_name = node.node.unique_name.removeprefix("Exit__") assert name == exit_name @@ -54,40 +100,142 @@ def visit_LibraryCall(self, node: tn.LibraryCall) -> None: parent = section_start.parent assert parent == node.parent - # grab all the nodes in-between and put them in a `LabeledSection` + # Grab all the nodes in-between and put them in a `LabeledSection`. start_index = list_index(parent.children, section_start) end_index = list_index(parent.children, node) - new_node = LabeledSection( - children=parent.children[start_index + 1 : end_index], + new_node = _LabeledSection( + children=[ + child for child in parent.children[start_index + 1 : end_index] + ], parent=parent, label=name, optimizations=node.node._local_optimizations, ) - # overwrite the nodes (including the labels) with the new node + # Overwrite the nodes (including the labels) with the new node. parent.children[start_index : end_index + 1] = [new_node] def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: - # reset the stack of enter labels - self._enter_labels = [] + # Reset the stack of entry nodes. + self._entry_nodes = [] + + self.generic_visit(node) - # then, visit all the children + # If we have nodes left, something is screwed up. + assert len(self._entry_nodes) == 0 + + +class _ApplyLocalOptimizations(tn.ScheduleNodeVisitor): + """ + Applies local optimization in `LabeledSection`s in a "leaf first" approach. + + This work inline and replaces the `LabeledSection` with the results of the local + optimization as configured in the `OptimizationConfig` of the `LabeledSection`. + """ + + def __init__(self, backend: Backend) -> None: + super().__init__() + self._backend = backend + + def __str__(self) -> str: + return "_LabelSections" + + def visit_LabeledSection(self, node: _LabeledSection) -> None: + # Go down into children first such that we can apply local optimization "leaf first". self.generic_visit(node) - # make sure we have replaced everybody - assert len(self._enter_labels) == 0 + # TODO + # The code below is basically an `StreePipeline`. I've duplicated that + # pipeline because we need some clever engineering to not get into a + # hell of dependency circles (where the local optimizations are pipeline pass + # and in itself depend on the pipeline). + + config = node.optimizations + assert config.stree.enabled + + # HACK + # Below, we are calling `visit_ScheduleTreeRoot` with a `LabeledSection`. This works + # because python uses duck-typing. + # TODO + # Clean up pipeline passes and the pipeline itself such that they can work + # on any subtree (i.e. any `ScheduleTreeScope`). + + if self._backend.is_gpu_backend(): + if config.stree.inline_K_loops_size_one: + gpu_inliner = InlineVertical2DWrite() + gpu_inliner.visit_ScheduleTreeRoot(node) + + if config.stree.merger.enabled: + gpu_merger = CartesianMerge( + self._backend, + overcompute=config.stree.merger.overcompute, + merge_order=config.stree.merger.order, + ) + gpu_merger.visit_ScheduleTreeRoot(node) + + if config.stree.kernelize: + if config.stree.merger.order not in ("IJK", "KJI"): + ndsl_log.warning( + "Can't locally kernelize maps. Unknown apply oder. Skipping this pass." + ) + else: + # Follow the merge-order for kernelization + gpu_kernelizer = KernelizeMaps( + self._backend, + apply_order=( + "JI" if config.stree.merger.order == "IJK" else "JK" + ), + ) + gpu_kernelizer.visit_ScheduleTreeRoot(node) + + if config.stree.refine_transients: + # TODO + # 🐞 Transient refine can't be used because of bugs transients showing + # in code generation. + # gpu_refiner = CartesianRefineTransients(self._backend) + # gpu_refiner.visit_ScheduleTreeRoot(node) + raise ValueError( + "Transient refinement is currently unavailable in the GPU pipeline." + ) + else: + if config.stree.inline_K_loops_size_one: + cpu_inliner = InlineVertical2DWrite() + cpu_inliner.visit_ScheduleTreeRoot(node) + + if config.stree.merger.enabled: + cpu_merger = CartesianMerge( + self._backend, + overcompute=config.stree.merger.overcompute, + merge_order=config.stree.merger.order, + ) + cpu_merger.visit_ScheduleTreeRoot(node) + + if config.stree.refine_transients: + cpu_refiner = CartesianRefineTransients(self._backend) + cpu_refiner.visit_ScheduleTreeRoot(node) + + # Replace this `LabeledSection` with just the (now transformed) children. + for child in node.children: + # be sure to re-parent the children of this node to the new parent + child.parent = node.parent + node_index = list_index(node.parent.children, node) + node.parent.children[node_index : node_index + 1] = node.children class LocalOptimizations(tn.ScheduleNodeVisitor): - def __init__(self) -> None: + def __init__(self, backend: Backend) -> None: super().__init__() + self._backend = backend def __str__(self) -> str: return "LocalOptimizations" def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: - # First, parse enter/exit labels into `LabeledSection`s. + # First, parse enter/exit labels into `LabeledSection`s... _LabelSections().visit(node) - # Then, apply local optimizations on children of `LabeledSection`s. + # .. then, apply local optimizations on children of `LabeledSection`s. + _ApplyLocalOptimizations(self._backend).visit(node) + + # debug only assert node diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index ed4f062f..a44c7c08 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -17,7 +17,6 @@ class StreePipeline: def __init__( self, - config: OptimizationConfig, *, passes: list[tn.ScheduleNodeVisitor], cache_directory: Path | None = None, @@ -27,7 +26,6 @@ def __init__( self.cache_directory = cache_directory self.passes = passes - self.config = config def __hash__(self) -> int: return hash(repr(self)) @@ -37,9 +35,9 @@ def __repr__(self) -> str: def run( self, - stree: tn.ScheduleTreeRoot, + stree: tn.ScheduleTreeScope, verbose: bool = False, - ) -> tn.ScheduleTreeRoot: + ) -> tn.ScheduleTreeScope: tree_stats = TreeOptimizationStatistics() tree_stats.original(stree) @@ -71,7 +69,7 @@ def __init__( cache_directory: Path | None = None, ) -> None: if passes is None: - ppl_passes = [CleanUpScheduleTree(), LocalOptimizations()] + ppl_passes = [CleanUpScheduleTree(), LocalOptimizations(backend)] if config.stree.inline_K_loops_size_one: ppl_passes.append(InlineVertical2DWrite()) if config.stree.merger.enabled: @@ -87,7 +85,6 @@ def __init__( else: ppl_passes = passes super().__init__( - config=config, passes=ppl_passes, cache_directory=cache_directory, ) @@ -103,7 +100,7 @@ def __init__( cache_directory: Path | None = None, ) -> None: if passes is None: - ppl_passes = [CleanUpScheduleTree(), LocalOptimizations()] + ppl_passes = [CleanUpScheduleTree(), LocalOptimizations(backend)] if config.stree.inline_K_loops_size_one: ppl_passes.append(InlineVertical2DWrite()) if config.stree.merger.enabled: @@ -123,7 +120,6 @@ def __init__( else: ppl_passes = passes super().__init__( - config=config, passes=ppl_passes, cache_directory=cache_directory, ) diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py index c4213771..fdcd7166 100644 --- a/ndsl/dsl/optimization_config.py +++ b/ndsl/dsl/optimization_config.py @@ -10,7 +10,7 @@ class Tree: @dataclass class Merger: - enabled: bool = False + enabled: bool = True """Enable cartesian axis merging.""" overcompute: bool = ( From e31d1d7517f9508b14eff31c2a44fbda3ca23f4b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Sun, 28 Jun 2026 22:08:10 +0200 Subject: [PATCH 092/101] fix: new algo for creating labeled sections --- .../optimizations/local_optimizations.py | 167 ++++++++++++++---- 1 file changed, 137 insertions(+), 30 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/local_optimizations.py b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py index 189d4962..f5539d7e 100644 --- a/ndsl/dsl/dace/stree/optimizations/local_optimizations.py +++ b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py @@ -28,7 +28,7 @@ def as_string(self, indent: int = 0) -> str: return result + super().as_string(indent) -class _LabelSections(tn.ScheduleNodeVisitor): +class _LabelSections(tn.ScheduleNodeTransformer): """ Transform entry/exit labeler nodes into a `LabeledSection` (see above) for easier later handling in case of local optimizations. Handles nested @@ -77,52 +77,159 @@ def __init__(self) -> None: def __str__(self) -> str: return "_LabelSections" - def visit_LibraryCall(self, node: tn.LibraryCall) -> None: - if node.node.name != "NDSLRuntime_Label": - # Only look at "our" label nodes. - return - - if node.node.unique_name.startswith("Enter__"): - # Keep taps on where we start. - self._entry_nodes.append(node) - return + def _label_marked_sections(self, scope: tn.ScheduleTreeScope) -> None: + """ + This is the function that actually does all the work by going over the children of a given schedule tree + scope and re-grouping them into labeled sections based on `NDSLRuntime_Label` entry/exit nodes. + """ + # The stack of entry nodes. They pop when the matching exit node is reached. Using a stack adds + # support for nested labeled sections. + entry_nodes_stack: list[tn.LibraryCall] = [] + + # The stack of children. Every new entry node pushes its children into a new stack entry. This allows + # one pass to gather nested children. + children_stack: list[list[tn.ScheduleTreeNode]] = [] + + # Top-level stack is for the current scope. + children_stack.append([]) + + for child in scope.children: + # Unless we are dealing with `tn.LibraryCall` nodes, we push all nodes to the stack of new children. + if not isinstance(child, tn.LibraryCall): + children_stack[-1].append(child) + continue + + if not child.node.name == "NDSLRuntime_Label": + # Leave other library call nodes alone. + children_stack[-1].append(child) + continue + + if child.node.unique_name.startswith("Enter__"): + # Keep taps on where we start and open a new list of children. + entry_nodes_stack.append(child) + children_stack.append([]) + continue + + # Expect to find an exit node now (matching the entry node that current on top of the stack). + if not child.node.unique_name.startswith("Exit__"): + raise RuntimeError( + f"Unexpected `NDSLRuntim_Label` '{child.node.unique_name}'." + ) - if node.node.unique_name.startswith("Exit__"): - # Find the matching entry node. - section_start = self._entry_nodes.pop() + # For exit nodes, find the matching entry node and the new children. + section_start = entry_nodes_stack.pop() + new_children = children_stack.pop() # sanity checks # - ensure we have the right section (if not, something is screwed up) name = section_start.node.unique_name.removeprefix("Enter__") - exit_name = node.node.unique_name.removeprefix("Exit__") - assert name == exit_name + assert name == child.node.unique_name.removeprefix("Exit__") # - ensure we have the same parent (if not something is screwed up) parent = section_start.parent - assert parent == node.parent + assert parent == child.parent + # - ensure the stack of children is not empty (it will at least contain the top-level scope) + assert len(children_stack) > 0 - # Grab all the nodes in-between and put them in a `LabeledSection`. - start_index = list_index(parent.children, section_start) - end_index = list_index(parent.children, node) + # Put all the new children in a `LabeledSection` and push that into the + # new children of the above stack of children. new_node = _LabeledSection( - children=[ - child for child in parent.children[start_index + 1 : end_index] - ], + children=new_children, parent=parent, label=name, - optimizations=node.node._local_optimizations, + optimizations=section_start.node._local_optimizations, ) + # re-parent new children to new node + for c in new_node.children: + c.parent = new_node + # push new node into enclosing stack of children + children_stack.append(new_node) - # Overwrite the nodes (including the labels) with the new node. - parent.children[start_index : end_index + 1] = [new_node] + # and - of course - the final book keeping + self._labeled_sections += 1 - def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: - # Reset the stack of entry nodes. - self._entry_nodes = [] + # set the new children on the current scope + scope.children = children_stack.pop() + + # some sanity checks + assert len(children_stack) == 0 # expect empty stack + for child in scope.children: + assert child.parent == scope # expect correct parent + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot: + self._labeled_sections = 0 + + # recurse down first to label sections "leaf first" + self.generic_visit(node) + self._label_marked_sections(node) + + ndsl_log.debug(f"{self}: labeled {self._labeled_sections} sections.") + return node + + def visit_GBlock(self, node: tn.GBlock) -> tn.GBlock: + self.generic_visit(node) + self._label_marked_sections(node) + + return node + + def visit_LoopScope(self, node: tn.LoopScope) -> tn.LoopScope: + self.generic_visit(node) + self._label_marked_sections(node) + + return node + + def visit_ForScope(self, node: tn.ForScope) -> tn.ForScope: + self.generic_visit(node) + self._label_marked_sections(node) + + return node + + def visit_WhileScope(self, node: tn.WhileScope) -> tn.WhileScope: + self.generic_visit(node) + self._label_marked_sections(node) + + return node + + def visit_DoWhileScope(self, node: tn.DoWhileScope) -> tn.DoWhileScope: + self.generic_visit(node) + self._label_marked_sections(node) + + return node + + def visit_IfScope(self, node: tn.IfScope) -> tn.IfScope: + self.generic_visit(node) + self._label_marked_sections(node) + + return node + + def visit_StateIfScope(self, node: tn.StateIfScope) -> tn.StateIfScope: + self.generic_visit(node) + self._label_marked_sections(node) + + return node + + def visit_ElifScope(self, node: tn.ElifScope) -> tn.ElifScope: + self.generic_visit(node) + self._label_marked_sections(node) + + return node + + def visit_ElseScope(self, node: tn.ElseScope) -> tn.ElseScope: + self.generic_visit(node) + self._label_marked_sections(node) + + return node + + def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: + self.generic_visit(node) + self._label_marked_sections(node) + + return node + def visit_ConsumeScope(self, node: tn.ConsumeScope) -> tn.ConsumeScope: self.generic_visit(node) + self._label_marked_sections(node) - # If we have nodes left, something is screwed up. - assert len(self._entry_nodes) == 0 + return node class _ApplyLocalOptimizations(tn.ScheduleNodeVisitor): From 699949b483d8c64e003df9ca878313a9d582227b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Sun, 28 Jun 2026 23:24:30 +0200 Subject: [PATCH 093/101] fix new algo --- ndsl/dsl/dace/stree/optimizations/local_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/stree/optimizations/local_optimizations.py b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py index f5539d7e..ab2923d3 100644 --- a/ndsl/dsl/dace/stree/optimizations/local_optimizations.py +++ b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py @@ -142,7 +142,7 @@ def _label_marked_sections(self, scope: tn.ScheduleTreeScope) -> None: for c in new_node.children: c.parent = new_node # push new node into enclosing stack of children - children_stack.append(new_node) + children_stack[-1].append(new_node) # and - of course - the final book keeping self._labeled_sections += 1 From 4f5804d9546da2efa4c7fcd1263f15721c2b48a2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 29 Jun 2026 12:32:40 -0400 Subject: [PATCH 094/101] Clean up logging in orchestration --- ndsl/dsl/dace/orchestration.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index d267a98b..e110bc9c 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -247,7 +247,6 @@ def _build_sdfg( ) }, validate=True, - print_report=True, ) stree = sdfg.as_schedule_tree() if config.verbose_orchestration: @@ -326,7 +325,7 @@ def _build_sdfg( if isinstance(me, nodes.Tasklet) and "callback_" in me.label: exclude_taskslets_list.append(me.label) - sdfg.apply_transformations_repeated(AddThreadBlockMap) + sdfg.apply_transformations_repeated(AddThreadBlockMap, print_report=False) if optimization_config.gpu.common_gpu_xforms: with DaCeProgress(config, "Apply common GPU xforms"): From b2e282a158744f64957e9194b034832da67ccef5 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 29 Jun 2026 12:34:04 -0400 Subject: [PATCH 095/101] Add API to get an equivalent CPU and STENCIL backend from an existing backend --- ndsl/config/backend.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 605b86d7..1a89a817 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -190,6 +190,18 @@ def is_stencil(self) -> bool: def is_gpu_backend(self) -> bool: return self._device == BackendTargetDevice.GPU + def equivalent_cpu_backend(self) -> "Backend": + """Return the equivalent backend (same strategy, framework and loop order) but for CPU device""" + if self._device == BackendTargetDevice.CPU: + return self + return Backend(f"{self._strategy}:{self._framework}:{BackendTargetDevice.CPU}:{self._loop_order}") + + def equivalent_stencil_backend(self) -> "Backend": + """Return the equivalent backend (same device, framework and loop order) but for Stencil strategy""" + if self._strategy == BackendStrategy.STENCIL: + return self + return Backend(f"{BackendStrategy.STENCIL}:{self._framework}:{self._device}:{self._loop_order}") + def is_fortran_aligned(self) -> bool: """Check that the standard 3D field on cartesian axis is memory-aligned with Fortran striding.""" From 1d242bcdd51acdbf60201141b305909597941d9c Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 29 Jun 2026 12:37:53 -0400 Subject: [PATCH 096/101] Lint --- ndsl/config/backend.py | 10 +++++++--- ndsl/dsl/dace/orchestration.py | 4 +++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py index 1a89a817..6f783e8c 100644 --- a/ndsl/config/backend.py +++ b/ndsl/config/backend.py @@ -191,16 +191,20 @@ def is_gpu_backend(self) -> bool: return self._device == BackendTargetDevice.GPU def equivalent_cpu_backend(self) -> "Backend": - """Return the equivalent backend (same strategy, framework and loop order) but for CPU device""" + """Return the equivalent backend (same strategy, framework and loop order) but for CPU device""" if self._device == BackendTargetDevice.CPU: return self - return Backend(f"{self._strategy}:{self._framework}:{BackendTargetDevice.CPU}:{self._loop_order}") + return Backend( + f"{self._strategy}:{self._framework}:{BackendTargetDevice.CPU}:{self._loop_order}" + ) def equivalent_stencil_backend(self) -> "Backend": """Return the equivalent backend (same device, framework and loop order) but for Stencil strategy""" if self._strategy == BackendStrategy.STENCIL: return self - return Backend(f"{BackendStrategy.STENCIL}:{self._framework}:{self._device}:{self._loop_order}") + return Backend( + f"{BackendStrategy.STENCIL}:{self._framework}:{self._device}:{self._loop_order}" + ) def is_fortran_aligned(self) -> bool: """Check that the standard 3D field on cartesian axis is memory-aligned with Fortran diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py index e110bc9c..2aa649f9 100644 --- a/ndsl/dsl/dace/orchestration.py +++ b/ndsl/dsl/dace/orchestration.py @@ -325,7 +325,9 @@ def _build_sdfg( if isinstance(me, nodes.Tasklet) and "callback_" in me.label: exclude_taskslets_list.append(me.label) - sdfg.apply_transformations_repeated(AddThreadBlockMap, print_report=False) + sdfg.apply_transformations_repeated( + AddThreadBlockMap, print_report=False + ) if optimization_config.gpu.common_gpu_xforms: with DaCeProgress(config, "Apply common GPU xforms"): From 10ab2a4ba2fffdfc6a746b17c550883b4aef2268 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 1 Jul 2026 10:07:43 +0200 Subject: [PATCH 097/101] fixup: clean the docs for stree opt --- .../dsl/dace/stree/optimizations/axis_merge.md | 2 +- .../dsl/dace/stree/optimizations/clean_tree.md | 2 +- .../{memlet_helpers.md => common/loops.md} | 4 ++-- .../{tree_common_op.md => common/memlet.md} | 4 ++-- .../dsl/dace/stree/optimizations/common/topology.md | 12 ++++++++++++ .../dsl/dace/stree/optimizations/kernelize_maps.md | 12 ++++++++++++ .../dace/stree/optimizations/local_optimizations.md | 12 ++++++++++++ .../dace/stree/optimizations/offgrid_conditionals.md | 12 ++++++++++++ .../dace/stree/optimizations/refine_transients.md | 2 +- .../dsl/dace/stree/optimizations/remove_loops.md | 12 ++++++++++++ .../dace/stree/optimizations/replace_axis_symbol.md | 12 ++++++++++++ .../dsl/dace/stree/optimizations/statistics.md | 12 ++++++++++++ 12 files changed, 91 insertions(+), 7 deletions(-) rename docs/docstrings/dsl/dace/stree/optimizations/{memlet_helpers.md => common/loops.md} (73%) rename docs/docstrings/dsl/dace/stree/optimizations/{tree_common_op.md => common/memlet.md} (73%) create mode 100644 docs/docstrings/dsl/dace/stree/optimizations/common/topology.md create mode 100644 docs/docstrings/dsl/dace/stree/optimizations/kernelize_maps.md create mode 100644 docs/docstrings/dsl/dace/stree/optimizations/local_optimizations.md create mode 100644 docs/docstrings/dsl/dace/stree/optimizations/offgrid_conditionals.md create mode 100644 docs/docstrings/dsl/dace/stree/optimizations/remove_loops.md create mode 100644 docs/docstrings/dsl/dace/stree/optimizations/replace_axis_symbol.md create mode 100644 docs/docstrings/dsl/dace/stree/optimizations/statistics.md diff --git a/docs/docstrings/dsl/dace/stree/optimizations/axis_merge.md b/docs/docstrings/dsl/dace/stree/optimizations/axis_merge.md index ab6794f7..53cf06df 100644 --- a/docs/docstrings/dsl/dace/stree/optimizations/axis_merge.md +++ b/docs/docstrings/dsl/dace/stree/optimizations/axis_merge.md @@ -1,6 +1,6 @@ # axis_merge -::: dsl.dace.stree.optimizations +::: dsl.dace.stree.optimizations.axis_merge diff --git a/docs/docstrings/dsl/dace/stree/optimizations/kernelize_maps.md b/docs/docstrings/dsl/dace/stree/optimizations/kernelize_maps.md new file mode 100644 index 00000000..974eb637 --- /dev/null +++ b/docs/docstrings/dsl/dace/stree/optimizations/kernelize_maps.md @@ -0,0 +1,12 @@ +# kernelize_maps + +::: dsl.dace.stree.optimizations.kernelize_maps + + diff --git a/docs/docstrings/dsl/dace/stree/optimizations/local_optimizations.md b/docs/docstrings/dsl/dace/stree/optimizations/local_optimizations.md new file mode 100644 index 00000000..540b9eba --- /dev/null +++ b/docs/docstrings/dsl/dace/stree/optimizations/local_optimizations.md @@ -0,0 +1,12 @@ +# local_optimizations + +::: dsl.dace.stree.optimizations.local_optimizations + + diff --git a/docs/docstrings/dsl/dace/stree/optimizations/offgrid_conditionals.md b/docs/docstrings/dsl/dace/stree/optimizations/offgrid_conditionals.md new file mode 100644 index 00000000..e5ab9ecf --- /dev/null +++ b/docs/docstrings/dsl/dace/stree/optimizations/offgrid_conditionals.md @@ -0,0 +1,12 @@ +# offgrid_conditionals + +::: dsl.dace.stree.optimizations.offgrid_conditionals + + diff --git a/docs/docstrings/dsl/dace/stree/optimizations/refine_transients.md b/docs/docstrings/dsl/dace/stree/optimizations/refine_transients.md index e207bb71..4d11671e 100644 --- a/docs/docstrings/dsl/dace/stree/optimizations/refine_transients.md +++ b/docs/docstrings/dsl/dace/stree/optimizations/refine_transients.md @@ -1,6 +1,6 @@ # refine_transients -::: dsl.dace.stree.optimizations +::: dsl.dace.stree.optimizations.refine_transients diff --git a/docs/docstrings/dsl/dace/stree/optimizations/replace_axis_symbol.md b/docs/docstrings/dsl/dace/stree/optimizations/replace_axis_symbol.md new file mode 100644 index 00000000..ccff4aa4 --- /dev/null +++ b/docs/docstrings/dsl/dace/stree/optimizations/replace_axis_symbol.md @@ -0,0 +1,12 @@ +# replace_axis_symbol + +::: dsl.dace.stree.optimizations.replace_axis_symbol + + diff --git a/docs/docstrings/dsl/dace/stree/optimizations/statistics.md b/docs/docstrings/dsl/dace/stree/optimizations/statistics.md new file mode 100644 index 00000000..c218d368 --- /dev/null +++ b/docs/docstrings/dsl/dace/stree/optimizations/statistics.md @@ -0,0 +1,12 @@ +# statistics + +::: dsl.dace.stree.optimizations.statistics + + From 2975f3f3c0c824e7a9ae09a2b09cdeafd2e086a9 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 1 Jul 2026 11:48:38 +0200 Subject: [PATCH 098/101] fix: support for plain numbers in index normalization When we normalize cartesian indices, add support for plain numbers as indices. Previously, we'd assume that each index is a symbolic expression. Now we have support for plain numbers too (e.g. from transient refinement). --- ndsl/dsl/dace/stree/optimizations/common/memlet.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ndsl/dsl/dace/stree/optimizations/common/memlet.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py index edb4f3e0..d52ca9f5 100644 --- a/ndsl/dsl/dace/stree/optimizations/common/memlet.py +++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py @@ -1,4 +1,5 @@ from enum import Enum +from numbers import Number from dace.memlet import Memlet from dace.sdfg.analysis.schedule_tree import treenodes as tn @@ -25,8 +26,14 @@ def is_equal(self, other: str) -> bool: return other == self.as_str() -def normalize_cartesian_indexation(index: symbol, axis: AxisIterator) -> symbol: +def normalize_cartesian_indexation( + index: Number | symbol, axis: AxisIterator +) -> symbol: """Return a normalize indexation symbol for cartesian indexation.""" + if isinstance(index, Number): + # Special case for refined cartesian indices, i.e. when `index` is 0. + return index + rename_maps = {} for symb in index.free_symbols: if symb.name.startswith(axis.as_str()): From d5b6b0dfb1a9c37355e8b8c399c1fb98e92cf81d Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 1 Jul 2026 11:50:33 +0200 Subject: [PATCH 099/101] fix: match whole words when replacing axis symbols This was causing issues because one symbol was contained in the other, e.g. when replacing `__k` with `__k_123` you'd get things like `__k_123_456` in case `__k` and `__k_456` were mixed. --- .../dace/stree/optimizations/replace_axis_symbol.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py index c04c2fc5..e4b83b03 100644 --- a/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py +++ b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py @@ -1,4 +1,5 @@ import itertools +import re from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.symbolic import symbol @@ -16,14 +17,16 @@ def visit_TaskletNode(self, node: tn.TaskletNode) -> None: if node.node.label.startswith("masklet"): for old, new in self._axis_replacements.items(): - node.node.code.as_string = node.node.code.as_string.replace( - str(old), str(new) + # use regex to match word boundaries (with `\b`) + node.node.code.as_string = re.sub( + rf"\b{str(old)}\b", str(new), node.node.code.as_string ) def visit_IfScope(self, node: tn.IfScope) -> None: for old, new in self._axis_replacements.items(): - node.condition.as_string = node.condition.as_string.replace( - str(old), str(new) + # use regex to match word boundaries (with `\b`) + node.condition.as_string = re.sub( + rf"\b{str(old)}\b", str(new), node.condition.as_string ) for child in node.children: From 6d069875d11fa16746a77cd4005abb307023b608 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 1 Jul 2026 15:15:09 +0200 Subject: [PATCH 100/101] first version of local optimization pipeline This is working to the best of my knowledge. Test case is a `D_SW` in `orch:dace:cpu:IJK` with FvTp2d configured to have loops ordered as `K-J-I` i.e. the other way round, which allows to merge K-loops. --- .../stree/optimizations/cartesian_merge.py | 4 +- .../optimizations/local_optimizations.py | 364 ++++++++++-------- .../stree/optimizations/refine_transients.py | 56 +-- ndsl/dsl/dace/stree/pipeline.py | 2 +- 4 files changed, 227 insertions(+), 199 deletions(-) diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py index d943e8df..067eb5f2 100644 --- a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py +++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py @@ -51,7 +51,9 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: MergeConditionals().visit(node) for axis in axis_merge_order: - CartesianAxisMerge(axis, overcompute=self._overcompute).visit(node) + CartesianAxisMerge( + axis, overcompute=self._overcompute + ).visit_ScheduleTreeRoot(node) ExtractOffgridConditionals().visit(node) MergeConditionals().visit(node) diff --git a/ndsl/dsl/dace/stree/optimizations/local_optimizations.py b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py index ab2923d3..b9796613 100644 --- a/ndsl/dsl/dace/stree/optimizations/local_optimizations.py +++ b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py @@ -2,14 +2,105 @@ from ndsl import Backend, OptimizationConfig, ndsl_log from ndsl.dsl.dace.stree.optimizations.cartesian_merge import CartesianMerge -from ndsl.dsl.dace.stree.optimizations.common import list_index from ndsl.dsl.dace.stree.optimizations.kernelize_maps import KernelizeMaps -from ndsl.dsl.dace.stree.optimizations.refine_transients import ( - CartesianRefineTransients, -) from ndsl.dsl.dace.stree.optimizations.remove_loops import InlineVertical2DWrite +class ScheduleTreeScopeTransformer(tn.ScheduleNodeTransformer): + def __init__(self) -> None: + super().__init__() + + def _breadth_first_callback(self, node: tn.ScheduleTreeScope) -> None: + pass + + def _depth_first_callback(self, node: tn.ScheduleTreeScope) -> None: + pass + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_GBlock(self, node: tn.GBlock) -> tn.GBlock: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_LoopScope(self, node: tn.LoopScope) -> tn.LoopScope: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_ForScope(self, node: tn.ForScope) -> tn.ForScope: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_WhileScope(self, node: tn.WhileScope) -> tn.WhileScope: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_DoWhileScope(self, node: tn.DoWhileScope) -> tn.DoWhileScope: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_IfScope(self, node: tn.IfScope) -> tn.IfScope: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_StateIfScope(self, node: tn.StateIfScope) -> tn.StateIfScope: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_ElifScope(self, node: tn.ElifScope) -> tn.ElifScope: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_ElseScope(self, node: tn.ElseScope) -> tn.ElseScope: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + def visit_ConsumeScope(self, node: tn.ConsumeScope) -> tn.ConsumeScope: + self._breadth_first_callback(node) + self.generic_visit(node) + self._depth_first_callback(node) + + return node + + class _LabeledSection(tn.ScheduleTreeScope): def __init__( self, @@ -28,7 +119,7 @@ def as_string(self, indent: int = 0) -> str: return result + super().as_string(indent) -class _LabelSections(tn.ScheduleNodeTransformer): +class _LabelSections(ScheduleTreeScopeTransformer): """ Transform entry/exit labeler nodes into a `LabeledSection` (see above) for easier later handling in case of local optimizations. Handles nested @@ -54,7 +145,7 @@ class _LabelSections(tn.ScheduleNodeTransformer): ```none # program before - labeled_section "my_stecil": + labeled_section "my_stencil": map i in [...] map j in [...] map k in [...] @@ -64,20 +155,13 @@ class _LabelSections(tn.ScheduleNodeTransformer): ``` """ - _entry_nodes: list[tn.LibraryCall] - """ - Stack of entry nodes for labeled sections. Nodes get pushed on entering the - labeled section and are removed again upon reaching the matching exit node. - """ - def __init__(self) -> None: super().__init__() - self._entry_nodes = [] def __str__(self) -> str: return "_LabelSections" - def _label_marked_sections(self, scope: tn.ScheduleTreeScope) -> None: + def _depth_first_callback(self, scope: tn.ScheduleTreeScope) -> None: """ This is the function that actually does all the work by going over the children of a given schedule tree scope and re-grouping them into labeled sections based on `NDSLRuntime_Label` entry/exit nodes. @@ -113,7 +197,7 @@ def _label_marked_sections(self, scope: tn.ScheduleTreeScope) -> None: # Expect to find an exit node now (matching the entry node that current on top of the stack). if not child.node.unique_name.startswith("Exit__"): raise RuntimeError( - f"Unexpected `NDSLRuntim_Label` '{child.node.unique_name}'." + f"Unexpected `NDSLRuntime_Label` '{child.node.unique_name}'." ) # For exit nodes, find the matching entry node and the new children. @@ -160,79 +244,13 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRo # recurse down first to label sections "leaf first" self.generic_visit(node) - self._label_marked_sections(node) + self._depth_first_callback(node) ndsl_log.debug(f"{self}: labeled {self._labeled_sections} sections.") return node - def visit_GBlock(self, node: tn.GBlock) -> tn.GBlock: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - - def visit_LoopScope(self, node: tn.LoopScope) -> tn.LoopScope: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - - def visit_ForScope(self, node: tn.ForScope) -> tn.ForScope: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - - def visit_WhileScope(self, node: tn.WhileScope) -> tn.WhileScope: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - - def visit_DoWhileScope(self, node: tn.DoWhileScope) -> tn.DoWhileScope: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - - def visit_IfScope(self, node: tn.IfScope) -> tn.IfScope: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - def visit_StateIfScope(self, node: tn.StateIfScope) -> tn.StateIfScope: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - - def visit_ElifScope(self, node: tn.ElifScope) -> tn.ElifScope: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - - def visit_ElseScope(self, node: tn.ElseScope) -> tn.ElseScope: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - - def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - - def visit_ConsumeScope(self, node: tn.ConsumeScope) -> tn.ConsumeScope: - self.generic_visit(node) - self._label_marked_sections(node) - - return node - - -class _ApplyLocalOptimizations(tn.ScheduleNodeVisitor): +class _ApplyLocalOptimizations(ScheduleTreeScopeTransformer): """ Applies local optimization in `LabeledSection`s in a "leaf first" approach. @@ -247,86 +265,121 @@ def __init__(self, backend: Backend) -> None: def __str__(self) -> str: return "_LabelSections" - def visit_LabeledSection(self, node: _LabeledSection) -> None: - # Go down into children first such that we can apply local optimization "leaf first". + def visit__LabeledSection(self, node: _LabeledSection) -> _LabeledSection: + # Recurse into labeled sections to support nested labeled sections. + self._breadth_first_callback(node) self.generic_visit(node) + self._depth_first_callback(node) - # TODO - # The code below is basically an `StreePipeline`. I've duplicated that - # pipeline because we need some clever engineering to not get into a - # hell of dependency circles (where the local optimizations are pipeline pass - # and in itself depend on the pipeline). - - config = node.optimizations - assert config.stree.enabled - - # HACK - # Below, we are calling `visit_ScheduleTreeRoot` with a `LabeledSection`. This works - # because python uses duck-typing. - # TODO - # Clean up pipeline passes and the pipeline itself such that they can work - # on any subtree (i.e. any `ScheduleTreeScope`). - - if self._backend.is_gpu_backend(): - if config.stree.inline_K_loops_size_one: - gpu_inliner = InlineVertical2DWrite() - gpu_inliner.visit_ScheduleTreeRoot(node) - - if config.stree.merger.enabled: - gpu_merger = CartesianMerge( - self._backend, - overcompute=config.stree.merger.overcompute, - merge_order=config.stree.merger.order, - ) - gpu_merger.visit_ScheduleTreeRoot(node) + return node + + def _depth_first_callback(self, scope: tn.ScheduleTreeScope) -> None: + new_children: list[tn.ScheduleTreeNode] = [] + + for child in scope.children: + # Any child that isn't a _LabeledSection gets directly added to the list of new children. + if not isinstance(child, _LabeledSection): + new_children.append(child) + continue - if config.stree.kernelize: - if config.stree.merger.order not in ("IJK", "KJI"): + # For labeled sections, apply the local optimizations to the sections' children, then + # append the possibly transformed children to the list of new children (without the + # labeled section). + + # TODO + # The code below is basically an `StreePipeline`. I've duplicated that + # pipeline because we need some clever engineering to not get into a + # hell of dependency circles (where the local optimizations are pipeline pass + # and in itself depend on the pipeline). + + config = child.optimizations + assert config.stree.enabled + + # HACK + # Below, we are calling `visit_ScheduleTreeRoot` with a `LabeledSection`. This works + # because python uses duck-typing. + # TODO + # Clean up pipeline passes and the pipeline itself such that they can work + # on any subtree (i.e. any `ScheduleTreeScope`). + + if self._backend.is_gpu_backend(): + if config.stree.inline_K_loops_size_one: + gpu_inliner = InlineVertical2DWrite() + gpu_inliner.visit_ScheduleTreeRoot(child) + + if config.stree.merger.enabled: + gpu_merger = CartesianMerge( + self._backend, + overcompute=config.stree.merger.overcompute, + merge_order=config.stree.merger.order, + ) + gpu_merger.visit_ScheduleTreeRoot(child) + + if config.stree.kernelize: + if config.stree.merger.order not in ("IJK", "KJI"): + ndsl_log.warning( + "Can't locally kernelize maps. Unknown apply oder. Skipping this pass." + ) + else: + # Follow the merge-order for kernelization + gpu_kernelizer = KernelizeMaps( + self._backend, + apply_order=( + "JI" if config.stree.merger.order == "IJK" else "JK" + ), + ) + gpu_kernelizer.visit_ScheduleTreeRoot(child) + + if config.stree.refine_transients: + # We can't know if transients are local to the scope that we are working in. + # In they are not, transient refinement can generate wrong results and refine + # too eagerly. Global transient refinement will also work in this section. ndsl_log.warning( - "Can't locally kernelize maps. Unknown apply oder. Skipping this pass." + "[Local-Opt]: Transient refinement can't e applied on a local scale " + "because it needs the global information on where/how transient data " + "is used. Please enable transient refinement on your global optimization " + "config and disable it here. No transients will be refined on the local " + "scale even if this option is turned on." ) - else: - # Follow the merge-order for kernelization - gpu_kernelizer = KernelizeMaps( + else: + if config.stree.inline_K_loops_size_one: + cpu_inliner = InlineVertical2DWrite() + cpu_inliner.visit_ScheduleTreeRoot(child) + + if config.stree.merger.enabled: + cpu_merger = CartesianMerge( self._backend, - apply_order=( - "JI" if config.stree.merger.order == "IJK" else "JK" - ), + overcompute=config.stree.merger.overcompute, + merge_order=config.stree.merger.order, + ) + cpu_merger.visit_ScheduleTreeRoot(child) + + if config.stree.refine_transients: + # We can't know if transients are local to the scope that we are working in. + # In they are not, transient refinement can generate wrong results and refine + # too eagerly. Global transient refinement will also work in this section. + ndsl_log.warning( + "[Local-Opt]: Transient refinement can't e applied on a local scale " + "because it needs the global information on where/how transient data " + "is used. Please enable transient refinement on your global optimization " + "config and disable it here. No transients will be refined on the local " + "scale even if this option is turned on." ) - gpu_kernelizer.visit_ScheduleTreeRoot(node) - - if config.stree.refine_transients: - # TODO - # 🐞 Transient refine can't be used because of bugs transients showing - # in code generation. - # gpu_refiner = CartesianRefineTransients(self._backend) - # gpu_refiner.visit_ScheduleTreeRoot(node) - raise ValueError( - "Transient refinement is currently unavailable in the GPU pipeline." - ) - else: - if config.stree.inline_K_loops_size_one: - cpu_inliner = InlineVertical2DWrite() - cpu_inliner.visit_ScheduleTreeRoot(node) - - if config.stree.merger.enabled: - cpu_merger = CartesianMerge( - self._backend, - overcompute=config.stree.merger.overcompute, - merge_order=config.stree.merger.order, - ) - cpu_merger.visit_ScheduleTreeRoot(node) - if config.stree.refine_transients: - cpu_refiner = CartesianRefineTransients(self._backend) - cpu_refiner.visit_ScheduleTreeRoot(node) + # Replace this `LabeledSection` with just the (now transformed) children. + for c in child.children: + # be sure to re-parent the children of this node to the new parent + c.parent = child.parent + new_children.append(c) - # Replace this `LabeledSection` with just the (now transformed) children. - for child in node.children: - # be sure to re-parent the children of this node to the new parent - child.parent = node.parent - node_index = list_index(node.parent.children, node) - node.parent.children[node_index : node_index + 1] = node.children + scope.children = new_children + + # sanity checks + for child in scope.children: + assert child.parent == scope # expect correct parent + assert not isinstance( + child, _LabeledSection + ) # no labeled sections should be left at this point class LocalOptimizations(tn.ScheduleNodeVisitor): @@ -343,6 +396,3 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: # .. then, apply local optimizations on children of `LabeledSection`s. _ApplyLocalOptimizations(self._backend).visit(node) - - # debug only - assert node diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py index 39d213fc..91bfb5e0 100644 --- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py +++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py @@ -4,7 +4,6 @@ from dace.sdfg.analysis.schedule_tree import treenodes as tn from ndsl import ndsl_log -from ndsl.config import Backend, BackendFramework from ndsl.dsl.dace.stree.optimizations.common import AxisIterator @@ -28,7 +27,6 @@ def _reduce_cartesian_axis_size_to_1( transient_map_reads: dace.subsets.Range | None, transient_map_writes: dace.subsets.Range | None, transient_data: dace.data.Data, - layout_map: tuple[int, ...], ) -> bool: """Reduce dimension size of transient to 1 if all access (reads and writes) are atomic""" @@ -55,21 +53,12 @@ def _reduce_cartesian_axis_size_to_1( # therefore this dimension can be removed. BUT we are not truly # removing it, we are reducing it to 1 to not have to deal # with different slicing. - transient_data.shape = _change_index_of_tuple( + new_shape = _change_index_of_tuple( transient_data.shape, axis.as_cartesian_index(), value=1, ) - - if len(transient_data.shape) == 3: - layout = [*layout_map] - else: - data_dim_count = len(transient_data.shape) - 3 - layout = [dim + data_dim_count for dim in layout_map] + [ - i - 1 for i in range(data_dim_count, 0, -1) - ] - - transient_data.set_strides_from_layout(*layout) + transient_data.set_shape(new_shape) transient_data.lifetime = dace.dtypes.AllocationLifetime.State return True @@ -91,9 +80,6 @@ def __init__(self) -> None: self.transients_range_writes: dict[str, dace.subsets.Range | None] = {} self.transients_range_reads: dict[str, dace.subsets.Range | None] = {} - def __str__(self) -> str: - return "CartesianCollectMaps" - def _find_first_map_or_loop( self, node: tn.TaskletNode, @@ -147,8 +133,8 @@ def _record_access( ].add(map_entry) def visit_TaskletNode(self, node: tn.TaskletNode) -> None: - self._record_access(node, node.input_memlets(), self.transients_range_writes) - self._record_access(node, node.output_memlets(), self.transients_range_reads) + self._record_access(node, node.input_memlets(), self.transients_range_reads) + self._record_access(node, node.output_memlets(), self.transients_range_writes) def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: self.containers = node.containers @@ -158,23 +144,20 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: self.transients_range_writes[name] = None self.transients_range_reads[name] = None - for child in node.children: - self.visit(child) + self.generic_visit(node) class RebuildMemletsFromContainers(tn.ScheduleNodeVisitor): - """Rebuild memlets from containers to ensure they are scope to the right size.""" + """Rebuild memlets from containers to ensure they are scoped to the right size.""" def __init__(self, refined_arrays: set[str]) -> None: self._refined_arrays = refined_arrays - def __str__(self) -> str: - return "RefineTransientAxis" - def visit_TaskletNode(self, node: tn.TaskletNode) -> None: for memlet in [*node.output_memlets(), *node.input_memlets()]: if memlet.data not in self._refined_arrays: continue + array = self.containers[memlet.data] if array.transient: if not isinstance(memlet.subset, dace.subsets.Range): @@ -190,8 +173,7 @@ def visit_TaskletNode(self, node: tn.TaskletNode) -> None: def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: self.containers = node.containers - for child in node.children: - self.visit(child) + self.generic_visit(node) class CartesianRefineTransients(tn.ScheduleNodeTransformer): @@ -237,7 +219,7 @@ class CartesianRefineTransients(tn.ScheduleNodeTransformer): memory (e.g. halo) for the `RebuildMemletsFromContainers`! """ - def __init__(self, backend: Backend) -> None: + def __init__(self) -> None: warnings.warn( "CartesianRefineTransients is a WIP. It's usage is *severely* limited " "and will most likely lead to bad numerics. Check the docs, check utest.", @@ -245,13 +227,6 @@ def __init__(self, backend: Backend) -> None: stacklevel=2, ) - if not backend.is_orchestrated() or backend.framework != BackendFramework.DACE: - raise NotImplementedError( - f"[Schedule Tree Opt] CartesianRefineTransient not implemented for backend {backend}" - ) - self.layout_map = backend.as_layout_map() - self.refined_array: set[str] = set() - def __str__(self) -> str: return "CartesianRefineTransients" @@ -260,10 +235,11 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: collect_map.visit(node) # Remove Axis - refined_transient = 0 + refined_arrays: set[str] = set() for name, data in node.containers.items(): if not (data.transient and isinstance(data, dace.data.Array)): continue + refined = False for axis in AxisIterator: # We do not refine multi-map transients @@ -276,18 +252,18 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None: > 1 ): continue + # Refine axis down to 1 refined |= _reduce_cartesian_axis_size_to_1( axis, collect_map.transients_range_reads[name], collect_map.transients_range_writes[name], data, - self.layout_map, ) - refined_transient += 1 if refined else 0 - self.refined_array.add(name) + if refined: + refined_arrays.add(name) - RebuildMemletsFromContainers(self.refined_array).visit(node) + RebuildMemletsFromContainers(refined_arrays).visit(node) - ndsl_log.debug(f"🚀 {refined_transient} Transient refined") + ndsl_log.debug(f"🚀 {len(refined_arrays)} Transient refined") diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py index a44c7c08..50f2d32e 100644 --- a/ndsl/dsl/dace/stree/pipeline.py +++ b/ndsl/dsl/dace/stree/pipeline.py @@ -81,7 +81,7 @@ def __init__( ) ) if config.stree.refine_transients: - ppl_passes.append(CartesianRefineTransients(backend)) + ppl_passes.append(CartesianRefineTransients()) else: ppl_passes = passes super().__init__( From 902c647969b06b07016023d97aea335b7ce6c4a5 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 1 Jul 2026 15:44:36 +0200 Subject: [PATCH 101/101] fix test case of no-overcompute merge --- .../dace/stree/optimizations/test_merge.py | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py index 36e366b3..8dd82624 100644 --- a/tests/dsl/dace/stree/optimizations/test_merge.py +++ b/tests/dsl/dace/stree/optimizations/test_merge.py @@ -8,10 +8,9 @@ from ndsl.boilerplate import get_factories_single_tile_orchestrated from ndsl.config import Backend from ndsl.constants import I_DIM, J_DIM, K_DIM -from ndsl.dsl.dace.stree.pipeline import CartesianMerge, CleanUpScheduleTree from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval from ndsl.dsl.typing import FloatField -from tests.dsl.dace.stree import StreePipeline, get_SDFG_and_purge +from tests.dsl.dace.stree import get_SDFG_and_purge from tests.dsl.dace.stree.optimizations import Factories @@ -74,6 +73,19 @@ def __init__( method_to_orchestrate=method, optimization_config=config, ) + orchestrate( + obj=self, + config=stencil_factory.config.dace_config, + method_to_orchestrate="no_overcompute_merge", + optimization_config=OptimizationConfig( + stree=OptimizationConfig.Tree( + enabled=True, + merger=OptimizationConfig.Tree.Merger( + enabled=True, overcompute=False + ), + ) + ), + ) self.stencil = stencil_factory.from_dims_halo( func=stencil, @@ -127,6 +139,14 @@ def overcompute_merge( self.stencil(in_field, out_field) self.stencil_with_different_intervals(in_field, out_field) + def no_overcompute_merge( + self, + in_field: FloatField, + out_field: FloatField, + ) -> None: + self.stencil(in_field, out_field) + self.stencil_with_different_intervals(in_field, out_field) + def push_non_cartesian_for( self, in_field: FloatField, @@ -213,13 +233,7 @@ def test_no_overcompute_merge( in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "") out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "") - no_overcompute = [ - CleanUpScheduleTree(), - CartesianMerge(stencil_factory.backend, overcompute=False), - ] - - with StreePipeline(passes=no_overcompute): - code.overcompute_merge(in_qty, out_qty) + code.no_overcompute_merge(in_qty, out_qty) sdfg = get_SDFG_and_purge(stencil_factory).sdfg