diff --git a/dace/config.py b/dace/config.py index 53328a7536..ae38322461 100644 --- a/dace/config.py +++ b/dace/config.py @@ -402,7 +402,7 @@ def get_metadata(*key_hierarchy): @staticmethod def get_default(*key_hierarchy): """ Returns the default value of a given configuration entry. - Takes into accound current operating system. + Takes into account current operating system. :param key_hierarchy: A tuple of strings leading to the configuration entry. diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 3b9325c2f6..38b311f5ae 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -11,10 +11,11 @@ import warnings from dace import data, dtypes, hooks, symbolic +from dace.codegen.compiled_sdfg import CompiledSDFG from dace.config import Config +from dace.data import create_datadescriptor, Data from dace.frontend.python import (newast, common as pycommon, cached_program, preprocessing) from dace.sdfg import SDFG, utils as sdutils -from dace.data import create_datadescriptor, Data try: import mpi4py @@ -793,7 +794,7 @@ def load_sdfg(self, path: str, *args, **kwargs): return sdfg, cachekey - def load_precompiled_sdfg(self, path: str, *args, **kwargs) -> None: + def load_precompiled_sdfg(self, path: str, *args, **kwargs) -> tuple[CompiledSDFG, cached_program.ProgramCacheKey]: """ Loads an external compiled SDFG object that will be invoked when the function is called. diff --git a/dace/memlet.py b/dace/memlet.py index 674e19ae6b..95b6e432a2 100644 --- a/dace/memlet.py +++ b/dace/memlet.py @@ -469,7 +469,7 @@ def replace(self, repl_dict): for symbol in repl_dict: if str(symbol) != str(repl_dict[symbol]): intermediate = symbolic.symbol('__dacesym_' + str(symbol)) - repl_to_intermediate[symbolic.symbol(symbol)] = intermediate + repl_to_intermediate[symbol] = intermediate repl_to_final[intermediate] = repl_dict[symbol] if len(repl_to_intermediate) > 0: diff --git a/dace/runtime/include/dace/math.h b/dace/runtime/include/dace/math.h index 533056c9e4..634299a7b5 100644 --- a/dace/runtime/include/dace/math.h +++ b/dace/runtime/include/dace/math.h @@ -86,7 +86,7 @@ static DACE_CONSTEXPR DACE_HDFI T Modulo_float(const T& value, const T& modulus) return value - floor(value / modulus) * modulus; } -// Implement to support a match wtih Fortran's intrinsic EXPONENT +// Implement to support a match with Fortran's intrinsic EXPONENT template::value>* = nullptr> static DACE_CONSTEXPR DACE_HDFI int frexp(const T& a) { int exponent = 0; @@ -254,7 +254,7 @@ static DACE_CONSTEXPR DACE_HDFI std::complex np_float_pow(const std::com // Computes Python modulus (also NumPy remainder) // Formula: num - (num // den) * den // NOTE: This is different than Python math.remainder and C remainder, -// which are equaivalent to the IEEE remainder: num - round(num / den) * den +// which are equivalent to the IEEE remainder: num - round(num / den) * den template static DACE_CONSTEXPR DACE_HDFI T py_mod(const T& numerator, const T& denominator) { T quotient = py_floor(numerator, denominator); @@ -392,7 +392,7 @@ static DACE_CONSTEXPR DACE_HDFI std::complex reciprocal(const std::complex #if __cplusplus < 201703L -// Compute the greates common divisor of two integers +// Compute the greatest common divisor of two integers template static DACE_CONSTEXPR DACE_HDFI T gcd(T a, T b) { // Modern Euclidian algorithm @@ -417,7 +417,7 @@ static DACE_CONSTEXPR DACE_HDFI T lcm(T a, T b) { #else -// Compute the greates common divisor of two integers +// Compute the greatest common divisor of two integers template static DACE_CONSTEXPR DACE_HDFI T gcd(const T& a, const T& b) { return std::gcd(a, b); @@ -514,6 +514,16 @@ namespace dace { return (T)std::exp(a); } + template + DACE_CONSTEXPR DACE_HDFI T exp2(const T& a) + { + return (T)std::exp2(a); + } + template + DACE_CONSTEXPR DACE_HDFI T expm1(const T& a) + { + return (T)std::expm1(a); + } #ifdef __CUDACC__ template @@ -572,36 +582,76 @@ namespace dace return std::sin(a); } template + DACE_CONSTEXPR DACE_HDFI T asin(const T& a) + { + return std::asin(a); + } + template DACE_CONSTEXPR DACE_HDFI T sinh(const T& a) { return std::sinh(a); } template + DACE_CONSTEXPR DACE_HDFI T asinh(const T& a) + { + return std::asinh(a); + } + template DACE_CONSTEXPR DACE_HDFI T cos(const T& a) { return std::cos(a); } template + DACE_CONSTEXPR DACE_HDFI T acos(const T& a) + { + return std::acos(a); + } + template DACE_CONSTEXPR DACE_HDFI T cosh(const T& a) { return std::cosh(a); } template + DACE_CONSTEXPR DACE_HDFI T acosh(const T& a) + { + return std::acosh(a); + } + template DACE_CONSTEXPR DACE_HDFI T tan(const T& a) { return std::tan(a); } template + DACE_CONSTEXPR DACE_HDFI T atan(const T& a) + { + return std::atan(a); + } + template + DACE_CONSTEXPR DACE_HDFI T atan2(const T& a) + { + return std::atan2(a); + } + template DACE_CONSTEXPR DACE_HDFI T tanh(const T& a) { return std::tanh(a); } template + DACE_CONSTEXPR DACE_HDFI T atanh(const T& a) + { + return std::atanh(a); + } + template DACE_CONSTEXPR DACE_HDFI T sqrt(const T& a) { return std::sqrt(a); } template + DACE_CONSTEXPR DACE_HDFI T cbrt(const T& a) + { + return std::cbrt(a); + } + template DACE_CONSTEXPR DACE_HDFI T log(const T& a) { return std::log(a); @@ -611,6 +661,66 @@ namespace dace { return std::log10(a); } + template + DACE_CONSTEXPR DACE_HDFI T log1p(const T& a) + { + return std::log1p(a); + } + template + DACE_CONSTEXPR DACE_HDFI T log2(const T& a) + { + return std::log2(a); + } + template + DACE_CONSTEXPR DACE_HDFI T fmod(const T& a, const T& b) + { + return std::fmod(a, b); + } + template + DACE_CONSTEXPR DACE_HDFI T lgamma(const T& a) + { + return std::lgamma(a); + } + template + DACE_CONSTEXPR DACE_HDFI T tgamma(const T& a) + { + return std::tgamma(a); + } + template + DACE_CONSTEXPR DACE_HDFI T ceil(const T& a) + { + return std::ceil(a); + } + template + DACE_CONSTEXPR DACE_HDFI T trunc(const T& a) + { + return std::trunc(a); + } + template + DACE_CONSTEXPR DACE_HDFI T erf(const T& a) + { + return std::erf(a); + } + template + DACE_CONSTEXPR DACE_HDFI T erfc(const T& a) + { + return std::erfc(a); + } + template + DACE_CONSTEXPR DACE_HDFI T nearbyint(const T& a) + { + return std::nearbyint(a); + } + template + DACE_CONSTEXPR DACE_HDFI T round(const T& a) + { + return std::round(a); + } + template + DACE_CONSTEXPR DACE_HDFI T hypot(const T& a, const T& b) + { + return std::hypot(a, b); + } } namespace cmath diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index f8bfab9dde..c07db2646f 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import Enum, auto from types import TracebackType -from typing import Final +from typing import Final, Sequence from dace import symbolic from dace.memlet import Memlet @@ -23,6 +23,7 @@ class StateBoundaryBehavior(Enum): PREFIX_PASSTHROUGH_IN: Final[str] = "IN_" PREFIX_PASSTHROUGH_OUT: Final[str] = "OUT_" +PREFIX_SINK_TASKLET: Final[str] = "__stree_sink_tasklet" @dataclass @@ -223,52 +224,80 @@ def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: if memlet.data not in sdfg.arrays: raise ValueError(f"Parsing AssignNode {node} failed. Can't find {memlet.data} in {sdfg}.") - def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: - current_state = self._current_state - assert current_state is not None - cf_region = current_state.parent_graph + def _loop_state_name_prefix(self, node: tn.ForScope | tn.WhileScope) -> str: + if isinstance(node, tn.ForScope): + return "for" - loop_region = LoopRegion(label=node.loop.label, - condition_expr=node.loop.loop_condition, - loop_var=node.loop.loop_variable, - initialize_expr=node.loop.init_statement, - update_expr=node.loop.update_statement, - unroll=node.loop.unroll, - unroll_factor=node.loop.unroll_factor) - cf_region.add_node(loop_region) - loop_state = loop_region.add_state(f"for_loop_state_{id(node)}", is_start_block=True) + if isinstance(node, tn.WhileScope): + return "while" - _insert_and_split_assignments(current_state, loop_region) + raise NotImplementedError(f"Loop state name prefix not implemented for loop of type {type(node)}.") - self._current_state = loop_state - self.visit(node.children, sdfg=sdfg) - - after_state = _insert_and_split_assignments(loop_region, label="loop_after") - self._current_state = after_state - - def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: + def _add_loop_region(self, node: tn.ForScope | tn.WhileScope, sdfg: SDFG) -> None: current_state = self._current_state - assert current_state is not None + assert current_state is not None # just to keep pyright happy cf_region = current_state.parent_graph - loop_region = node.loop - cf_region.add_node(loop_region) - loop_state = loop_region.add_state(f"while_loop_state_{id(node)}", is_start_block=True) + loop_region = LoopRegion( + label=node.loop.label, + condition_expr=node.loop.loop_condition, + loop_var=node.loop.loop_variable, + initialize_expr=node.loop.init_statement, + update_expr=node.loop.update_statement, + unroll=node.loop.unroll, + unroll_factor=node.loop.unroll_factor, + inverted=node.loop.inverted, + update_before_condition=node.loop.update_before_condition, + ) + + memlets = loop_region.get_meta_read_memlets(self._ctx.root.containers) + self._ensure_data_descriptors(memlets, sdfg) + + cf_region.add_node(loop_region, ensure_unique_name=True) + prefix = self._loop_state_name_prefix(node) + loop_state = loop_region.add_state(f"{prefix}_loop_state_{id(node)}", is_start_block=True) _insert_and_split_assignments(current_state, loop_region) self._current_state = loop_state self.visit(node.children, sdfg=sdfg) - after_state = _insert_and_split_assignments(loop_region, label="loop_after") + after_state = _insert_and_split_assignments(loop_region, label=f"{prefix}_loop_after") self._current_state = after_state + def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: + self._add_loop_region(node, sdfg) + + def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: + self._add_loop_region(node, sdfg) + def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_LoopScope(self, node: tn.LoopScope, sdfg: SDFG) -> None: raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + def _ensure_data_descriptors(self, memlets: Sequence[Memlet], sdfg: SDFG) -> None: + scope_node, to_connect = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) + if isinstance(scope_node, SDFG): + for memlet in memlets: + # Copy data descriptor from parent SDFG and add input connector + if memlet.data not in sdfg.arrays: + parent_sdfg = self._parent_sdfg_with_array(memlet.data, sdfg) + + # Support for NView nodes + use_nview = self._apply_nview_array_override(memlet.data, sdfg) + if not use_nview: + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + + # Dev note: memlet.data and nview.target are identical + assert memlet.data not in to_connect["inputs"] + to_connect["inputs"].add(memlet.data) + def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: before_state = self._current_state assert before_state is not None @@ -285,6 +314,9 @@ def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: if_body = ControlFlowRegion("if_body", sdfg=sdfg) conditional_block.add_branch(node.condition, if_body) + memlets = conditional_block.get_meta_read_memlets(self._ctx.root.containers) + self._ensure_data_descriptors(memlets, sdfg) + if_state = if_body.add_state("if_state", is_start_block=True) self._current_state = if_state @@ -479,10 +511,12 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # connect to local access node (if available) if memlet_data in access_cache: cached_access = access_cache[memlet_data] - self._current_state.add_memlet_path(cached_access, - map_entry, - dst_conn=connector, - memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + self._current_state.add_memlet_path( + cached_access, + map_entry, + dst_conn=connector, + memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data]), + ) continue if isinstance(outer_map_entry, nodes.EntryNode): @@ -495,8 +529,13 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: assert new_in_connector == True assert new_in_connector == new_out_connector + if memlet_data in sdfg.arrays: + data_descriptor = sdfg.arrays[memlet_data] + else: + _sdfg = self._parent_sdfg_with_array(memlet_data, sdfg) + data_descriptor = _sdfg.arrays[memlet_data] self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, - Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + Memlet.from_array(memlet_data, data_descriptor)) else: if isinstance(outer_map_entry, SDFG): # Copy data descriptor from parent SDFG and add input connector @@ -522,10 +561,12 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: assert memlet_data not in access_cache access_cache[memlet_data] = self._current_state.add_read(memlet_data) cached_access = access_cache[memlet_data] - self._current_state.add_memlet_path(cached_access, - map_entry, - dst_conn=connector, - memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + self._current_state.add_memlet_path( + cached_access, + map_entry, + dst_conn=connector, + memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data]), + ) if isinstance(outer_map_entry, nodes.EntryNode) and self._current_state.out_degree(outer_map_entry) < 1: self._current_state.add_nedge(outer_map_entry, map_entry, Memlet()) @@ -537,6 +578,13 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # connect writes to map_exit node for name in to_connect: + # Special case: + # This tasklet is a sink node and just needs an (empty) memlet connection to the MapExit node. + if name.startswith(PREFIX_SINK_TASKLET): + tasklet_to_connect, empty_memlet = to_connect[name] + self._current_state.add_nedge(tasklet_to_connect, map_exit, empty_memlet) + continue + in_connector_name = f"{PREFIX_PASSTHROUGH_IN}{name}" out_connector_name = f"{PREFIX_PASSTHROUGH_OUT}{name}" new_in_connector = map_exit.add_in_connector(in_connector_name) @@ -593,13 +641,18 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: access_cache[name] = write_access_node access_node = access_cache[name] + if name in sdfg.arrays: + data_descriptor = sdfg.arrays[name] + else: + _sdfg = self._parent_sdfg_with_array(name, sdfg) + data_descriptor = _sdfg.arrays[name] self._current_state.add_memlet_path(map_exit, access_node, src_conn=out_connector_name, - memlet=Memlet.from_array(name, sdfg.arrays[name])) + memlet=Memlet.from_array(name, data_descriptor)) if isinstance(outer_map_entry, nodes.EntryNode): - outer_to_connect[name] = (access_node, Memlet.from_array(name, sdfg.arrays[name])) + outer_to_connect[name] = (access_node, Memlet.from_array(name, data_descriptor)) else: assert isinstance(outer_map_entry, SDFG) or outer_map_entry is None @@ -668,8 +721,8 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: cached_access = cache[memlet.data] self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) - # Add empty memlet if map_entry has no out_connectors to connect to - if isinstance(scope_node, nodes.MapEntry) and self._current_state.out_degree(scope_node) < 1: + # Add empty memlet if this tasklet is a source node + if isinstance(scope_node, nodes.MapEntry) and not node.in_memlets: self._current_state.add_nedge(scope_node, tasklet, Memlet()) # Connect output memlets @@ -709,6 +762,10 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: else: assert scope_node is None + # Add empty memlet if this tasklet is a sink node + if isinstance(scope_node, nodes.MapEntry) and not node.out_memlets: + to_connect[f"{PREFIX_SINK_TASKLET}_{id(tasklet)}"] = (tasklet, Memlet()) + def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: raise NotImplementedError(f"Support for {type(node)} not yet implemented.") @@ -719,13 +776,23 @@ def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: self._ctx.access_cache[cache_key] = {} access_cache = self._ctx.access_cache[cache_key] - # assumption source access may or may not yet exist (in this state) + # both, source and target nodes, may or may not exist (in this state) src_name = node.memlet.data - source = access_cache[src_name] if src_name in access_cache else self._current_state.add_read(src_name) - - # assumption: target access node doesn't exist yet - assert node.target not in access_cache - target = self._current_state.add_write(node.target) + if src_name not in access_cache: + # cache new read access + source_access_node = self._current_state.add_read(src_name) + access_cache[src_name] = source_access_node + source = access_cache[src_name] + + target_name = node.target + # only re-use cached write-only nodes, e.g. don't create a cycle for + # field[5, 0, 0:73] = copy field[0, 0, 0:73] + if target_name not in access_cache or self._current_state.out_degree( + access_cache[target_name]) > 0 or src_name == target_name: + # cache new write access node + target_access_node = self._current_state.add_write(node.target) + access_cache[node.target] = target_access_node + target = access_cache[node.target] self._current_state.add_memlet_path(source, target, memlet=node.memlet) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 70297e6db9..3f969cd301 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -108,7 +108,25 @@ def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoo # Fast path, no propagation necessary if keep_locals: - return MemletSet().union(*(gather(c, root) for c in self.children)) + if not inputs: + # for outputs we don't need to care about read after write + return MemletSet().union(*(gather(c, root) for c in self.children)) + + # for inputs, make sure read-after-write doesn't show up in inputs + result = MemletSet() + previously_written = MemletSet() + + for child in self.children: + c_reads = child.input_memlets(root, **kwargs) + for read in c_reads: + if read not in previously_written: + result.add(read) + + # register writes + for c_write in child.output_memlets(root, **kwargs): + previously_written.add(c_write) + + return result root = root if root is not None else self.get_root() @@ -120,6 +138,7 @@ def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoo current_locals = set() current_locals |= disallow_propagation result = MemletSet() + previously_written = MemletSet() # Loop over children in order, if any new symbol is defined within this scope (e.g., symbol assignment, # dynamic map range), consider it as a new local @@ -133,6 +152,8 @@ def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoo internal_memlets: MemletSet = gather(c, root) if propagate: for memlet in internal_memlets: + if memlet in previously_written: + continue result.add( propagate_subset([memlet], root.containers[memlet.data], @@ -141,6 +162,11 @@ def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoo undefined_variables=current_locals, use_dst=not inputs)) + if inputs: + # register writes to keep track of read-after-write + for c_write in c.output_memlets(root, **kwargs): + previously_written.add(c_write) + return result def input_memlets(self, @@ -540,7 +566,7 @@ def input_memlets(self, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() result = MemletSet() - result.update(memlets_in_ast(self.condition.code[0], root.containers)) + result.update(memlets_in_ast(self.condition.code[0], root.containers, include_scalars=True)) result.update(super().input_memlets(root, **kwargs)) return result @@ -623,7 +649,7 @@ def input_memlets(self, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() result = MemletSet() - result.update(memlets_in_ast(self.condition.code[0], root.containers)) + result.update(memlets_in_ast(self.condition.code[0], root.containers, include_scalars=True)) result.update(super().input_memlets(root, **kwargs)) return result diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index eb0c5a99f4..0502f786dd 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1572,6 +1572,9 @@ def propagate_memlet(dfg_state, neighboring internal memlets within the same scope into account. """ + if memlet.is_empty(): + return Memlet() + use_dst = False if isinstance(scope_node, nodes.EntryNode): use_dst = False @@ -1587,9 +1590,8 @@ def propagate_memlet(dfg_state, neighboring_edges = [e for e in neighboring_edges if e.dst_conn and e.dst_conn[3:] == connector] else: raise TypeError('Trying to propagate through a non-scope node') - if memlet.is_empty(): - return Memlet() + assert entry_node is not None sdfg = dfg_state.parent scope_node_symbols = set(conn for conn in entry_node.in_connectors if not conn.startswith('IN_')) defined_vars = [ diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index aee67661ae..3737df78c5 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -119,13 +119,14 @@ def _replace_dict_values(d, old, new): d[k] = new -def memlets_in_ast(node: ast.AST, arrays: Dict[str, dt.Data]) -> List[mm.Memlet]: +def memlets_in_ast(node: ast.AST, arrays: Dict[str, dt.Data], *, include_scalars: bool = False) -> List[mm.Memlet]: """ Generates a list of memlets from each of the subscripts that appear in the Python AST. Assumes the subscript slice can be coerced to a symbolic expression (e.g., no indirect access). :param node: The AST node to find memlets in. :param arrays: A dictionary mapping array names to their data descriptors (a-la ``sdfg.arrays``) + :param include_scalars: If true, include Memlets for scalar accesses. Defaults to false to be backwards compatible. :return: A list of Memlet objects in the order they appear in the AST. """ result: List[mm.Memlet] = [] @@ -136,6 +137,10 @@ def memlets_in_ast(node: ast.AST, arrays: Dict[str, dt.Data]) -> List[mm.Memlet] data, slc = astutils.subscript_to_slice(subnode, arrays) subset = sbs.Range(slc) result.append(mm.Memlet(data=data, subset=subset)) + elif include_scalars and isinstance(subnode, ast.Name): + data = astutils.rname(subnode) + if data in arrays and isinstance(arrays[data], dace.data.Scalar): + result.append(mm.Memlet.from_array(data, arrays[data])) return result @@ -388,10 +393,10 @@ def get_read_memlets(self, arrays: Dict[str, dt.Data]) -> List[mm.Memlet]: :return: A list of Memlet objects for each read. """ result: List[mm.Memlet] = [] - result.extend(memlets_in_ast(self.condition.code[0], arrays)) + result.extend(memlets_in_ast(self.condition.code[0], arrays, include_scalars=True)) for assign in self.assignments.values(): vast = ast.parse(assign) - result.extend(memlets_in_ast(vast, arrays)) + result.extend(memlets_in_ast(vast, arrays, include_scalars=True)) return result diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index d6881305e2..9f566013c1 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -161,7 +161,7 @@ def entry_node(self, node: nd.Node) -> Optional[nd.EntryNode]: @abc.abstractmethod def exit_node(self, entry_node: nd.EntryNode) -> Optional[nd.ExitNode]: """ Returns the exit node leaving the context opened by the given entry node. """ - raise None + return None ################################################################### # Memlet-tracking methods @@ -3556,11 +3556,11 @@ def get_meta_read_memlets(self, arrays: Optional[Dict[str, dt.Data]] = None) -> arrays = arrays if arrays is not None else self.sdfg.arrays - read_memlets = memlets_in_ast(self.loop_condition.code[0], arrays) + read_memlets = memlets_in_ast(self.loop_condition.code[0], arrays, include_scalars=True) if self.init_statement: - read_memlets.extend(memlets_in_ast(self.init_statement.code[0], arrays)) + read_memlets.extend(memlets_in_ast(self.init_statement.code[0], arrays, include_scalars=True)) if self.update_statement: - read_memlets.extend(memlets_in_ast(self.update_statement.code[0], arrays)) + read_memlets.extend(memlets_in_ast(self.update_statement.code[0], arrays, include_scalars=True)) return read_memlets def replace_meta_accesses(self, replacements): @@ -3818,13 +3818,22 @@ def get_meta_codeblocks(self): codes.append(c) return codes - def get_meta_read_memlets(self) -> List[mm.Memlet]: + def get_meta_read_memlets(self, arrays: Optional[Dict[str, dt.Data]] = None) -> List[mm.Memlet]: + """ + Get a list of all (read) memlets in meta codeblocks. + + :param arrays: An optional dictionary mapping array names to their data descriptors. + If not not given defaults to ``self.sdfg.arrays``. + """ # Avoid cyclic imports. from dace.sdfg.sdfg import memlets_in_ast + + arrays = arrays if arrays is not None else self.sdfg.arrays + read_memlets = [] for c, _ in self.branches: if c is not None: - read_memlets.extend(memlets_in_ast(c.code[0], self.sdfg.arrays)) + read_memlets.extend(memlets_in_ast(c.code[0], arrays, include_scalars=True)) return read_memlets def propagate_memlets(self, border_memlets: Dict[str, Dict[str, Optional[mm.Memlet]]]) -> None: diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 766899319e..ac9f607517 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -361,10 +361,9 @@ def apply(self, _, sdfg: sd.SDFG): sdict = state.scope_dict() for node in state.nodes(): if sdict[node] is None: - if isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)): - if node.guid: - if isinstance(node, nodes.LibraryNode): - node.schedule = dtypes.ScheduleType.GPU_Device + if isinstance(node, (nodes.LibraryNode)): + if not self._output_or_input_is_marked_host(state, node): + node.schedule = dtypes.ScheduleType.GPU_Device gpu_nodes.add((state, node)) elif isinstance(node, nodes.EntryNode): if node.guid not in self.host_maps and not self._output_or_input_is_marked_host(state, node): @@ -378,21 +377,24 @@ def apply(self, _, sdfg: sd.SDFG): if isinstance(nnode, (nodes.EntryNode, nodes.LibraryNode)): nnode.schedule = dtypes.ScheduleType.Sequential - # NOTE: The outputs of LibraryNodes, NestedSDFGs and Map that have GPU schedule must be moved to GPU memory. + # NOTE: The outputs of LibraryNodes and Maps that have GPU schedule must be moved to GPU memory. # TODO: Also use GPU-shared and GPU-register memory when appropriate. for state, node in gpu_nodes: - if isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)): + if isinstance(node, (nodes.LibraryNode)): for e in state.out_edges(node): dst = state.memlet_path(e)[-1].dst if isinstance(dst, nodes.AccessNode): desc = sdfg.arrays[dst.data] desc.storage = dtypes.StorageType.GPU_Global - if isinstance(node, nodes.EntryNode): + elif isinstance(node, nodes.EntryNode): for e in state.out_edges(state.exit_node(node)): dst = state.memlet_path(e)[-1].dst if isinstance(dst, nodes.AccessNode): desc = sdfg.arrays[dst.data] desc.storage = dtypes.StorageType.GPU_Global + else: + raise RuntimeError( + f"GPU node of unexpected type. Expected `LibraryNode` or `EntryNode`, found {type(node)}.") ####################################################### # Step 5: Collect free tasklets and check for scalars that have to be moved to the GPU diff --git a/tests/numpy/common.py b/tests/numpy/common.py index 5849269c99..c50f40188e 100644 --- a/tests/numpy/common.py +++ b/tests/numpy/common.py @@ -26,7 +26,7 @@ def compare_numpy_output(device=dace.dtypes.DeviceType.CPU, (including errors). `func` will be run once as a dace program, and once using python. - The inputs to the function will be randomly intialized arrays with + The inputs to the function will be randomly initialized arrays with shapes and dtypes according to the argument annotations. Note that this should be used *instead* of the `@dace.program` diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index e6ae4658a1..2a29d70487 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -501,6 +501,41 @@ def test_create_map_scope_write(): sdfg.validate() +def test_create_map_scope_read(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("print_i", {"read"}, {}, "print(read)"), {"read": dace.Memlet("A[i]")}, + {}) + ], + ) + ], + ) + + sdfg = stree.as_sdfg(simplify=False) + sdfg.validate() + + +def test_create_map_scope_hello_world(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:2"))), + children=[tn.TaskletNode(nodes.Tasklet("print_hi", {}, {}, "print('Hello world!')"), {}, {})], + ) + ], + ) + + sdfg = stree.as_sdfg(simplify=False) + sdfg.validate() + + def test_create_map_scope_read_after_write(): stree = tn.ScheduleTreeRoot( name="tester", @@ -586,6 +621,28 @@ def test_create_map_scope_double_memlet(): sdfg.validate() +def test_create_map_scope_write_in_two_tasklets(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float32, [20]), + }, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("assign_i", {}, {"out"}, "out = i"), {}, {"out": dace.Memlet("A[i]")}), + tn.TaskletNode(nodes.Tasklet("assign_i", {}, {"out"}, "out = i"), {}, {"out": dace.Memlet("B[i]")}), + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + def test_create_nested_map_scope(): stree = tn.ScheduleTreeRoot( name="tester", @@ -730,6 +787,50 @@ def test_triple_map_nested_if(): sdfg.validate() +def test_triple_map_if_condition_outside(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': data.Array(dace.float64, [60]), + 'tmp': data.Scalar(dace.float64, transient=True), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('assign', {'read'}, {'out'}, 'out = read'), {'read': dace.Memlet('A[1]')}, + {'out': dace.Memlet("tmp[0]")}), + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", "i", sbs.Range.from_string("0:4"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", "k", sbs.Range.from_string("0:3"))), + children=[ + tn.IfScope( + condition=CodeBlock("tmp + 1 > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 1"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], + ), + tn.ElseScope(children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 2"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], ), + ], + ) + ], + ) + ], + ) + ], + ) + + sdfg = stree.as_sdfg(simplify=False) + sdfg.validate() + sdfg.compile() + + def test_create_nested_map_scope_multi_read(): stree = tn.ScheduleTreeRoot( name="tester", @@ -886,6 +987,31 @@ def test_assign_nodes_avoid_duplicate_boundaries(): assert [type(child) for child in stree.children] == [tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode] +def test_multiple_copy_nodes() -> None: + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + children=[ + tn.CopyNode('A', dace.Memlet("A[0] -> [10]")), + tn.CopyNode('A', dace.Memlet("A[1] -> [11]")), + tn.CopyNode('A', dace.Memlet("A[2] -> [12]")), + ], + ) + + sdfg = stree.as_sdfg() + + states = sdfg.states() + assert len(states) == 1, "expect one state" + + nodes = states[0].nodes() + assert len(nodes) == 4, "expect four access nodes" + for node in nodes: + assert isinstance(node, dace.nodes.AccessNode) + assert node.data == "A" + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -903,14 +1029,29 @@ def test_assign_nodes_avoid_duplicate_boundaries(): # test_create_state_boundary_empty_memlet() test_create_tasklet_raw() test_create_tasklet_waw() + test_create_tasklet_war() test_create_loop_for() test_create_loop_while() test_create_if_else() + test_create_if_elif_else() test_create_if_without_else() test_create_map_scope_write() + test_create_map_scope_read() + test_create_map_scope_hello_world() + test_create_map_scope_read_after_write() + test_create_map_scope_write_after_read() test_create_map_scope_copy() test_create_map_scope_double_memlet() + test_create_map_scope_write_in_two_tasklets() test_create_nested_map_scope() + test_double_map_with_for_loop() + test_triple_map_flat_if() + test_triple_map_if_condition_outside() test_create_nested_map_scope_multi_read() test_map_with_state_boundary_inside() + test_map_calculate_temporary_in_two_loops() test_edge_assignment_read_after_write() + test_assign_nodes_force_state_transition() + test_assign_nodes_multiple_force_one_transition() + test_assign_nodes_avoid_duplicate_boundaries() + test_multiple_copy_nodes() diff --git a/tests/schedule_tree/treenodes_test.py b/tests/schedule_tree/treenodes_test.py index 31a0abbd21..bb3f05f62f 100644 --- a/tests/schedule_tree/treenodes_test.py +++ b/tests/schedule_tree/treenodes_test.py @@ -1,8 +1,10 @@ # Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from dace.sdfg.analysis.schedule_tree import treenodes as tn -from dace import nodes +from dace import nodes, data, subsets +from dace import Memlet +import dace import pytest @@ -109,6 +111,35 @@ def test_dataflow_scope_children(DataflowScope: type[tn.DataflowScope], tasklet: assert child.parent == scope +def test_scope_inputs_outputs() -> None: + write_scalar = tn.TaskletNode( + nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), + {}, + {'out': Memlet('scalar[0]')}, + ) + read_scalar = tn.TaskletNode( + nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), + {'inp': Memlet('scalar[0]')}, + {'out': Memlet('A[1]')}, + ) + map_scope = tn.MapScope( + node=nodes.MapEntry(nodes.Map('map', ['i'], subsets.Range.from_string("0:20"))), + children=[write_scalar, read_scalar], + ) + + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + 'scalar': data.Scalar(dace.float64) + }, + children=[map_scope], + ) + + assert len(map_scope.input_memlets()) == 0 + assert len(map_scope.output_memlets()) == 2 + + if __name__ == '__main__': test_schedule_tree_scope_children(tn.ScheduleTreeScope, tasklet) test_schedule_tree_scope_children(tn.ControlFlowScope, tasklet)