From c712dadad3c796ef8e8d12f811a452c9acaf635e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 11 May 2026 14:33:13 +0200 Subject: [PATCH 01/16] expose more math functions as dace.math.xyz --- dace/runtime/include/dace/math.h | 83 ++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 4 deletions(-) diff --git a/dace/runtime/include/dace/math.h b/dace/runtime/include/dace/math.h index 533056c9e4..136aa87c72 100644 --- a/dace/runtime/include/dace/math.h +++ b/dace/runtime/include/dace/math.h @@ -86,7 +86,7 @@ static DACE_CONSTEXPR DACE_HDFI T Modulo_float(const T& value, const T& modulus) return value - floor(value / modulus) * modulus; } -// Implement to support a match wtih Fortran's intrinsic EXPONENT +// Implement to support a match with Fortran's intrinsic EXPONENT template::value>* = nullptr> static DACE_CONSTEXPR DACE_HDFI int frexp(const T& a) { int exponent = 0; @@ -254,7 +254,7 @@ static DACE_CONSTEXPR DACE_HDFI std::complex np_float_pow(const std::com // Computes Python modulus (also NumPy remainder) // Formula: num - (num // den) * den // NOTE: This is different than Python math.remainder and C remainder, -// which are equaivalent to the IEEE remainder: num - round(num / den) * den +// which are equivalent to the IEEE remainder: num - round(num / den) * den template static DACE_CONSTEXPR DACE_HDFI T py_mod(const T& numerator, const T& denominator) { T quotient = py_floor(numerator, denominator); @@ -392,7 +392,7 @@ static DACE_CONSTEXPR DACE_HDFI std::complex reciprocal(const std::complex #if __cplusplus < 201703L -// Compute the greates common divisor of two integers +// Compute the greatest common divisor of two integers template static DACE_CONSTEXPR DACE_HDFI T gcd(T a, T b) { // Modern Euclidian algorithm @@ -417,7 +417,7 @@ static DACE_CONSTEXPR DACE_HDFI T lcm(T a, T b) { #else -// Compute the greates common divisor of two integers +// Compute the greatest common divisor of two integers template static DACE_CONSTEXPR DACE_HDFI T gcd(const T& a, const T& b) { return std::gcd(a, b); @@ -572,36 +572,71 @@ namespace dace return std::sin(a); } template + DACE_CONSTEXPR DACE_HDFI T asin(const T& a) + { + return std::asin(a); + } + template DACE_CONSTEXPR DACE_HDFI T sinh(const T& a) { return std::sinh(a); } template + DACE_CONSTEXPR DACE_HDFI T asinh(const T& a) + { + return std::asinh(a); + } + template DACE_CONSTEXPR DACE_HDFI T cos(const T& a) { return std::cos(a); } template + DACE_CONSTEXPR DACE_HDFI T acos(const T& a) + { + return std::acos(a); + } + template DACE_CONSTEXPR DACE_HDFI T cosh(const T& a) { return std::cosh(a); } template + DACE_CONSTEXPR DACE_HDFI T acosh(const T& a) + { + return std::acosh(a); + } + template DACE_CONSTEXPR DACE_HDFI T tan(const T& a) { return std::tan(a); } template + DACE_CONSTEXPR DACE_HDFI T atan(const T& a) + { + return std::atan(a); + } + template DACE_CONSTEXPR DACE_HDFI T tanh(const T& a) { return std::tanh(a); } template + DACE_CONSTEXPR DACE_HDFI T atanh(const T& a) + { + return std::atanh(a); + } + template DACE_CONSTEXPR DACE_HDFI T sqrt(const T& a) { return std::sqrt(a); } template + DACE_CONSTEXPR DACE_HDFI T cbrt(const T& a) + { + return std::cbrt(a); + } + template DACE_CONSTEXPR DACE_HDFI T log(const T& a) { return std::log(a); @@ -611,6 +646,46 @@ namespace dace { return std::log10(a); } + template + DACE_CONSTEXPR DACE_HDFI T fmod(const T& a) + { + return std::fmod(a); + } + template + DACE_CONSTEXPR DACE_HDFI T tgamma(const T& a) + { + return std::tgamma(a); + } + template + DACE_CONSTEXPR DACE_HDFI T ceil(const T& a) + { + return std::ceil(a); + } + template + DACE_CONSTEXPR DACE_HDFI T trunc(const T& a) + { + return std::trunc(a); + } + template + DACE_CONSTEXPR DACE_HDFI T erf(const T& a) + { + return std::erf(a); + } + template + DACE_CONSTEXPR DACE_HDFI T erfc(const T& a) + { + return std::erfc(a); + } + template + DACE_CONSTEXPR DACE_HDFI T nearbyint(const T& a) + { + return std::nearbyint(a); + } + template + DACE_CONSTEXPR DACE_HDFI T round(const T& a) + { + return std::round(a); + } } namespace cmath From 5475d547b76cb44eb876125eb35d6fac08831fdc Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 11 May 2026 15:13:31 +0200 Subject: [PATCH 02/16] even more math functions exposed --- dace/runtime/include/dace/math.h | 35 ++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/dace/runtime/include/dace/math.h b/dace/runtime/include/dace/math.h index 136aa87c72..325193558f 100644 --- a/dace/runtime/include/dace/math.h +++ b/dace/runtime/include/dace/math.h @@ -514,6 +514,16 @@ namespace dace { return (T)std::exp(a); } + template + DACE_CONSTEXPR DACE_HDFI T exp2(const T& a) + { + return (T)std::exp2(a); + } + template + DACE_CONSTEXPR DACE_HDFI T expm1(const T& a) + { + return (T)std::expm1(a); + } #ifdef __CUDACC__ template @@ -617,6 +627,11 @@ namespace dace return std::atan(a); } template + DACE_CONSTEXPR DACE_HDFI T atan2(const T& a) + { + return std::atan2(a); + } + template DACE_CONSTEXPR DACE_HDFI T tanh(const T& a) { return std::tanh(a); @@ -647,11 +662,26 @@ namespace dace return std::log10(a); } template + DACE_CONSTEXPR DACE_HDFI T log1p(const T& a) + { + return std::log1p(a); + } + template + DACE_CONSTEXPR DACE_HDFI T log2(const T& a) + { + return std::log2(a); + } + template DACE_CONSTEXPR DACE_HDFI T fmod(const T& a) { return std::fmod(a); } template + DACE_CONSTEXPR DACE_HDFI T lgamma(const T& a) + { + return std::lgamma(a); + } + template DACE_CONSTEXPR DACE_HDFI T tgamma(const T& a) { return std::tgamma(a); @@ -686,6 +716,11 @@ namespace dace { return std::round(a); } + template + DACE_CONSTEXPR DACE_HDFI T hypot(const T& a) + { + return std::hypot(a); + } } namespace cmath From 3df061c8aeabcaeea966f79e39a4dbded2628df9 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 11 May 2026 16:15:53 +0200 Subject: [PATCH 03/16] fix fmod number of arguments --- dace/runtime/include/dace/math.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/runtime/include/dace/math.h b/dace/runtime/include/dace/math.h index 325193558f..f1ae7dc9f6 100644 --- a/dace/runtime/include/dace/math.h +++ b/dace/runtime/include/dace/math.h @@ -672,9 +672,9 @@ namespace dace return std::log2(a); } template - DACE_CONSTEXPR DACE_HDFI T fmod(const T& a) + DACE_CONSTEXPR DACE_HDFI T fmod(const T& a, const T& b) { - return std::fmod(a); + return std::fmod(a, b); } template DACE_CONSTEXPR DACE_HDFI T lgamma(const T& a) From 51cdbf35194a6b487e9a936a226ce9c5b41dc260 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 11 May 2026 16:23:28 +0200 Subject: [PATCH 04/16] fixup: hypot function arguments --- dace/runtime/include/dace/math.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/runtime/include/dace/math.h b/dace/runtime/include/dace/math.h index f1ae7dc9f6..634299a7b5 100644 --- a/dace/runtime/include/dace/math.h +++ b/dace/runtime/include/dace/math.h @@ -717,9 +717,9 @@ namespace dace return std::round(a); } template - DACE_CONSTEXPR DACE_HDFI T hypot(const T& a) + DACE_CONSTEXPR DACE_HDFI T hypot(const T& a, const T& b) { - return std::hypot(a); + return std::hypot(a, b); } } From f271b30bb983559306342ce2ff98c69e6662bb32 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 18 May 2026 10:16:36 +0200 Subject: [PATCH 05/16] fix memlet symbol replacement --- dace/memlet.py | 2 +- tests/numpy/common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/memlet.py b/dace/memlet.py index 674e19ae6b..95b6e432a2 100644 --- a/dace/memlet.py +++ b/dace/memlet.py @@ -469,7 +469,7 @@ def replace(self, repl_dict): for symbol in repl_dict: if str(symbol) != str(repl_dict[symbol]): intermediate = symbolic.symbol('__dacesym_' + str(symbol)) - repl_to_intermediate[symbolic.symbol(symbol)] = intermediate + repl_to_intermediate[symbol] = intermediate repl_to_final[intermediate] = repl_dict[symbol] if len(repl_to_intermediate) > 0: diff --git a/tests/numpy/common.py b/tests/numpy/common.py index 5849269c99..c50f40188e 100644 --- a/tests/numpy/common.py +++ b/tests/numpy/common.py @@ -26,7 +26,7 @@ def compare_numpy_output(device=dace.dtypes.DeviceType.CPU, (including errors). `func` will be run once as a dace program, and once using python. - The inputs to the function will be randomly intialized arrays with + The inputs to the function will be randomly initialized arrays with shapes and dtypes according to the argument annotations. Note that this should be used *instead* of the `@dace.program` From bec508d401a5df004d0955267a056cfb5a92c86e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 19 May 2026 12:27:09 +0200 Subject: [PATCH 06/16] fix: stree -> sdfg, data descriptors of nested maps --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index f8bfab9dde..2ff5f1043c 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -495,8 +495,13 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: assert new_in_connector == True assert new_in_connector == new_out_connector + if memlet_data in sdfg.arrays: + data_descriptor = sdfg.arrays[memlet_data] + else: + _sdfg = self._parent_sdfg_with_array(memlet_data, sdfg) + data_descriptor = _sdfg.arrays[memlet_data] self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, - Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + Memlet.from_array(memlet_data, data_descriptor)) else: if isinstance(outer_map_entry, SDFG): # Copy data descriptor from parent SDFG and add input connector @@ -593,13 +598,18 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: access_cache[name] = write_access_node access_node = access_cache[name] + if name in sdfg.arrays: + data_descriptor = sdfg.arrays[name] + else: + _sdfg = self._parent_sdfg_with_array(name, sdfg) + data_descriptor = _sdfg.arrays[name] self._current_state.add_memlet_path(map_exit, access_node, src_conn=out_connector_name, - memlet=Memlet.from_array(name, sdfg.arrays[name])) + memlet=Memlet.from_array(name, data_descriptor)) if isinstance(outer_map_entry, nodes.EntryNode): - outer_to_connect[name] = (access_node, Memlet.from_array(name, sdfg.arrays[name])) + outer_to_connect[name] = (access_node, Memlet.from_array(name, data_descriptor)) else: assert isinstance(outer_map_entry, SDFG) or outer_map_entry is None From ec81b1a0c2a872da8dd315378ff6a9ac67d5458b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 20 May 2026 15:30:00 +0200 Subject: [PATCH 07/16] consider scalar memlets in stree -> sdfg --- .../analysis/schedule_tree/tree_to_sdfg.py | 83 +++++++++++++------ dace/sdfg/analysis/schedule_tree/treenodes.py | 4 +- dace/sdfg/sdfg.py | 11 ++- dace/sdfg/state.py | 19 +++-- tests/schedule_tree/to_sdfg_test.py | 55 ++++++++++++ 5 files changed, 136 insertions(+), 36 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 2ff5f1043c..9028963c32 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import Enum, auto from types import TracebackType -from typing import Final +from typing import Final, Sequence from dace import symbolic from dace.memlet import Memlet @@ -223,52 +223,80 @@ def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: if memlet.data not in sdfg.arrays: raise ValueError(f"Parsing AssignNode {node} failed. Can't find {memlet.data} in {sdfg}.") - def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: - current_state = self._current_state - assert current_state is not None - cf_region = current_state.parent_graph - - loop_region = LoopRegion(label=node.loop.label, - condition_expr=node.loop.loop_condition, - loop_var=node.loop.loop_variable, - initialize_expr=node.loop.init_statement, - update_expr=node.loop.update_statement, - unroll=node.loop.unroll, - unroll_factor=node.loop.unroll_factor) - cf_region.add_node(loop_region) - loop_state = loop_region.add_state(f"for_loop_state_{id(node)}", is_start_block=True) + def _loop_state_name_prefix(self, node: tn.ForScope | tn.WhileScope) -> str: + if isinstance(node, tn.ForScope): + return "for" - _insert_and_split_assignments(current_state, loop_region) + if isinstance(node, tn.WhileScope): + return "while" - self._current_state = loop_state - self.visit(node.children, sdfg=sdfg) - - after_state = _insert_and_split_assignments(loop_region, label="loop_after") - self._current_state = after_state + raise NotImplementedError(f"Loop state name prefix not implemented for loop of type {type(node)}.") - def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: + def _add_loop_region(self, node: tn.ForScope | tn.WhileScope, sdfg: SDFG) -> None: current_state = self._current_state - assert current_state is not None + assert current_state is not None # just to keep pyright happy cf_region = current_state.parent_graph - loop_region = node.loop + loop_region = LoopRegion( + label=node.loop.label, + condition_expr=node.loop.loop_condition, + loop_var=node.loop.loop_variable, + initialize_expr=node.loop.init_statement, + update_expr=node.loop.update_statement, + unroll=node.loop.unroll, + unroll_factor=node.loop.unroll_factor, + inverted=node.loop.inverted, + update_before_condition=node.loop.update_before_condition, + ) + + memlets = loop_region.get_meta_read_memlets(self._ctx.root.containers) + self._ensure_data_descriptors(memlets, sdfg) + cf_region.add_node(loop_region) - loop_state = loop_region.add_state(f"while_loop_state_{id(node)}", is_start_block=True) + prefix = self._loop_state_name_prefix(node) + loop_state = loop_region.add_state(f"{prefix}_loop_state_{id(node)}", is_start_block=True) _insert_and_split_assignments(current_state, loop_region) self._current_state = loop_state self.visit(node.children, sdfg=sdfg) - after_state = _insert_and_split_assignments(loop_region, label="loop_after") + after_state = _insert_and_split_assignments(loop_region, label=f"{prefix}_loop_after") self._current_state = after_state + def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: + self._add_loop_region(node, sdfg) + + def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: + self._add_loop_region(node, sdfg) + def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_LoopScope(self, node: tn.LoopScope, sdfg: SDFG) -> None: raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + def _ensure_data_descriptors(self, memlets: Sequence[Memlet], sdfg: SDFG) -> None: + scope_node, to_connect = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) + if isinstance(scope_node, SDFG): + for memlet in memlets: + # Copy data descriptor from parent SDFG and add input connector + if memlet.data not in sdfg.arrays: + parent_sdfg = self._parent_sdfg_with_array(memlet.data, sdfg) + + # Support for NView nodes + use_nview = self._apply_nview_array_override(memlet.data, sdfg) + if not use_nview: + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + + # Dev note: memlet.data and nview.target are identical + assert memlet.data not in to_connect["inputs"] + to_connect["inputs"].add(memlet.data) + def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: before_state = self._current_state assert before_state is not None @@ -285,6 +313,9 @@ def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: if_body = ControlFlowRegion("if_body", sdfg=sdfg) conditional_block.add_branch(node.condition, if_body) + memlets = conditional_block.get_meta_read_memlets(self._ctx.root.containers) + self._ensure_data_descriptors(memlets, sdfg) + if_state = if_body.add_state("if_state", is_start_block=True) self._current_state = if_state diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 70297e6db9..6a9e9dc38c 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -540,7 +540,7 @@ def input_memlets(self, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() result = MemletSet() - result.update(memlets_in_ast(self.condition.code[0], root.containers)) + result.update(memlets_in_ast(self.condition.code[0], root.containers, include_scalars=True)) result.update(super().input_memlets(root, **kwargs)) return result @@ -623,7 +623,7 @@ def input_memlets(self, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() result = MemletSet() - result.update(memlets_in_ast(self.condition.code[0], root.containers)) + result.update(memlets_in_ast(self.condition.code[0], root.containers, include_scalars=True)) result.update(super().input_memlets(root, **kwargs)) return result diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index aee67661ae..3737df78c5 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -119,13 +119,14 @@ def _replace_dict_values(d, old, new): d[k] = new -def memlets_in_ast(node: ast.AST, arrays: Dict[str, dt.Data]) -> List[mm.Memlet]: +def memlets_in_ast(node: ast.AST, arrays: Dict[str, dt.Data], *, include_scalars: bool = False) -> List[mm.Memlet]: """ Generates a list of memlets from each of the subscripts that appear in the Python AST. Assumes the subscript slice can be coerced to a symbolic expression (e.g., no indirect access). :param node: The AST node to find memlets in. :param arrays: A dictionary mapping array names to their data descriptors (a-la ``sdfg.arrays``) + :param include_scalars: If true, include Memlets for scalar accesses. Defaults to false to be backwards compatible. :return: A list of Memlet objects in the order they appear in the AST. """ result: List[mm.Memlet] = [] @@ -136,6 +137,10 @@ def memlets_in_ast(node: ast.AST, arrays: Dict[str, dt.Data]) -> List[mm.Memlet] data, slc = astutils.subscript_to_slice(subnode, arrays) subset = sbs.Range(slc) result.append(mm.Memlet(data=data, subset=subset)) + elif include_scalars and isinstance(subnode, ast.Name): + data = astutils.rname(subnode) + if data in arrays and isinstance(arrays[data], dace.data.Scalar): + result.append(mm.Memlet.from_array(data, arrays[data])) return result @@ -388,10 +393,10 @@ def get_read_memlets(self, arrays: Dict[str, dt.Data]) -> List[mm.Memlet]: :return: A list of Memlet objects for each read. """ result: List[mm.Memlet] = [] - result.extend(memlets_in_ast(self.condition.code[0], arrays)) + result.extend(memlets_in_ast(self.condition.code[0], arrays, include_scalars=True)) for assign in self.assignments.values(): vast = ast.parse(assign) - result.extend(memlets_in_ast(vast, arrays)) + result.extend(memlets_in_ast(vast, arrays, include_scalars=True)) return result diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index d6881305e2..df8995667a 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3556,11 +3556,11 @@ def get_meta_read_memlets(self, arrays: Optional[Dict[str, dt.Data]] = None) -> arrays = arrays if arrays is not None else self.sdfg.arrays - read_memlets = memlets_in_ast(self.loop_condition.code[0], arrays) + read_memlets = memlets_in_ast(self.loop_condition.code[0], arrays, include_scalars=True) if self.init_statement: - read_memlets.extend(memlets_in_ast(self.init_statement.code[0], arrays)) + read_memlets.extend(memlets_in_ast(self.init_statement.code[0], arrays, include_scalars=True)) if self.update_statement: - read_memlets.extend(memlets_in_ast(self.update_statement.code[0], arrays)) + read_memlets.extend(memlets_in_ast(self.update_statement.code[0], arrays, include_scalars=True)) return read_memlets def replace_meta_accesses(self, replacements): @@ -3818,13 +3818,22 @@ def get_meta_codeblocks(self): codes.append(c) return codes - def get_meta_read_memlets(self) -> List[mm.Memlet]: + def get_meta_read_memlets(self, arrays: Optional[Dict[str, dt.Data]] = None) -> List[mm.Memlet]: + """ + Get a list of all (read) memlets in meta codeblocks. + + :param arrays: An optional dictionary mapping array names to their data descriptors. + If not not given defaults to ``self.sdfg.arrays``. + """ # Avoid cyclic imports. from dace.sdfg.sdfg import memlets_in_ast + + arrays = arrays if arrays is not None else self.sdfg.arrays + read_memlets = [] for c, _ in self.branches: if c is not None: - read_memlets.extend(memlets_in_ast(c.code[0], self.sdfg.arrays)) + read_memlets.extend(memlets_in_ast(c.code[0], arrays, include_scalars=True)) return read_memlets def propagate_memlets(self, border_memlets: Dict[str, Dict[str, Optional[mm.Memlet]]]) -> None: diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index e6ae4658a1..218755f851 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -730,6 +730,50 @@ def test_triple_map_nested_if(): sdfg.validate() +def test_triple_map_if_condition_outside(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': data.Array(dace.float64, [60]), + 'tmp': data.Scalar(dace.float64, transient=True), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('assign', {'read'}, {'out'}, 'out = read'), {'read': dace.Memlet('A[1]')}, + {'out': dace.Memlet("tmp[0]")}), + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", "i", sbs.Range.from_string("0:4"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", "k", sbs.Range.from_string("0:3"))), + children=[ + tn.IfScope( + condition=CodeBlock("tmp + 1 > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 1"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], + ), + tn.ElseScope(children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 2"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], ), + ], + ) + ], + ) + ], + ) + ], + ) + + sdfg = stree.as_sdfg(simplify=False) + sdfg.validate() + sdfg.compile() + + def test_create_nested_map_scope_multi_read(): stree = tn.ScheduleTreeRoot( name="tester", @@ -903,14 +947,25 @@ def test_assign_nodes_avoid_duplicate_boundaries(): # test_create_state_boundary_empty_memlet() test_create_tasklet_raw() test_create_tasklet_waw() + test_create_tasklet_war() test_create_loop_for() test_create_loop_while() test_create_if_else() + test_create_if_elif_else() test_create_if_without_else() test_create_map_scope_write() + test_create_map_scope_read_after_write() + test_create_map_scope_write_after_read() test_create_map_scope_copy() test_create_map_scope_double_memlet() test_create_nested_map_scope() + test_double_map_with_for_loop() + test_triple_map_flat_if() + test_triple_map_if_condition_outside() test_create_nested_map_scope_multi_read() test_map_with_state_boundary_inside() + test_map_calculate_temporary_in_two_loops() test_edge_assignment_read_after_write() + test_assign_nodes_force_state_transition() + test_assign_nodes_multiple_force_one_transition() + test_assign_nodes_avoid_duplicate_boundaries() From 44657753cef3c0ce3ef9deef9d0c81e0e7314b1e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 22 May 2026 07:47:09 +0200 Subject: [PATCH 08/16] Connect source/sink nodes with empty Memlets --- .../analysis/schedule_tree/tree_to_sdfg.py | 36 +++++++---- dace/sdfg/propagation.py | 6 +- dace/sdfg/state.py | 2 +- tests/schedule_tree/to_sdfg_test.py | 60 +++++++++++++++++++ 4 files changed, 91 insertions(+), 13 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 9028963c32..34b9b08371 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -23,6 +23,7 @@ class StateBoundaryBehavior(Enum): PREFIX_PASSTHROUGH_IN: Final[str] = "IN_" PREFIX_PASSTHROUGH_OUT: Final[str] = "OUT_" +PREFIX_SINK_TASKLET: Final[str] = "__stree_sink_tasklet" @dataclass @@ -510,10 +511,12 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # connect to local access node (if available) if memlet_data in access_cache: cached_access = access_cache[memlet_data] - self._current_state.add_memlet_path(cached_access, - map_entry, - dst_conn=connector, - memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + self._current_state.add_memlet_path( + cached_access, + map_entry, + dst_conn=connector, + memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data]), + ) continue if isinstance(outer_map_entry, nodes.EntryNode): @@ -558,10 +561,12 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: assert memlet_data not in access_cache access_cache[memlet_data] = self._current_state.add_read(memlet_data) cached_access = access_cache[memlet_data] - self._current_state.add_memlet_path(cached_access, - map_entry, - dst_conn=connector, - memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + self._current_state.add_memlet_path( + cached_access, + map_entry, + dst_conn=connector, + memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data]), + ) if isinstance(outer_map_entry, nodes.EntryNode) and self._current_state.out_degree(outer_map_entry) < 1: self._current_state.add_nedge(outer_map_entry, map_entry, Memlet()) @@ -573,6 +578,13 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # connect writes to map_exit node for name in to_connect: + # Special case: + # This tasklet is a sink node and just needs an (empty) memlet connection to the MapExit node. + if name.startswith(PREFIX_SINK_TASKLET): + tasklet_to_connect, empty_memlet = to_connect[name] + self._current_state.add_nedge(tasklet_to_connect, map_exit, empty_memlet) + continue + in_connector_name = f"{PREFIX_PASSTHROUGH_IN}{name}" out_connector_name = f"{PREFIX_PASSTHROUGH_OUT}{name}" new_in_connector = map_exit.add_in_connector(in_connector_name) @@ -709,8 +721,8 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: cached_access = cache[memlet.data] self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) - # Add empty memlet if map_entry has no out_connectors to connect to - if isinstance(scope_node, nodes.MapEntry) and self._current_state.out_degree(scope_node) < 1: + # Add empty memlet if this tasklet is a source node + if isinstance(scope_node, nodes.MapEntry) and not node.in_memlets: self._current_state.add_nedge(scope_node, tasklet, Memlet()) # Connect output memlets @@ -750,6 +762,10 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: else: assert scope_node is None + # Add empty memlet if this tasklet is a sink node + if isinstance(scope_node, nodes.MapEntry) and not node.out_memlets: + to_connect[f"{PREFIX_SINK_TASKLET}_{id(tasklet)}"] = (tasklet, Memlet()) + def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: raise NotImplementedError(f"Support for {type(node)} not yet implemented.") diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index eb0c5a99f4..0502f786dd 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1572,6 +1572,9 @@ def propagate_memlet(dfg_state, neighboring internal memlets within the same scope into account. """ + if memlet.is_empty(): + return Memlet() + use_dst = False if isinstance(scope_node, nodes.EntryNode): use_dst = False @@ -1587,9 +1590,8 @@ def propagate_memlet(dfg_state, neighboring_edges = [e for e in neighboring_edges if e.dst_conn and e.dst_conn[3:] == connector] else: raise TypeError('Trying to propagate through a non-scope node') - if memlet.is_empty(): - return Memlet() + assert entry_node is not None sdfg = dfg_state.parent scope_node_symbols = set(conn for conn in entry_node.in_connectors if not conn.startswith('IN_')) defined_vars = [ diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index df8995667a..9f566013c1 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -161,7 +161,7 @@ def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: @abc.abstractmethod def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: """ Returns the exit node leaving the context opened by the given entry node. """ - raise None + return None ################################################################### # Memlet-tracking methods diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 218755f851..66d74d8646 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -501,6 +501,41 @@ def test_create_map_scope_write(): sdfg.validate() +def test_create_map_scope_read(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("print_i", {"read"}, {}, "print(read)"), {"read": dace.Memlet("A[i]")}, + {}) + ], + ) + ], + ) + + sdfg = stree.as_sdfg(simplify=False) + sdfg.validate() + + +def test_create_map_scope_hello_world(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:2"))), + children=[tn.TaskletNode(nodes.Tasklet("print_hi", {}, {}, "print('Hello world!')"), {}, {})], + ) + ], + ) + + sdfg = stree.as_sdfg(simplify=False) + sdfg.validate() + + def test_create_map_scope_read_after_write(): stree = tn.ScheduleTreeRoot( name="tester", @@ -586,6 +621,28 @@ def test_create_map_scope_double_memlet(): sdfg.validate() +def test_create_map_scope_write_in_two_tasklets(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float32, [20]), + }, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("assign_i", {}, {"out"}, "out = i"), {}, {"out": dace.Memlet("A[i]")}), + tn.TaskletNode(nodes.Tasklet("assign_i", {}, {"out"}, "out = i"), {}, {"out": dace.Memlet("B[i]")}), + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + def test_create_nested_map_scope(): stree = tn.ScheduleTreeRoot( name="tester", @@ -954,10 +1011,13 @@ def test_assign_nodes_avoid_duplicate_boundaries(): test_create_if_elif_else() test_create_if_without_else() test_create_map_scope_write() + test_create_map_scope_read() + test_create_map_scope_hello_world() test_create_map_scope_read_after_write() test_create_map_scope_write_after_read() test_create_map_scope_copy() test_create_map_scope_double_memlet() + test_create_map_scope_write_in_two_tasklets() test_create_nested_map_scope() test_double_map_with_for_loop() test_triple_map_flat_if() From 99a9360d35f458c328b204860d59a365522484ab Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 22 May 2026 16:52:39 +0200 Subject: [PATCH 09/16] properly implement self-assigning copy nodes --- .../analysis/schedule_tree/tree_to_sdfg.py | 22 +++++++++++----- tests/schedule_tree/to_sdfg_test.py | 26 +++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 34b9b08371..6ac58d530d 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -776,13 +776,23 @@ def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: self._ctx.access_cache[cache_key] = {} access_cache = self._ctx.access_cache[cache_key] - # assumption source access may or may not yet exist (in this state) + # both, source and target nodes, may or may not exist (in this state) src_name = node.memlet.data - source = access_cache[src_name] if src_name in access_cache else self._current_state.add_read(src_name) - - # assumption: target access node doesn't exist yet - assert node.target not in access_cache - target = self._current_state.add_write(node.target) + if src_name not in access_cache: + # cache new read access + source_access_node = self._current_state.add_read(src_name) + access_cache[src_name] = source_access_node + source = access_cache[src_name] + + target_name = node.target + # only re-use cached write-only nodes, e.g. don't create a cycle for + # field[5, 0, 0:73] = copy field[0, 0, 0:73] + if target_name not in access_cache or self._current_state.out_degree( + access_cache[target_name]) > 0 or src_name == target_name: + # cache new write access node + target_access_node = self._current_state.add_write(node.target) + access_cache[node.target] = target_access_node + target = access_cache[node.target] self._current_state.add_memlet_path(source, target, memlet=node.memlet) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 66d74d8646..2a29d70487 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -987,6 +987,31 @@ def test_assign_nodes_avoid_duplicate_boundaries(): assert [type(child) for child in stree.children] == [tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode] +def test_multiple_copy_nodes() -> None: + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + children=[ + tn.CopyNode('A', dace.Memlet("A[0] -> [10]")), + tn.CopyNode('A', dace.Memlet("A[1] -> [11]")), + tn.CopyNode('A', dace.Memlet("A[2] -> [12]")), + ], + ) + + sdfg = stree.as_sdfg() + + states = sdfg.states() + assert len(states) == 1, "expect one state" + + nodes = states[0].nodes() + assert len(nodes) == 4, "expect four access nodes" + for node in nodes: + assert isinstance(node, dace.nodes.AccessNode) + assert node.data == "A" + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -1029,3 +1054,4 @@ def test_assign_nodes_avoid_duplicate_boundaries(): test_assign_nodes_force_state_transition() test_assign_nodes_multiple_force_one_transition() test_assign_nodes_avoid_duplicate_boundaries() + test_multiple_copy_nodes() From 4da9d096ed3454ffa6dcb7b5233c281dc90696c2 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Mon, 1 Jun 2026 16:21:53 +0200 Subject: [PATCH 10/16] unrelated: fixing typo --- dace/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/config.py b/dace/config.py index 53328a7536..ae38322461 100644 --- a/dace/config.py +++ b/dace/config.py @@ -402,7 +402,7 @@ def get_metadata(*key_hierarchy): @staticmethod def get_default(*key_hierarchy): """ Returns the default value of a given configuration entry. - Takes into accound current operating system. + Takes into account current operating system. :param key_hierarchy: A tuple of strings leading to the configuration entry. From a104ffeae8cee41e4f7a508f4ac8b8b8b90c73f7 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 9 Jun 2026 14:35:54 +0200 Subject: [PATCH 11/16] unrelated: fix typehint of load_precompiled_sdfg() --- dace/frontend/python/parser.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 3b9325c2f6..7cdee1a20b 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -16,6 +16,9 @@ from dace.sdfg import SDFG, utils as sdutils from dace.data import create_datadescriptor, Data +if TYPE_CHECKING: + from dace.codegen.compiled_sdfg import CompiledSDFG + try: import mpi4py from dace.sdfg.utils import distributed_compile @@ -793,7 +796,7 @@ def load_sdfg(self, path: str, *args, **kwargs): return sdfg, cachekey - def load_precompiled_sdfg(self, path: str, *args, **kwargs) -> None: + def load_precompiled_sdfg(self, path: str, *args, **kwargs) -> tuple[CompiledSDFG, cached_program.ProgramCacheKey]: """ Loads an external compiled SDFG object that will be invoked when the function is called. From ce6e6b18815560193254636b8b6abc2c489f3f13 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 9 Jun 2026 14:41:34 +0200 Subject: [PATCH 12/16] fixup: import missing `TYPE_CHECKING` --- dace/frontend/python/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 7cdee1a20b..a5312c04ea 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -6,7 +6,7 @@ import os import sympy import sys -from typing import Any, Callable, Dict, List, Optional, Set, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Sequence, Tuple, Union, TYPE_CHECKING from typing import get_origin, get_args import warnings From c8cb49cc29cc281a5c92bd84e231e31cb6dc561d Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 9 Jun 2026 14:53:01 +0200 Subject: [PATCH 13/16] apparently we need the full import ... Please enter the commit message for your changes. Lines starting --- dace/frontend/python/parser.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index a5312c04ea..38b311f5ae 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -6,18 +6,16 @@ import os import sympy import sys -from typing import Any, Callable, Dict, List, Optional, Set, Sequence, Tuple, Union, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Set, Sequence, Tuple, Union from typing import get_origin, get_args import warnings from dace import data, dtypes, hooks, symbolic +from dace.codegen.compiled_sdfg import CompiledSDFG from dace.config import Config +from dace.data import create_datadescriptor, Data from dace.frontend.python import (newast, common as pycommon, cached_program, preprocessing) from dace.sdfg import SDFG, utils as sdutils -from dace.data import create_datadescriptor, Data - -if TYPE_CHECKING: - from dace.codegen.compiled_sdfg import CompiledSDFG try: import mpi4py From c6bc57a3f23d2427da3cb23ece13255de4a9af47 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 17 Jun 2026 10:37:16 +0200 Subject: [PATCH 14/16] fix: ensure unique names of loop regions (stree -> sdfg) Names/Labels of LoopRegions need to be unique inside the CFG (validation checks that). So far, we didn't have problems. However, with "no simplify before stree", we are getting in places where this is an issue. This commit deploys a simple fix to ensure unique names of LoopRegions. --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 6ac58d530d..c07db2646f 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -253,7 +253,7 @@ def _add_loop_region(self, node: tn.ForScope | tn.WhileScope, sdfg: SDFG) -> Non memlets = loop_region.get_meta_read_memlets(self._ctx.root.containers) self._ensure_data_descriptors(memlets, sdfg) - cf_region.add_node(loop_region) + cf_region.add_node(loop_region, ensure_unique_name=True) prefix = self._loop_state_name_prefix(node) loop_state = loop_region.add_state(f"{prefix}_loop_state_{id(node)}", is_start_block=True) From d186d86dea15f7852545dcde0c4f5b9e6d4f072b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 18 Jun 2026 10:22:45 +0200 Subject: [PATCH 15/16] fix: consider read-after-write in input_memlets of tree Every tree node has functions to calculate input and output memlets (of that node). Tree scopes have a default implementation to gather all inputs/outputs of their children. That default implementation for scopes didn't consider read after write (within the same scope). This caused "too many inputs" to be returned, which - in turn - caused the dependency analysis of the state boundary inserter to generate wrong results. This could lead to **missing** state boundaries. We saw write/write races in D2A2C_Vect of pyFV3. --- dace/sdfg/analysis/schedule_tree/treenodes.py | 28 +++++++++++++++- tests/schedule_tree/treenodes_test.py | 33 ++++++++++++++++++- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 6a9e9dc38c..3f969cd301 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -108,7 +108,25 @@ def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoo # Fast path, no propagation necessary if keep_locals: - return MemletSet().union(*(gather(c, root) for c in self.children)) + if not inputs: + # for outputs we don't need to care about read after write + return MemletSet().union(*(gather(c, root) for c in self.children)) + + # for inputs, make sure read-after-write doesn't show up in inputs + result = MemletSet() + previously_written = MemletSet() + + for child in self.children: + c_reads = child.input_memlets(root, **kwargs) + for read in c_reads: + if read not in previously_written: + result.add(read) + + # register writes + for c_write in child.output_memlets(root, **kwargs): + previously_written.add(c_write) + + return result root = root if root is not None else self.get_root() @@ -120,6 +138,7 @@ def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoo current_locals = set() current_locals |= disallow_propagation result = MemletSet() + previously_written = MemletSet() # Loop over children in order, if any new symbol is defined within this scope (e.g., symbol assignment, # dynamic map range), consider it as a new local @@ -133,6 +152,8 @@ def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoo internal_memlets: MemletSet = gather(c, root) if propagate: for memlet in internal_memlets: + if memlet in previously_written: + continue result.add( propagate_subset([memlet], root.containers[memlet.data], @@ -141,6 +162,11 @@ def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoo undefined_variables=current_locals, use_dst=not inputs)) + if inputs: + # register writes to keep track of read-after-write + for c_write in c.output_memlets(root, **kwargs): + previously_written.add(c_write) + return result def input_memlets(self, diff --git a/tests/schedule_tree/treenodes_test.py b/tests/schedule_tree/treenodes_test.py index 31a0abbd21..bb3f05f62f 100644 --- a/tests/schedule_tree/treenodes_test.py +++ b/tests/schedule_tree/treenodes_test.py @@ -1,8 +1,10 @@ # Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from dace.sdfg.analysis.schedule_tree import treenodes as tn -from dace import nodes +from dace import nodes, data, subsets +from dace import Memlet +import dace import pytest @@ -109,6 +111,35 @@ def test_dataflow_scope_children(DataflowScope: type[tn.DataflowScope], tasklet: assert child.parent == scope +def test_scope_inputs_outputs() -> None: + write_scalar = tn.TaskletNode( + nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), + {}, + {'out': Memlet('scalar[0]')}, + ) + read_scalar = tn.TaskletNode( + nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), + {'inp': Memlet('scalar[0]')}, + {'out': Memlet('A[1]')}, + ) + map_scope = tn.MapScope( + node=nodes.MapEntry(nodes.Map('map', ['i'], subsets.Range.from_string("0:20"))), + children=[write_scalar, read_scalar], + ) + + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + 'scalar': data.Scalar(dace.float64) + }, + children=[map_scope], + ) + + assert len(map_scope.input_memlets()) == 0 + assert len(map_scope.output_memlets()) == 2 + + if __name__ == '__main__': test_schedule_tree_scope_children(tn.ScheduleTreeScope, tasklet) test_schedule_tree_scope_children(tn.ControlFlowScope, tasklet) From 7c526886bfadeb9808a06a66fcbca1dbfa6b8ad4 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Tue, 23 Jun 2026 09:56:21 +0200 Subject: [PATCH 16/16] fix: gpu transform node selection (backport of #2413) --- .../interstate/gpu_transform_sdfg.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 766899319e..ac9f607517 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -361,10 +361,9 @@ def apply(self, _, sdfg: sd.SDFG): sdict = state.scope_dict() for node in state.nodes(): if sdict[node] is None: - if isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)): - if node.guid: - if isinstance(node, nodes.LibraryNode): - node.schedule = dtypes.ScheduleType.GPU_Device + if isinstance(node, (nodes.LibraryNode)): + if not self._output_or_input_is_marked_host(state, node): + node.schedule = dtypes.ScheduleType.GPU_Device gpu_nodes.add((state, node)) elif isinstance(node, nodes.EntryNode): if node.guid not in self.host_maps and not self._output_or_input_is_marked_host(state, node): @@ -378,21 +377,24 @@ def apply(self, _, sdfg: sd.SDFG): if isinstance(nnode, (nodes.EntryNode, nodes.LibraryNode)): nnode.schedule = dtypes.ScheduleType.Sequential - # NOTE: The outputs of LibraryNodes, NestedSDFGs and Map that have GPU schedule must be moved to GPU memory. + # NOTE: The outputs of LibraryNodes and Maps that have GPU schedule must be moved to GPU memory. # TODO: Also use GPU-shared and GPU-register memory when appropriate. for state, node in gpu_nodes: - if isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)): + if isinstance(node, (nodes.LibraryNode)): for e in state.out_edges(node): dst = state.memlet_path(e)[-1].dst if isinstance(dst, nodes.AccessNode): desc = sdfg.arrays[dst.data] desc.storage = dtypes.StorageType.GPU_Global - if isinstance(node, nodes.EntryNode): + elif isinstance(node, nodes.EntryNode): for e in state.out_edges(state.exit_node(node)): dst = state.memlet_path(e)[-1].dst if isinstance(dst, nodes.AccessNode): desc = sdfg.arrays[dst.data] desc.storage = dtypes.StorageType.GPU_Global + else: + raise RuntimeError( + f"GPU node of unexpected type. Expected `LibraryNode` or `EntryNode`, found {type(node)}.") ####################################################### # Step 5: Collect free tasklets and check for scalars that have to be moved to the GPU