diff --git a/dace/codegen/CMakeLists.txt b/dace/codegen/CMakeLists.txt index 614f92a029..3c5d1609c3 100644 --- a/dace/codegen/CMakeLists.txt +++ b/dace/codegen/CMakeLists.txt @@ -58,9 +58,11 @@ include_directories(${DACE_RUNTIME_DIR}/include) # Global DaCe external dependencies find_package(Threads REQUIRED) find_package(OpenMP REQUIRED COMPONENTS CXX) +find_package(Python REQUIRED COMPONENTS Development) list(APPEND DACE_LIBS Threads::Threads) list(APPEND DACE_LIBS OpenMP::OpenMP_CXX) +list(APPEND DACE_LIBS Python::Python) add_definitions(-DDACE_BINARY_DIR=\"${CMAKE_BINARY_DIR}\") diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 1fcd55302b..0deaf13a83 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -62,10 +62,12 @@ def copy_expr( name_override=None, ): data_desc = sdfg.arrays[data_name] + # NOTE: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs? # TODO: Study this when changing Structures to be (optionally?) non-pointers. tokens = data_name.split('.') - if len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure): + if (len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure) + and not isinstance(sdfg.arrays[tokens[0]], data.PythonClass)): name = data_name.replace('.', '->') else: name = data_name @@ -244,6 +246,8 @@ def ptr(name: str, desc: data.Data, sdfg: SDFG = None, framecode: 'DaCeCodeGener if '.' in name: root = name.split('.')[0] + if root in sdfg.arrays and isinstance(sdfg.arrays[root], data.PythonClass): + return pyobject_member_expr(root, name.split('.', 1)[1], desc) if root in sdfg.arrays and isinstance(sdfg.arrays[root], data.Structure): name = name.replace('.', '->') @@ -273,6 +277,14 @@ def ptr(name: str, desc: data.Data, sdfg: SDFG = None, framecode: 'DaCeCodeGener return name +def pyobject_member_expr(root_name: str, attr_path: str, desc: data.Data) -> str: + if isinstance(desc, data.Array): + return f'dace_get_pyobject_attr_ptr<{desc.dtype.ctype}>({root_name}, "{attr_path}")' + if isinstance(desc, data.Scalar): + return f'dace_get_pyobject_attr<{desc.dtype.ctype}>({root_name}, "{attr_path}")' + raise TypeError(f'Unsupported PythonClass member descriptor: {type(desc).__name__}') + + def emit_memlet_reference(dispatcher: 'TargetDispatcher', sdfg: SDFG, memlet: mmlt.Memlet, @@ -371,7 +383,8 @@ def make_const(expr: str) -> str: # NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and structures. # NOTE: Since structures are implemented as pointers, we replace dots with arrows. - expr = expr.replace('.', '->') + if 'dace_get_pyobject_attr' not in expr: + expr = expr.replace('.', '->') return (typedef + ref, pointer_name, expr) @@ -543,7 +556,8 @@ def cpp_array_expr(sdfg, # NOTE: Are there any cases where a mix of '.' and '->' is needed when traversing nested structs? # TODO: Study this when changing Structures to be (optionally?) non-pointers. tokens = memlet.data.split('.') - if len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure): + if (len(tokens) > 1 and tokens[0] in sdfg.arrays and isinstance(sdfg.arrays[tokens[0]], data.Structure) + and not isinstance(sdfg.arrays[tokens[0]], data.PythonClass)): name = memlet.data.replace('.', '->') else: name = memlet.data diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index ce0851c351..6a6e9a8f84 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -19,7 +19,7 @@ dynamic_map_inputs) from dace.sdfg.scope import is_devicelevel_gpu, is_in_scope from dace.sdfg.validation import validate_memlet_data -from typing import TYPE_CHECKING, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union if TYPE_CHECKING: from dace.codegen.targets.framecode import DaCeCodeGenerator @@ -44,10 +44,23 @@ def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''): if isinstance(v, data.Data): args[f'{prefix}->{k}'] = v + def _visit_pythonclass_members(struct: data.Structure, args: dict, root_name: str, prefix: str = ''): + for k, v in struct.members.items(): + member_path = f'{prefix}.{k}' if prefix else k + if isinstance(v, data.Structure): + _visit_pythonclass_members(v, args, root_name, member_path) + elif isinstance(v, data.ContainerArray): + _visit_pythonclass_members(v.stype, args, root_name, member_path) + + if isinstance(v, (data.Array, data.Scalar)): + args[cpp.pyobject_member_expr(root_name, member_path, v)] = v + # Keeps track of generated connectors, so we know how to access them in nested scopes args = dict(arglist) for name, arg_type in arglist.items(): - if isinstance(arg_type, data.Structure): + if isinstance(arg_type, data.PythonClass): + _visit_pythonclass_members(arg_type, args, name) + elif isinstance(arg_type, data.Structure): desc = sdfg.arrays[name] _visit_structure(arg_type, args, name) elif isinstance(arg_type, data.ContainerArray): @@ -56,6 +69,15 @@ def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''): if isinstance(desc, data.Structure): _visit_structure(desc, args, name) + for name in sdfg.arrays.keys(): + desc = sdfg.arrays[name] + if '.' not in name or not isinstance(desc, (data.Array, data.Scalar)): + continue + root_name, member_path = name.split('.', 1) + root_desc = sdfg.arrays.get(root_name) + if isinstance(root_desc, data.PythonClass): + args[cpp.pyobject_member_expr(root_name, member_path, desc)] = desc + for name, arg_type in args.items(): if isinstance(arg_type, data.Scalar): # GPU global memory is only accessed via pointers @@ -72,6 +94,8 @@ def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''): self._dispatcher.defined_vars.add(name, DefinedType.StreamArray, arg_type.as_arg(name='')) else: self._dispatcher.defined_vars.add(name, DefinedType.Stream, arg_type.as_arg(name='')) + elif isinstance(arg_type, data.PythonClass): + self._dispatcher.defined_vars.add(name, DefinedType.Object, arg_type.dtype.ctype) elif isinstance(arg_type, data.Structure): self._dispatcher.defined_vars.add(name, DefinedType.Pointer, arg_type.dtype.ctype) else: @@ -705,6 +729,15 @@ def _emit_copy( src_nodedesc = src_node.desc(sdfg) dst_nodedesc = dst_node.desc(sdfg) + if (write and isinstance(dst_nodedesc, data.Scalar) and '.' in dst_node.data + and isinstance(sdfg.arrays[dst_node.data.split('.')[0]], data.PythonClass)): + self._emit_pythonclass_scalar_setter( + dst_node.data, dst_nodedesc.dtype.ctype, + self._pythonclass_scalar_source_expr(sdfg, + cfg.nodes()[state_id], edge, src_node, dst_node), stream, cfg, + state_id, [src_node, dst_node]) + return + if write: vconn = self.ptr(dst_node.data, dst_nodedesc, sdfg) ctype = dst_nodedesc.dtype.ctype @@ -714,6 +747,11 @@ def _emit_copy( # Setting a reference if isinstance(dst_nodedesc, data.Reference) and orig_vconn == 'set': + if '.' in dst_node.data and isinstance(sdfg.arrays[dst_node.data.split('.')[0]], data.PythonClass): + self._emit_pythonclass_array_reference_set(sdfg, + cfg.nodes()[state_id], edge, src_node, dst_node, + src_nodedesc, stream, cfg, state_id) + return srcptr = self.ptr(src_node.data, src_nodedesc, sdfg) defined_type, _ = self._dispatcher.defined_vars.get(srcptr) stream.write( @@ -795,6 +833,10 @@ def _emit_copy( state_dfg: SDFGState = cfg.nodes()[state_id] + if (isinstance(dst_nodedesc, data.Reference) and edge.dst_conn == 'set' and '.' in dst_node.data + and isinstance(sdfg.arrays[dst_node.data.split('.')[0]], data.PythonClass)): + return + copy_shape, src_strides, dst_strides, src_expr, dst_expr = cpp.memlet_copy_to_absolute_strides( self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node) @@ -1043,8 +1085,14 @@ def process_out_memlets(self, is_scalar = True # Pointer to pointer assignment is_stream = isinstance(sdfg.arrays[memlet.data], data.Stream) is_refset = isinstance(sdfg.arrays[memlet.data], data.Reference) and dst_edge.dst_conn == 'set' + is_pythonclass_scalar = (is_scalar and '.' in memlet.data + and isinstance(sdfg.arrays[memlet.data], data.Scalar) + and isinstance(sdfg.arrays[memlet.data.split('.')[0]], data.PythonClass)) if (is_scalar and not memlet.dynamic and not is_stream) or is_refset: + if (is_refset and '.' in memlet.data + and isinstance(sdfg.arrays[memlet.data.split('.')[0]], data.PythonClass)): + continue out_local_name = " __" + uconn in_local_name = uconn if not locals_defined: @@ -1063,6 +1111,11 @@ def process_out_memlets(self, # which we skip since the memlets are references continue desc = sdfg.arrays[memlet.data] + if is_pythonclass_scalar: + write_expr = self._emit_pythonclass_scalar_setter_expr(memlet.data, desc.dtype.ctype, + in_local_name) + result.write(write_expr, cfg, state_id, node) + continue ptrname = codegen.ptr(memlet.data, desc, sdfg) is_global = desc.lifetime in (dtypes.AllocationLifetime.Global, dtypes.AllocationLifetime.Persistent, @@ -1113,6 +1166,41 @@ def make_ptr_assignment(self, src_expr, src_dtype, dst_expr, dst_dtype, codegen= dst_expr = codegen.make_ptr_vector_cast(dst_expr, dst_dtype, src_dtype, True, DefinedType.Pointer) return f"{dst_expr} = {src_expr};" + def _emit_pythonclass_array_reference_set(self, sdfg: SDFG, state_dfg: SDFGState, + edge: MultiConnectorEdge[mmlt.Memlet], src_node: nodes.AccessNode, + dst_node: nodes.AccessNode, src_nodedesc: data.Data, stream: CodeIOStream, + cfg: ControlFlowRegion, state_id: int) -> None: + copy_shape, src_strides, _, src_expr, _ = cpp.memlet_copy_to_absolute_strides( + self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node) + attr_path = dst_node.data.split('.', 1)[1] + root_name = dst_node.data.split('.', 1)[0] + shape = ', '.join(cpp.sym2cpp(s) for s in copy_shape) + strides = ', '.join(cpp.sym2cpp(s * src_nodedesc.dtype.bytes) for s in src_strides) + stream.write( + f'''{{ + Py_ssize_t __shape[{len(copy_shape)}] = {{{shape}}}; + Py_ssize_t __strides[{len(copy_shape)}] = {{{strides}}}; + dace_set_pyobject_attr_array<{src_nodedesc.dtype.ctype}>({root_name}, "{attr_path}", {src_expr}, __shape, + __strides, {len(copy_shape)}); + }}''', cfg, state_id, [src_node, dst_node]) + + def _pythonclass_scalar_source_expr(self, sdfg: SDFG, state_dfg: SDFGState, edge: MultiConnectorEdge[mmlt.Memlet], + src_node: nodes.AccessNode, dst_node: nodes.AccessNode) -> str: + _, _, _, src_expr, _ = cpp.memlet_copy_to_absolute_strides(self._dispatcher, sdfg, state_dfg, edge, src_node, + dst_node) + return f'*({src_expr})' + + def _emit_pythonclass_scalar_setter(self, dst_data: str, ctype: str, value_expr: str, stream: CodeIOStream, + cfg: ControlFlowRegion, state_id: int, + nodes_to_track: Sequence[nodes.Node]) -> None: + stream.write(self._emit_pythonclass_scalar_setter_expr(dst_data, ctype, value_expr), cfg, state_id, + nodes_to_track) + + def _emit_pythonclass_scalar_setter_expr(self, dst_data: str, ctype: str, value_expr: str) -> str: + attr_path = dst_data.split('.', 1)[1] + root_name = dst_data.split('.', 1)[0] + return f'dace_set_pyobject_attr<{ctype}>({root_name}, "{attr_path}", static_cast<{ctype}>({value_expr}));' + def memlet_view_ctor(self, sdfg: SDFG, memlet: mmlt.Memlet, dtype, is_output: bool) -> str: memlet_params = [] @@ -1288,6 +1376,12 @@ def memlet_definition(self, memlet_type = ctypedef result += "{} &{} = {};".format(memlet_type, local_name, expr) defined = DefinedType.Stream + elif var_type == DefinedType.Object: + if output: + result += "{} {};".format(ctypedef, local_name) + else: + result += "{} &{} = {};".format(ctypedef, local_name, expr) + defined = DefinedType.Object else: raise TypeError("Unknown variable type: {}".format(var_type)) diff --git a/dace/config_schema.yml b/dace/config_schema.yml index 7cd8979d7a..8e2cc22a5b 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -593,6 +593,33 @@ required: Raise all errors out of nested function parsing contexts instead of trying to create a callback implicitly. + raise_statements: + type: str + title: Schedule-tree raise handling + default: support + description: > + Controls how ``raise`` statements are handled by the direct + schedule-tree frontend. ``support`` lowers directly supported + exception classes to ``RaiseNode`` and falls back to Python + callbacks for unsupported dynamic cases. ``ignore_dynamic`` + keeps directly supported ``RaiseNode`` cases but skips + dynamic raises that would otherwise require a callback. + ``ignore_all`` skips all ``raise`` statements. + + runtime_negative_indices: + type: bool + title: Runtime negative-index wrapping + default: false + description: > + If enabled, array indices whose sign cannot be proven + nonnegative are normalized with symbolic runtime accessors + during frontend lowering. Uncertain scalar element indices + use ``pyindex(ind, size)``, while uncertain slice bounds use + conditional negative-offset normalization so positive bounds + such as ``stop == size`` keep Python semantics. This is + disabled by default; definitely negative cases are still + normalized statically during preprocessing. + verbose_errors: type: bool title: Show preprocessed AST on parsing errors diff --git a/dace/data/__init__.py b/dace/data/__init__.py index b87607c9fc..a9253dafa3 100644 --- a/dace/data/__init__.py +++ b/dace/data/__init__.py @@ -27,6 +27,7 @@ StructureReference, ContainerArrayReference, ) +from dace.data.pydata import PythonClass, PythonDict, PythonList, PythonTuple # Import prod from utils and expose as _prod for backward compatibility from dace.utils import prod as _prod @@ -86,6 +87,10 @@ 'ArrayReference', 'StructureReference', 'ContainerArrayReference', + 'PythonList', + 'PythonTuple', + 'PythonDict', + 'PythonClass', # Tensor support 'TensorIterationTypes', 'TensorAssemblyType', diff --git a/dace/data/core.py b/dace/data/core.py index c19a221b2c..de3239813e 100644 --- a/dace/data/core.py +++ b/dace/data/core.py @@ -11,7 +11,7 @@ import dataclasses from collections import OrderedDict -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any, ClassVar, Dict, List, Set, Tuple, Type, Union, get_origin, get_type_hints import numpy as np import sympy as sp @@ -43,6 +43,81 @@ def _arrays_from_json(obj, context=None): return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj) +def infer_structured_class_members(cls: Type[Any], **overrides) -> Dict[str, Any]: + """Infer typed members for a dataclass-like Python class. + + This helper is intentionally conservative: it uses dataclass fields when + available, otherwise it requires class-level annotations that can be turned + into DaCe data descriptors. + """ + if not isinstance(cls, type): + raise TypeError(f'{cls} is not a class type') + + from dace.data.creation import create_datadescriptor # Avoid import cycle + + members: Dict[str, Any] = {} + if dataclasses.is_dataclass(cls): + for field in dataclasses.fields(cls): + if dataclasses.is_dataclass(field.type): + members[field.name] = Structure.from_dataclass(field.type) + else: + members[field.name] = create_datadescriptor(field.type) + else: + try: + annotations = get_type_hints(cls) + except Exception: + annotations = dict(getattr(cls, '__annotations__', {}) or {}) + + for field_name, annotation in annotations.items(): + if field_name == 'return' or field_name.startswith('__'): + continue + if get_origin(annotation) is ClassVar: + continue + members[field_name] = create_datadescriptor(annotation) + + members.update(overrides) + if not members: + raise TypeError(f'{cls} does not expose a supported typed field layout') + return members + + +def infer_structured_object_members(obj: Any, **overrides) -> Dict[str, Any]: + """Infer typed members for a live Python object instance. + + Class-level typed fields are used when available, then instance attributes + with values that can be converted to DaCe descriptors refine or extend the + result. + """ + if isinstance(obj, type): + raise TypeError(f'{obj} is a class type, not an instance') + + from dace.data.creation import create_datadescriptor # Avoid import cycle + + members: Dict[str, Any] = {} + try: + members.update(infer_structured_class_members(type(obj))) + except TypeError: + pass + + try: + instance_items = vars(obj).items() + except TypeError: + instance_items = () + + for field_name, value in instance_items: + if field_name.startswith('__'): + continue + try: + members[field_name] = create_datadescriptor(value) + except (TypeError, ValueError): + continue + + members.update(overrides) + if not members: + raise TypeError(f'{obj} does not expose a supported typed field layout') + return members + + @make_properties class Data: """ Data type descriptors that can be used as references to memory. @@ -1010,6 +1085,11 @@ def from_dataclass(cls, **overrides) -> 'Structure': :param cls: The dataclass to convert. :param overrides: Optional overrides for the structure fields. :return: A Structure data descriptor. + + The resulting descriptor assumes a fixed field layout that can be + marshalled to code generation as a C struct. Frontends should keep + values on the ``PythonClass`` path instead when code may reassign + non-array fields or create new fields dynamically. """ members = {} for field in dataclasses.fields(cls): @@ -1022,6 +1102,18 @@ def from_dataclass(cls, **overrides) -> 'Structure': members.update(overrides) return Structure(members, name=cls.__name__) + @staticmethod + def from_class(cls, **overrides) -> 'Structure': + """Create a Structure descriptor from a conservatively typed Python class. + + This helper is for classes whose field layout is treated as a marshalled C struct. + If frontend behavior depends on the object remaining a Python reference, + such as non-array field reassignment or dynamic field creation, use + ``dace.data.PythonClass`` instead. + """ + members = infer_structured_class_members(cls, **overrides) + return Structure(members, name=cls.__name__) + @property def total_size(self): return -1 diff --git a/dace/data/creation.py b/dace/data/creation.py index e8eab97916..4514a39975 100644 --- a/dace/data/creation.py +++ b/dace/data/creation.py @@ -18,7 +18,7 @@ ArrayLike = Any from dace import dtypes, symbolic -from dace.data.core import Array, Data, Scalar +from dace.data.core import Array, Data, Scalar, Structure def create_datadescriptor(obj, no_custom_desc=False): @@ -92,6 +92,15 @@ def create_datadescriptor(obj, no_custom_desc=False): shape=interface['shape'], strides=(tuple(s // itemsize for s in interface['strides']) if interface['strides'] else None), storage=storage) + elif not isinstance(obj, type) and hasattr(obj, '__array__'): + try: + return create_datadescriptor(np.asarray(obj), no_custom_desc=True) + except Exception: + pass + elif isinstance(obj, dict): + from dace.data.pydata import infer_python_dict_descriptor_from_value + return infer_python_dict_descriptor_from_value( + obj, lambda value: create_datadescriptor(value, no_custom_desc=no_custom_desc), transient=False) elif isinstance(obj, (list, tuple)): # Lists and tuples are cast to numpy obj = np.array(obj) @@ -122,6 +131,11 @@ def create_datadescriptor(obj, no_custom_desc=False): return Scalar(dtypes.pointer(dtypes.typeclass(None))) elif isinstance(obj, str) or obj is str: return Scalar(dtypes.string) + elif isinstance(obj, type): + try: + return Structure.from_class(obj) + except TypeError: + pass elif callable(obj): # Cannot determine return value/argument types from function object return Scalar(dtypes.callback(None)) diff --git a/dace/data/pydata.py b/dace/data/pydata.py new file mode 100644 index 0000000000..8b47d7d894 --- /dev/null +++ b/dace/data/pydata.py @@ -0,0 +1,300 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +"""Python-native data descriptors. + +These descriptors are primarily used by the direct schedule-tree Python +frontend. They intentionally distinguish between: + +- ``Structure`` for classes whose fully known layout can be marshalled to code + generation as a by-value C struct, and +- ``PythonClass`` for analyzable Python objects that stay on the Python-object + path and are passed by reference, currently through the nanobind object + boundary. + +That distinction is semantic, not cosmetic. Once frontend behavior depends on +Python object identity, such as reassigning a field to a new object or +creating new fields dynamically, the value must remain a ``PythonClass``. +""" +# TODO: These classes are incomplete, they require more support on the SDFG/connector level and in code generation +# (bindings, C++ codegen, etc.). They are currently only used for the Schedule Tree-based Python frontend. + +import copy +from dataclasses import is_dataclass +from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Type + +from dace import dtypes +from dace.data.core import Array, Data, Scalar, Structure, infer_structured_class_members +from dace.properties import NestedDataClassProperty, make_properties + + +def _clone_descriptor(descriptor: Data) -> Data: + return copy.deepcopy(descriptor) + + +def _pyobject_descriptor(*, transient: bool) -> Scalar: + return Scalar(dtypes.pyobject(), transient=transient) + + +def _normalize_descriptor(descriptor: Optional[Any], *, transient: bool) -> Data: + if descriptor is None: + return _pyobject_descriptor(transient=transient) + if isinstance(descriptor, dtypes.typeclass): + descriptor = Scalar(descriptor, transient=transient) + elif not isinstance(descriptor, Data): + raise TypeError(f'Unsupported nested data descriptor type: {type(descriptor)}') + else: + descriptor = _clone_descriptor(descriptor) + descriptor.transient = transient + return descriptor + + +def descriptors_equivalent(left: Data, right: Data) -> bool: + if type(left) is not type(right): + return False + try: + return left.is_equivalent(right) + except Exception: + return left == right + + +def merge_python_dict_component_descriptors(descriptors: Iterable[Optional[Data]], *, transient: bool) -> Data: + merged: Optional[Data] = None + for descriptor in descriptors: + if descriptor is None: + return _pyobject_descriptor(transient=transient) + candidate = _normalize_descriptor(descriptor, transient=transient) + if merged is None: + merged = candidate + continue + if not descriptors_equivalent(merged, candidate): + return _pyobject_descriptor(transient=transient) + return merged or _pyobject_descriptor(transient=transient) + + +def infer_python_dict_descriptor_from_value(value: Mapping[Any, Any], + descriptor_factory: Callable[[Any], Data], + *, + transient: bool = False) -> 'PythonDict': + key_descriptors = [] + value_descriptors = [] + for key, mapped_value in value.items(): + try: + key_descriptor = descriptor_factory(key) + except Exception: + key_descriptor = None + try: + value_descriptor = descriptor_factory(mapped_value) + except Exception: + value_descriptor = None + key_descriptors.append(key_descriptor) + value_descriptors.append(value_descriptor) + return PythonDict(merge_python_dict_component_descriptors(key_descriptors, transient=transient), + merge_python_dict_component_descriptors(value_descriptors, transient=transient), + transient=transient) + + +def python_dataclass_descriptor(cls: Type[Any], *, by_value: bool = False, **overrides) -> Data: + """Create the canonical descriptor for a Python dataclass. + + ``by_value=True`` selects ``Structure`` and is intended for classes whose + full field layout is known and safe to marshal as a by-value C struct. + ``by_value=False`` keeps the descriptor on the ``PythonClass`` path, which + preserves analyzable member information while still treating the value as a + Python object reference. + + In particular, frontends should prefer ``PythonClass`` whenever code may + rebind a non-array field to a different object or create new fields at + runtime, because those operations cannot be represented as mutation of a + fixed by-value struct layout. + """ + if by_value: + return Structure.from_dataclass(cls, **overrides) + return PythonClass.from_dataclass(cls, **overrides) + + +@make_properties +class PythonList(Array): + """Represents a native Python list argument.""" + + def __init__(self, dtype: Any = dtypes.pyobject(), shape: Sequence[Any] = (1, ), **kwargs): + super().__init__(dtype=dtype, shape=shape, **kwargs) + + def as_arg(self, with_types: bool = True, for_call: bool = False, name: str = None): + if not with_types or for_call: + return name + return f'nb::list {name}' + + +@make_properties +class PythonTuple(Array): + """Represents a native Python tuple argument.""" + + def __init__(self, dtype: Any = dtypes.pyobject(), shape: Sequence[Any] = (1, ), **kwargs): + super().__init__(dtype=dtype, shape=shape, **kwargs) + + def as_arg(self, with_types: bool = True, for_call: bool = False, name: str = None): + if not with_types or for_call: + return name + return f'nb::tuple {name}' + + +@make_properties +class PythonDict(Data): + """Represents a native Python dictionary with a uniform key and value type. + + This is intentionally a frontend-facing stub descriptor for now: it carries + typed key/value metadata for analysis and schedule-tree lowering, but it is + not a promise of full first-class mapping code generation support. + """ + + key_type = NestedDataClassProperty(default=None, allow_none=True) + value_type = NestedDataClassProperty(default=None, allow_none=True) + + def _transient_setter(self, value): + self._transient = value + if self.key_type is not None: + self.key_type.transient = value + if self.value_type is not None: + self.value_type.transient = value + + def __init__(self, + key_type: Optional[Data] = None, + value_type: Optional[Data] = None, + transient: bool = False, + storage=dtypes.StorageType.Default, + location=None, + lifetime=dtypes.AllocationLifetime.Scope, + debuginfo=None): + self.key_type = _normalize_descriptor(key_type, transient=transient) + self.value_type = _normalize_descriptor(value_type, transient=transient) + super().__init__(dtypes.pyobject(), (1, ), transient, storage, location, lifetime, debuginfo) + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'PythonDict': + raise TypeError('Invalid data type') + + ret = PythonDict() + from dace import serialize + serialize.set_properties_from_json(ret, json_obj, context=context) + return ret + + def __repr__(self): + return f'PythonDict(key_type={self.key_type}, value_type={self.value_type})' + + def clone(self): + return PythonDict(self.key_type, + self.value_type, + transient=self.transient, + storage=self.storage, + location=self.location, + lifetime=self.lifetime, + debuginfo=self.debuginfo) + + @property + def strides(self): + return [1] + + @property + def total_size(self): + return 1 + + @property + def offset(self): + return [0] + + @property + def start_offset(self): + return 0 + + @property + def alignment(self): + return 0 + + @property + def optional(self) -> bool: + return False + + @property + def pool(self) -> bool: + return False + + @property + def may_alias(self) -> bool: + return False + + @property + def free_symbols(self): + result = set() + result |= self.key_type.free_symbols + result |= self.value_type.free_symbols + return result + + def is_equivalent(self, other): + return isinstance(other, PythonDict) and descriptors_equivalent( + self.key_type, other.key_type) and descriptors_equivalent(self.value_type, other.value_type) + + def as_arg(self, with_types: bool = True, for_call: bool = False, name: str = None): + if not with_types or for_call: + return name + return f'nb::dict {name}' + + def as_python_arg(self, with_types: bool = True, for_call: bool = False, name: str = None): + if not with_types or for_call: + return name + return f'{name}: dict' + + +@make_properties +class PythonClass(Structure): + """Represents an analyzable Python class with typed fields. + + Unlike ``Structure``, this descriptor keeps the value on the Python-object + path even when member types are known. Code generation therefore treats it + as a referenced Python object rather than a by-value C struct. + + This is the correct representation when the program may rely on Python + object identity, such as scalar or other non-array field rebinding, + descriptor-backed behavior, or dynamic field creation. + """ + + def __init__(self, members, name: str = 'PythonClass', **kwargs): + super().__init__(members=members, name=name, **kwargs) + self.dtype = dtypes.pyobject() + + @staticmethod + def from_json(json_obj, context=None): + """Deserialize PythonClass from JSON, handling both 'PythonClass' and 'Structure' types.""" + if json_obj['type'] not in ('Structure', 'PythonClass'): + raise TypeError("Invalid data type") + + # Create dummy object + ret = PythonClass({}) + from dace import serialize + serialize.set_properties_from_json(ret, json_obj, context=context) + + return ret + + @classmethod + def from_dataclass(cls, dataclass_type: Type[Any], **overrides) -> 'PythonClass': + if not is_dataclass(dataclass_type): + raise TypeError(f'{dataclass_type} is not a dataclass') + return cls.from_class(dataclass_type, **overrides) + + @classmethod + def from_class(cls, class_type: Type[Any], **overrides) -> 'PythonClass': + members = infer_structured_class_members(class_type, **overrides) + return cls(members, name=class_type.__name__) + + def clone(self): + return PythonClass(self.members, + self.name, + transient=self.transient, + storage=self.storage, + location=self.location, + lifetime=self.lifetime, + debuginfo=self.debuginfo) + + def as_python_arg(self, with_types: bool = True, for_call: bool = False, name: str = None): + if not with_types or for_call: + return name + return f'{name}: object' diff --git a/dace/dtypes.py b/dace/dtypes.py index bc4c35cc4b..5d43c22f67 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -298,6 +298,8 @@ def __init__(self, wrapped_type, typename=None): try: if wrapped_type == "bool": wrapped_type = numpy.bool_ + elif wrapped_type == "object": + wrapped_type = numpy.object_ else: wrapped_type = getattr(numpy, wrapped_type) except AttributeError: diff --git a/dace/frontend/common/op_repository.py b/dace/frontend/common/op_repository.py index 74924cba7b..32f4ccaf42 100644 --- a/dace/frontend/common/op_repository.py +++ b/dace/frontend/common/op_repository.py @@ -1,9 +1,15 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import itertools +from numbers import Number from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import numpy as np + +from dace import symbolic from dace.dtypes import paramdec MethodType = Callable[..., Tuple[str]] +_INFERENCE_MISSING = object() def _get_all_bases(class_or_name: Union[str, Type]) -> List[str]: @@ -33,6 +39,12 @@ class Replacements(object): _ufunc_rep: Dict[str, MethodType] = {} _method_rep: Dict[Tuple[str, str], MethodType] = {} _attr_rep: Dict[Tuple[str, str], MethodType] = {} + _dtype_rep: Dict[str, Callable] = {} # Lightweight descriptor inference (free functions) + _dtype_method_rep: Dict[Tuple[str, str], Callable] = {} # (classname, method) -> fn(self_desc, *a, **kw) + _dtype_method_self_rep: Dict[Tuple[str, str], Callable] = {} # (classname, method) -> fn(self_desc, *a, **kw) + _dtype_attr_rep: Dict[Tuple[str, str], Callable] = {} # (classname, attr) -> fn(self_desc) + _dtype_ufunc_rep: Dict[str, Callable] = {} # ufunc method -> fn(input_descs, ufunc_name, *a, **kw) + _dtype_op_rep: Dict[Tuple[Optional[str], Optional[str], str], Callable] = {} @staticmethod def get(name: str): @@ -83,6 +95,80 @@ def get_attribute(class_or_name: Union[str, Type], attr_name: str): return Replacements._attr_rep[(classname, attr_name)] return None + @staticmethod + def get_descriptor_inference(name: str): + """Returns a lightweight descriptor-inference function for a named call, or None.""" + return Replacements._dtype_rep.get(name, None) + + @staticmethod + def get_method_descriptor_inference(class_or_name: Union[str, Type], method_name: str): + """Returns a descriptor-inference function for a method call, or None.""" + for classname in _get_all_bases(class_or_name): + if (classname, method_name) in Replacements._dtype_method_rep: + return Replacements._dtype_method_rep[(classname, method_name)] + return None + + @staticmethod + def get_method_self_descriptor_inference(class_or_name: Union[str, Type], method_name: str): + """Returns a self-mutating inference function for a method call, or None.""" + for classname in _get_all_bases(class_or_name): + if (classname, method_name) in Replacements._dtype_method_self_rep: + return Replacements._dtype_method_self_rep[(classname, method_name)] + return None + + @staticmethod + def get_attribute_descriptor_inference(class_or_name: Union[str, Type], attr_name: str): + """Returns a descriptor-inference function for an attribute access, or None.""" + for classname in _get_all_bases(class_or_name): + if (classname, attr_name) in Replacements._dtype_attr_rep: + return Replacements._dtype_attr_rep[(classname, attr_name)] + return None + + def get_ufunc_descriptor_inference(ufunc_method: Optional[str] = None): + """Returns a descriptor-inference function for a NumPy ufunc call or method, or None.""" + key = ufunc_method or 'ufunc' + return Replacements._dtype_ufunc_rep.get(key, None) + + @staticmethod + def get_operator_descriptor_inference(optype: str, + left_operand: Any = _INFERENCE_MISSING, + right_operand: Any = _INFERENCE_MISSING): + """Returns a descriptor-inference function for an operator, or None.""" + if left_operand is _INFERENCE_MISSING and right_operand is _INFERENCE_MISSING: + return Replacements._dtype_op_rep.get((None, None, optype), None) + + left_types = _get_inference_operand_types(left_operand) + if right_operand is _INFERENCE_MISSING: + for left_type in left_types: + if (left_type, None, optype) in Replacements._dtype_op_rep: + return Replacements._dtype_op_rep[(left_type, None, optype)] + return Replacements._dtype_op_rep.get((None, None, optype), None) + + right_types = _get_inference_operand_types(right_operand) + for left_type, right_type in itertools.product(left_types, right_types): + if (left_type, right_type, optype) in Replacements._dtype_op_rep: + return Replacements._dtype_op_rep[(left_type, right_type, optype)] + + return Replacements._dtype_op_rep.get((None, None, optype), None) + + +def _get_inference_operand_types(operand: Any) -> List[Optional[str]]: + if operand is _INFERENCE_MISSING: + return [None] + if isinstance(operand, (bool, np.bool_)): + return ['BoolConstant'] + if isinstance(operand, Number): + return ['NumConstant'] + if symbolic.issymbolic(operand): + return ['symbol'] + if isinstance(operand, list): + return ['ListLiteral'] + if isinstance(operand, tuple): + return ['TupleLiteral'] + if isinstance(operand, str): + return ['StringLiteral'] + return _get_all_bases(type(operand)) + @paramdec def replaces(func: Callable[..., Tuple[str]], name: str): @@ -162,3 +248,116 @@ def replaces_attribute(func: Callable[..., Tuple[str]], classname: str, attr_nam """ Replacements._attr_rep[(classname, attr_name)] = func return func + + +@paramdec +def infers_descriptor(func: Callable, name: str): + """ + Registers a lightweight descriptor-inference function for a named call. + + The function receives ``(input_descriptors, *args, **kwargs)`` where + *input_descriptors* maps array-argument names to their + :class:`dace.data.Data` descriptors and the remaining arguments are + compile-time values (numbers, symbolic expressions, strings, or + ``None`` when static evaluation failed). It may return a single + :class:`dace.data.Data` descriptor, a tuple or list of descriptors + for structured multi-result calls, or ``None`` if inference is not + possible. Empty tuples or lists denote a successful zero-output + inference. + + :param func: The inference function. + :param name: Fully-qualified function name (e.g. ``'numpy.sum'``). + """ + Replacements._dtype_rep[name] = func + return func + + +@paramdec +def infers_method_descriptor(func: Callable, classname: str, method_name: str): + """ + Registers descriptor inference for a method call (e.g. ``a.sum()``). + + The function receives ``(self_descriptor, *args, **kwargs)`` where + *self_descriptor* is the :class:`dace.data.Data` descriptor of the + object the method is called on. It may return a single + :class:`dace.data.Data` descriptor, a tuple or list of descriptors, + or ``None``. + + :param func: The inference function. + :param classname: Data-descriptor class name (e.g. ``'Array'``). + :param method_name: Method name (e.g. ``'sum'``). + """ + Replacements._dtype_method_rep[(classname, method_name)] = func + return func + + +@paramdec +def infers_method_self_descriptor(func: Callable, classname: str, method_name: str): + """ + Registers descriptor inference for a method call that mutates ``self``. + + The function receives ``(self_descriptor, *args, **kwargs)`` where + *self_descriptor* is the :class:`dace.data.Data` descriptor of the + object the method is called on. It may return a single + :class:`dace.data.Data` descriptor, a tuple or list of descriptors, + or ``None``. + + :param func: The inference function. + :param classname: Data-descriptor class name (e.g. ``'Array'``). + :param method_name: Method name (e.g. ``'sum'``). + """ + Replacements._dtype_method_self_rep[(classname, method_name)] = func + return func + + +@paramdec +def infers_attribute_descriptor(func: Callable, classname: str, attr_name: str): + """ + Registers descriptor inference for an attribute access (e.g. ``a.T``). + + The function receives ``(self_descriptor,)`` and returns either a + single :class:`dace.data.Data` descriptor, a tuple or list of + descriptors, or ``None``. + + :param func: The inference function. + :param classname: Data-descriptor class name (e.g. ``'Array'``). + :param attr_name: Attribute name (e.g. ``'T'``). + """ + Replacements._dtype_attr_rep[(classname, attr_name)] = func + return func + + +@paramdec +def infers_ufunc_descriptor(func: Callable, name: str): + """ + Registers lightweight descriptor inference for a NumPy ufunc call or ufunc method. + + The function receives ``(input_descriptors, ufunc_name, *args, **kwargs)`` and may return a + single :class:`dace.data.Data` descriptor, a tuple or list of descriptors, or ``None``. + + :param func: The inference function. + :param name: ``'ufunc'`` for a direct ufunc call or the ufunc method name, such as ``'reduce'``. + """ + Replacements._dtype_ufunc_rep[name] = func + return func + + +@paramdec +def infers_operator_descriptor(func: Callable, + optype: str, + classname: Optional[str] = None, + otherclass: Optional[str] = None): + """ + Registers descriptor inference for an operator (e.g. ``A @ B`` or ``-A``). + + The function receives one or more operand descriptors, depending on + the AST operator form being inferred, and returns a + :class:`dace.data.Data` descriptor for the result, or ``None``. + + :param func: The inference function. + :param optype: AST operator name (e.g. ``'MatMult'``). + :param classname: Optional left operand category name. + :param otherclass: Optional right operand category name. + """ + Replacements._dtype_op_rep[(classname, otherclass, optype)] = func + return func diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index 82206de183..012f3e2280 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -456,6 +456,13 @@ def generic_visit(self, node): setattr(node, field, new_node) return node + if isinstance(node, list): + return [copy_tree(n) for n in node] + if not isinstance(node, ast.AST): + import warnings + warnings.warn(f'copy_tree expected an AST node or list of AST nodes, got {type(node).__name__}', stacklevel=2) + return copy.deepcopy(node) + return Copier().visit(node) diff --git a/dace/frontend/python/common.py b/dace/frontend/python/common.py index 2dc77b48b3..d54000b346 100644 --- a/dace/frontend/python/common.py +++ b/dace/frontend/python/common.py @@ -2,11 +2,14 @@ import ast import collections from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, TYPE_CHECKING, Union from dace import data from dace.sdfg.sdfg import SDFG +if TYPE_CHECKING: + from dace.sdfg.analysis.schedule_tree.treenodes import ScheduleTreeRoot + class DaceSyntaxError(Exception): @@ -58,6 +61,18 @@ def __gt__(self, other) -> bool: return self.value > str(other) +@dataclass(frozen=True) +class ListLiteral: + """A list literal found in a parsed DaCe program.""" + value: Tuple[Any, ...] + + +@dataclass(frozen=True) +class TupleLiteral: + """A tuple literal found in a parsed DaCe program.""" + value: Tuple[Any, ...] + + class SDFGConvertible(object): """ A mixin that defines the interface to annotate SDFG-convertible objects. @@ -119,6 +134,44 @@ def closure_resolver(self, return SDFGClosure() +class ScheduleTreeConvertible: + """ + A mixin that defines the interface to annotate schedule-tree-convertible + objects. + """ + + def __schedule_tree__(self, + *args, + lambda_bindings: Optional[Dict[str, ast.AST]] = None, + callable_bindings: Optional[Dict[str, Any]] = None, + **kwargs) -> 'ScheduleTreeRoot': + """ + Returns a schedule-tree representation of this object. + + :param args: Arguments or argument types that can be used for + specialization. + :param lambda_bindings: Optional lambda specializations propagated from + the caller. + :param callable_bindings: Optional callable specializations propagated + from the caller. + :param kwargs: Keyword arguments or argument types that can be used for + specialization. + :return: A schedule-tree root representing this object. + """ + raise NotImplementedError + + def __schedule_tree_signature__(self) -> Tuple[Sequence[str], Sequence[str]]: + """ + Returns the schedule-tree call signature represented by this object as + a sequence of all argument names that will be found in a call to this + object and a sequence of the constant argument names from the first + sequence. + + :return: A 2-tuple of (all arguments, constant arguments). + """ + raise NotImplementedError + + @dataclass class SDFGClosure: """ diff --git a/dace/frontend/python/memlet_parser.py b/dace/frontend/python/memlet_parser.py index 1be7d0d79b..d7fc9e9f31 100644 --- a/dace/frontend/python/memlet_parser.py +++ b/dace/frontend/python/memlet_parser.py @@ -5,11 +5,13 @@ from dataclasses import dataclass from dace import data, dtypes, subsets +from dace.config import Config from dace.frontend.python import astutils from dace.frontend.python.astutils import rname from dace.memlet import Memlet -from dace.symbolic import pystr_to_symbolic, SymbolicType +from dace.symbolic import IfExpr, SymbolicType, pyindex, pystr_to_symbolic from dace.frontend.python.common import DaceSyntaxError +from sympy.core.relational import Relational MemletType = Union[ast.Call, ast.Attribute, ast.Subscript, ast.Name] @@ -79,10 +81,59 @@ def _parse_dim_atom(das, atom): return result -def _fill_missing_slices(das, ast_ndslice, array, indices): +def _wrap_scalar_index(index_expr, extent): + # Scalar element access uses Python wraparound semantics, which codegen + # implements via pyindex(...)->py_mod(...). + try: + if (index_expr < 0) == True: + return index_expr + extent + except (TypeError, ValueError): + pass + + if not Config.get_bool('frontend', 'runtime_negative_indices'): + return index_expr + + if getattr(index_expr, 'is_Boolean', False): + return index_expr + + is_integer = getattr(index_expr, 'is_integer', None) + if is_integer is False: + return index_expr + + is_nonnegative = getattr(index_expr, 'is_nonnegative', None) + if is_nonnegative is True: + return index_expr + return pyindex(index_expr, extent) + + +def _wrap_slice_bound(index_expr, extent, *, inclusive_stop: bool): + # Slice bounds are different: positive stop=size must remain size rather + # than wrapping to zero, so we normalize only negative values. + try: + if (index_expr < 0) == True: + wrapped = index_expr + extent + else: + wrapped = index_expr + except (TypeError, ValueError): + wrapped = index_expr + + if wrapped is index_expr and Config.get_bool('frontend', 'runtime_negative_indices'): + if not getattr(index_expr, 'is_Boolean', False): + is_integer = getattr(index_expr, 'is_integer', None) + if is_integer is not False: + is_nonnegative = getattr(index_expr, 'is_nonnegative', None) + if is_nonnegative is not True: + wrapped = IfExpr(index_expr < 0, index_expr + extent, index_expr) + + if inclusive_stop: + return wrapped - 1 + return wrapped + + +def _fill_missing_slices(das, ast_ndslice, shape): # Filling ndslice with default values from array dimensions # if ranges not specified (e.g., of the form "A[:]") - ndslice = [None] * len(array.shape) + ndslice = [None] * len(shape) offsets = [] new_axes = [] arrdims: Dict[int, str] = {} @@ -99,20 +150,10 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): dim = ast.Name(id=dim) if isinstance(dim, tuple): - rb = _parse_dim_atom(das, dim[0] or 0) - re = _parse_dim_atom(das, dim[1] or array.shape[indices[idx]]) - 1 + dim_extent = shape[idx] + rb = _wrap_slice_bound(_parse_dim_atom(das, dim[0] or 0), dim_extent, inclusive_stop=False) + re = _wrap_slice_bound(_parse_dim_atom(das, dim[1] or dim_extent), dim_extent, inclusive_stop=True) rs = _parse_dim_atom(das, dim[2] or 1) - # NOTE: try/except for cases where rb/re are not symbols/numbers - try: - if (rb < 0) == True: - rb += array.shape[indices[idx]] - except (TypeError, ValueError): - pass - try: - if (re < 0) == True: - re += array.shape[indices[idx]] - except (TypeError, ValueError): - pass ndslice[idx] = (rb, re, rs) offsets.append(idx) idx += 1 @@ -124,30 +165,48 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): has_ellipsis = True remaining_dims = len(ast_ndslice) - num_new_axes - idx - 1 for j in range(idx, len(ndslice) - remaining_dims): - ndslice[j] = (0, array.shape[j] - 1, 1) + ndslice[j] = (0, shape[j] - 1, 1) idx += 1 new_idx += 1 - elif (dim is None or (isinstance(dim, ast.Constant) and dim.value is None)): + elif (dim is None or (isinstance(dim, ast.Constant) and dim.value is None) or inner_eval_ast(das, dim) is None): new_axes.append(new_idx) new_idx += 1 # NOTE: Do not increment idx here elif isinstance(dim, ast.Name) and isinstance(dim.id, (list, tuple)): # List/tuple literal - ndslice[idx] = (0, array.shape[idx] - 1, 1) - arrdims[indices[idx]] = dim.id + ndslice[idx] = (0, shape[idx] - 1, 1) + arrdims[idx] = dim.id idx += 1 new_idx += 1 - elif isinstance(dim, ast.Name) and isinstance(dim.id, slice): + elif isinstance(inner_eval_ast(das, dim), slice): # slice literal - rb, re, rs = dim.id.start, dim.id.stop, dim.id.step + resolved = inner_eval_ast(das, dim) + rb, re, rs = resolved.start, resolved.stop, resolved.step + if rb is None: + rb = 0 + if re is None: + re = shape[idx] + if rs is None: + rs = 1 + + dim_extent = shape[idx] + ndslice[idx] = (_wrap_slice_bound(rb, dim_extent, inclusive_stop=False), + _wrap_slice_bound(re, dim_extent, inclusive_stop=True), rs) + idx += 1 + new_idx += 1 + elif isinstance(dim, ast.Name) and dim.id in das and isinstance(das[dim.id], slice): + # compile-time slice object + rb, re, rs = das[dim.id].start, das[dim.id].stop, das[dim.id].step if rb is None: rb = 0 if re is None: - re = array.shape[indices[idx]] + re = shape[idx] if rs is None: rs = 1 - ndslice[idx] = (rb, re - 1, rs) + dim_extent = shape[idx] + ndslice[idx] = (_wrap_slice_bound(rb, dim_extent, inclusive_stop=False), + _wrap_slice_bound(re, dim_extent, inclusive_stop=True), rs) idx += 1 new_idx += 1 elif (isinstance(dim, ast.Name) and dim.id in das and isinstance(das[dim.id], data.Array)): @@ -157,7 +216,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): # Boolean array indexing if len(ast_ndslice) > 1: raise IndexError(f'Invalid indexing into array "{dim.id}". Only one boolean array is allowed.') - if tuple(desc.shape) != tuple(array.shape): + if tuple(desc.shape) != tuple(shape): raise IndexError(f'Invalid indexing into array "{dim.id}". ' 'Shape of boolean index must match original array.') elif desc.dtype in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, @@ -170,30 +229,33 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): if data._prod(desc.shape) == 1: # Special case: one-element array treated as scalar - ndslice[idx] = (dim.id, dim.id, 1) + scalar_expr = _wrap_scalar_index(pystr_to_symbolic(dim.id), shape[idx]) + ndslice[idx] = (scalar_expr, scalar_expr, 1) else: - ndslice[idx] = (0, array.shape[idx] - 1, 1) - arrdims[indices[idx]] = dim.id + ndslice[idx] = (0, shape[idx] - 1, 1) + arrdims[idx] = dim.id idx += 1 new_idx += 1 elif (isinstance(dim, ast.Name) and dim.id in das and isinstance(das[dim.id], data.Scalar)): - ndslice[idx] = (dim.id, dim.id, 1) + scalar_expr = _wrap_scalar_index(pystr_to_symbolic(dim.id), shape[idx]) + ndslice[idx] = (scalar_expr, scalar_expr, 1) idx += 1 new_idx += 1 else: r = pyexpr_to_symbolic(das, dim) - if (r < 0) == True: - r += array.shape[indices[idx]] + if getattr(r, 'is_Boolean', False) or getattr(r, 'is_Relational', False) or isinstance(r, Relational): + raise IndexError('Boolean expressions are not supported as scalar memlet indices') + r = _wrap_scalar_index(r, shape[idx]) ndslice[idx] = r idx += 1 new_idx += 1 # Extend slices to unspecified dimensions - for i in range(idx, len(array.shape)): - # ndslice[i] = (0, array.shape[idx] - 1, 1) + for i in range(idx, len(shape)): + # ndslice[i] = (0, shape[idx] - 1, 1) # idx += 1 - ndslice[i] = (0, array.shape[i] - 1, 1) + ndslice[i] = (0, shape[i] - 1, 1) offsets.append(i) return ndslice, offsets, new_axes, arrdims @@ -226,22 +288,26 @@ def parse_memlet_subset(array: data.Data, cnode = node ast_ndslices = astutils.subscript_to_ast_slice_recursive(cnode) offsets = list(range(len(array.shape))) + current_shape = list(array.shape) # Loop over nd-slices (A[i][j][k]...) subset_array = [] for idx, ast_ndslice in enumerate(ast_ndslices): - # Cut out dimensions that were indexed in the previous slice - narray = copy.deepcopy(array) - narray.shape = [s for i, s in enumerate(array.shape) if i in offsets] + # Each nested slice indices into the surviving dimensions of the + # previous slice, not the original array extents. + current_offsets = offsets # Loop over the N dimensions - ndslice, offsets, new_extra_dims, arrdims = _fill_missing_slices(das, ast_ndslice, narray, offsets) + ndslice, local_offsets, new_extra_dims, arrdims = _fill_missing_slices(das, ast_ndslice, current_shape) + local_subset = _ndslice_to_subset(ndslice) + offsets = [current_offsets[i] for i in local_offsets] + current_shape = [local_subset.size()[i] for i in local_offsets] if new_extra_dims and idx != (len(ast_ndslices) - 1): raise NotImplementedError('New axes only implemented for last slice') if arrdims and len(ast_ndslices) != 1: raise NotImplementedError('Array dimensions not implemented for consecutive subscripts') extra_dims = new_extra_dims - subset_array.append(_ndslice_to_subset(ndslice)) + subset_array.append(local_subset) subset = subset_array[0] diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 30fb537247..4606884e2a 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -17,8 +17,8 @@ from dace.config import Config from dace.frontend.common import op_repository as oprepo from dace.frontend.python import astutils -from dace.frontend.python.common import (DaceSyntaxError, SDFGClosure, SDFGConvertible, inverse_dict_lookup, - StringLiteral) +from dace.frontend.python.common import (DaceSyntaxError, ListLiteral, SDFGClosure, SDFGConvertible, TupleLiteral, + inverse_dict_lookup, StringLiteral) from dace.frontend.python.astutils import ExtNodeVisitor, ExtNodeTransformer from dace.frontend.python.astutils import rname from dace.frontend.python import nested_call, replacements, preprocessing @@ -1378,7 +1378,7 @@ def defined(self): try: from mpi4py import MPI result.update({k: v for k, v in self.globals.items() if isinstance(v, MPI.Comm)}) - except (ImportError, ModuleNotFoundError): + except (ImportError, ModuleNotFoundError, RuntimeError): pass return result @@ -1410,8 +1410,19 @@ def get_target_name(self, output_index: Optional[int] = None, default: Optional[ # return the name of the left-hand side of the assignment. if len(self.current_ast_stack) > 1 and isinstance(self.current_ast_stack[-2], ast.Assign): target = self.current_ast_stack[-2].targets[0] - if isinstance(target, ast.Tuple) and len(target.elts) > output_index: - candidate = self._get_name_from_node(target.elts[output_index]) + candidate = None + if isinstance(target, (ast.Tuple, ast.List)): + flat_targets = [] + pending_targets = list(target.elts) + while pending_targets: + current_target = pending_targets.pop(0) + if isinstance(current_target, (ast.Tuple, ast.List)): + pending_targets = list(current_target.elts) + pending_targets + else: + flat_targets.append(current_target) + + if len(flat_targets) > output_index: + candidate = self._get_name_from_node(flat_targets[output_index]) elif isinstance(target, (ast.Name, ast.Subscript, ast.Attribute)): candidate = self._get_name_from_node(target) @@ -3465,7 +3476,8 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): # Get targets (elts) and results elts = None results = None - if isinstance(node_target, (ast.Tuple, ast.List)): + unpack_target = isinstance(node_target, (ast.Tuple, ast.List)) + if unpack_target: elts = list(node_target.elts) else: elts = [node_target] @@ -3481,11 +3493,27 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): tuple_found = True break + assignment_output_index = 0 + + def _collect_assignment_results(value: ast.AST) -> List[Tuple[str, str]]: + nonlocal assignment_output_index + if isinstance(value, (ast.Tuple, ast.List)): + unpacked = [] + for element in value.elts: + unpacked.extend(_collect_assignment_results(element)) + return unpacked + + old_output_index = self.default_output_index + self.default_output_index = assignment_output_index + assignment_output_index += 1 + try: + return self._gettype(value) + finally: + self.default_output_index = old_output_index + results = [] - if isinstance(node.value, (ast.Tuple, ast.List)): - for i, n in enumerate(node.value.elts): - self.default_output_index = i - results.extend(self._gettype(n)) + if unpack_target and isinstance(node.value, (ast.Tuple, ast.List)): + results.extend(_collect_assignment_results(node.value)) self.default_output_index = 0 else: rval = self._gettype(node.value) @@ -3520,6 +3548,13 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): true_array = None visited_target = False + if isinstance(target, ast.Attribute) and name in defined_vars: + root_name = defined_vars[name] + root_desc = defined_arrays.get(root_name) + if isinstance(root_desc, data.PythonClass): + self._emit_pythonclass_attribute_assignment(node, root_name, '.'.join(tokens), result) + continue + if name in defined_vars: # Handle complex object assignment (e.g., A.flat[:]) if isinstance(target, ast.Subscript): # In case of nested subscripts, find the root AST node @@ -4660,6 +4695,88 @@ def parse_target(t: Union[ast.Name, ast.Subscript]): else: return return_names + def _emit_pythonclass_attribute_assignment(self, node: ast.AST, object_name: str, attribute_name: str, + value: Any) -> None: + if not attribute_name: + raise DaceSyntaxError(self, node, 'Expected a PythonClass attribute assignment target') + value_name = None + value_desc = None + member_desc = None + + if isinstance(value, str) and value in self.sdfg.arrays: + value_name = value + value_desc = self.sdfg.arrays[value] + if isinstance(value_desc, data.Array): + member_desc = data.Reference.view(value_desc) + elif isinstance(value_desc, data.Scalar): + member_desc = data.Scalar(value_desc.dtype) + elif value in self.sdfg.symbols or symbolic.issymbolic(value): + code_value = str(value) + symdtype = symbolic.symtype(value) if symbolic.issymbolic(value) else self.sdfg.symbols[value] + member_desc = data.Scalar(symdtype) + elif isinstance(value, tuple(dtypes.dtype_to_typeclass().keys())): + code_value = repr(value.item() if hasattr(value, 'item') else value) + member_desc = data.create_datadescriptor(value.item() if hasattr(value, 'item') else value) + else: + raise DaceSyntaxError( + self, node, f'Unsupported PythonClass assignment value "{value}". ' + 'Please assign a scalar or symbol.') + + state = self._add_state(f'pythonclass_attr_{node.lineno}') + self.last_block.set_default_lineinfo(self.current_lineinfo) + + self._ensure_pythonclass_member(object_name, attribute_name, member_desc) + + if isinstance(value_desc, data.Array): + ref_name = f'{object_name}.{attribute_name}' + value_read = state.add_read(value_name, debuginfo=self.current_lineinfo) + ref_write = state.add_write(ref_name, debuginfo=self.current_lineinfo) + ref_edge = state.add_edge(value_read, None, ref_write, 'set', Memlet.from_array(value_name, value_desc)) + ref_edge.data = align_memlet(state, ref_edge, dst=False) + return + + if value_desc is not None: + field_name = f'{object_name}.{attribute_name}' + value_read = state.add_read(value_name, debuginfo=self.current_lineinfo) + field_write = state.add_write(field_name, debuginfo=self.current_lineinfo) + field_edge = state.add_edge(value_read, None, field_write, None, Memlet.from_array(value_name, value_desc)) + field_edge.data = align_memlet(state, field_edge, dst=False) + return + + field_name = f'{object_name}.{attribute_name}' + value_tasklet = state.add_tasklet(f'pythonclass_attr_{node.lineno}', {}, {'__out'}, + f'__out = {code_value}', + language=dtypes.Language.Python, + side_effects=False, + debuginfo=self.current_lineinfo) + value_tasklet.add_out_connector('__out', member_desc.dtype, force=True) + + field_write = state.add_write(field_name, debuginfo=self.current_lineinfo) + field_edge = state.add_edge(value_tasklet, '__out', field_write, None, + Memlet.from_array(field_name, member_desc)) + field_edge.data = align_memlet(state, field_edge, dst=False) + + def _ensure_pythonclass_member(self, object_name: str, attribute_name: str, member_value_desc: data.Data) -> None: + root_desc = self.sdfg.arrays[object_name] + if not isinstance(root_desc, data.PythonClass): + raise DaceSyntaxError(self, None, f'Expected PythonClass root for "{object_name}"') + + member_desc = root_desc + tokens = attribute_name.split('.') + for token in tokens[:-1]: + if token not in member_desc.members: + raise DaceSyntaxError(self, None, f'Unknown PythonClass attribute path "{attribute_name}"') + member_desc = member_desc.members[token] + if isinstance(member_desc, data.ContainerArray): + member_desc = member_desc.stype + if not isinstance(member_desc, data.Structure): + raise DaceSyntaxError(self, None, f'Cannot create nested PythonClass field under "{token}"') + + leaf_name = tokens[-1] + if leaf_name not in member_desc.members: + member_desc.members[leaf_name] = copy.deepcopy(member_value_desc) + member_desc.members[leaf_name].transient = False + def _connect_pystate(self, tasklet: nodes.CodeNode, state: SDFGState, @@ -5147,6 +5264,11 @@ def visit_TypeAlias(self, node: TypeAlias): def _gettype(self, opnode: ast.AST) -> List[Tuple[str, str]]: """ Returns an operand and its type as a 2-tuple of strings. """ + if isinstance(opnode, ast.List): + return [(ListLiteral(tuple(self.visit(opnode))), ListLiteral)] + if isinstance(opnode, ast.Tuple): + return [(TupleLiteral(tuple(self.visit(opnode))), TupleLiteral)] + if isinstance(opnode, ast.AST): operands = self.visit(opnode) else: @@ -5190,6 +5312,7 @@ def _visit_op(self, node: Union[ast.UnaryOp, ast.BinOp, ast.BoolOp], op1: ast.AS if len(op1_parsed) > 1: raise DaceSyntaxError(self, op1, 'Operand cannot be a tuple') operand1, op1type = op1_parsed[0] + if op2 is not None: op2_parsed = self._gettype(op2) if len(op2_parsed) > 1: diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 3b9325c2f6..87c7a083a0 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -6,16 +6,19 @@ import os import sympy import sys -from typing import Any, Callable, Dict, List, Optional, Set, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Sequence, Tuple, Union from typing import get_origin, get_args import warnings from dace import data, dtypes, hooks, symbolic from dace.config import Config -from dace.frontend.python import (newast, common as pycommon, cached_program, preprocessing) +from dace.frontend.python import (newast, common as pycommon, cached_program, preprocessing, schedule_tree_frontend) from dace.sdfg import SDFG, utils as sdutils from dace.data import create_datadescriptor, Data +if TYPE_CHECKING: + from dace.sdfg.analysis.schedule_tree import treenodes as tn + try: import mpi4py from dace.sdfg.utils import distributed_compile @@ -45,6 +48,27 @@ def _get_cell_contents_or_none(cell): return None +def _collect_annotation_class_globals(annotation: Any) -> Dict[str, type]: + result: Dict[str, type] = {} + + origin = get_origin(annotation) + if origin is not None: + for arg in get_args(annotation): + result.update(_collect_annotation_class_globals(arg)) + return result + + if not isinstance(annotation, type): + return result + + try: + data.Structure.from_class(annotation) + except (TypeError, ValueError): + return result + + result[annotation.__name__] = annotation + return result + + def _get_locals_and_globals(f): """ Retrieves a list of local and global variables for the function ``f``. This is used to retrieve variables around and defined before @dace.programs for adding symbols and constants. @@ -75,6 +99,14 @@ def _get_locals_and_globals(f): [_get_cell_contents_or_none(x) for x in annotate_func.__closure__]) }) + # Python 3.10-3.13 do not expose annotation-only local names through + # ``__annotate__``. Recover direct class annotations from resolved + # annotations so schedule-tree lowering can still identify PythonClass + # promotion candidates. + for annotation in getattr(f, '__annotations__', {}).values(): + for name, value in _collect_annotation_class_globals(annotation).items(): + result.setdefault(name, value) + return result @@ -151,7 +183,7 @@ def infer_symbols_from_datadescriptor(sdfg: SDFG, return {str(k)[8:]: v for k, v in result.items()} -class DaceProgram(pycommon.SDFGConvertible): +class DaceProgram(pycommon.SDFGConvertible, pycommon.ScheduleTreeConvertible): """ A data-centric program object, obtained by decorating a function with ``@dace.program``. """ @@ -170,11 +202,14 @@ def __init__(self, use_explicit_cf: bool = True, ignore_type_hints: bool = False): + signature_source = f.__func__ if method and inspect.ismethod(f) and getattr(f, '__self__', + None) is not None else f + self.f = f self.dec_args = args self.dec_kwargs = kwargs self.resolve_functions = constant_functions - self.argnames = _get_argnames(f) + self.argnames = _get_argnames(signature_source) if method: self.objname = self.argnames[0] self.argnames = self.argnames[1:] @@ -192,7 +227,7 @@ def __init__(self, self.ignore_type_hints = ignore_type_hints self.global_vars = _get_locals_and_globals(f) - self.signature = inspect.signature(f) + self.signature = inspect.signature(signature_source) self.default_args = { pname: pval.default for pname, pval in self.signature.parameters.items() if not _is_empty(pval.default) @@ -217,6 +252,9 @@ def __init__(self, # Cache SDFGs with last used arguments self._cache = cached_program.DaceProgramCache(self._eval_closure) + self._schedule_tree_cache: Dict[cached_program.ProgramCacheKey, + 'tn.ScheduleTreeRoot'] = (cached_program.LimitedSizeDict( + size_limit=self._cache.size)) # These sets fill up after the first parsing of the program and stay # the same unless the argument types change self.closure_array_keys: Set[str] = set() @@ -237,6 +275,11 @@ def __deepcopy__(self, memo): setattr(result, k, copy.deepcopy(v, memo)) return result + def _reject_async_program(self) -> None: + if inspect.iscoroutinefunction(self.f) or inspect.isasyncgenfunction(self.f): + raise SyntaxError('Async @dace.program functions are unsupported. ' + 'Use a synchronous @dace.program and call async helpers as callbacks.') + def auto_optimize(self, sdfg: SDFG, symbols: Dict[str, int] = None) -> SDFG: """ Invoke automatic optimization heuristics on internal program. """ # Avoid import loop @@ -288,6 +331,7 @@ def to_sdfg(self, *args, simplify=None, save=False, validate=False, use_cache=Fa if self._cache.has(cachekey): entry = self._cache.get(cachekey) + self._run_parallel_schedule_tree_lowering_checks(args, kwargs, entry.sdfg) return entry.sdfg sdfg = self._parse(args, kwargs, simplify=simplify, save=save, validate=validate) @@ -298,6 +342,48 @@ def to_sdfg(self, *args, simplify=None, save=False, validate=False, use_cache=Fa return sdfg + def to_schedule_tree(self, *args, use_cache: bool = False, **kwargs) -> 'tn.ScheduleTreeRoot': + """ + Creates a schedule tree directly from the DaCe Python frontend. + + :param args: JIT argument examples. + :param kwargs: JIT keyword argument examples. + :param use_cache: If True, reuses a cached schedule tree when possible. + :return: A schedule-tree root object. + """ + self._reject_async_program() + self.global_vars = _get_locals_and_globals(self.f) + + if self.methodobj is not None: + self.global_vars[self.objname] = self.methodobj + + argtypes, _, constant_args, specified = self._get_type_annotations(args, kwargs) + self.global_vars.update(constant_args) + + cachekey = None + if use_cache: + cachekey = self._cache.make_key(argtypes, specified, self.closure_array_keys, self.closure_constant_keys, + constant_args) + if cachekey in self._schedule_tree_cache: + return copy.deepcopy(self._schedule_tree_cache[cachekey]) + + stree = self._generate_schedule_tree(args, kwargs) + + if use_cache: + self._schedule_tree_cache[cachekey] = copy.deepcopy(stree) + + return stree + + def __schedule_tree__(self, + *args, + lambda_bindings: Optional[Dict[str, ast.Lambda]] = None, + callable_bindings: Optional[Dict[str, Any]] = None, + **kwargs) -> 'tn.ScheduleTreeRoot': + return self._generate_schedule_tree(tuple(args), + dict(kwargs), + lambda_bindings=lambda_bindings, + callable_bindings=callable_bindings) + def __sdfg__(self, *args, **kwargs) -> SDFG: return self._parse(args, kwargs, simplify=None, save=False, validate=False) @@ -334,6 +420,9 @@ def name(self) -> str: def __sdfg_signature__(self) -> Tuple[Sequence[str], Sequence[str]]: return self.argnames, self.constant_args + def __schedule_tree_signature__(self) -> Tuple[Sequence[str], Sequence[str]]: + return self.__sdfg_signature__() + def __sdfg_closure__(self, reevaluate: Optional[Dict[str, str]] = None) -> Dict[str, Any]: """ Returns the closure arrays of the SDFG represented by the dace @@ -405,6 +494,8 @@ def _eval_closure(self, arg: str, extra_constants: Optional[Dict[str, Any]] = No return eval(arg, self.global_vars, extra_constants) def _create_sdfg_args(self, sdfg: SDFG, args: Tuple[Any], kwargs: Dict[str, Any]) -> Dict[str, Any]: + annotated_argtypes, _, _, _ = self._get_type_annotations(args, kwargs) + # Start with default arguments, then add other arguments result = {**self.default_args} # Reconstruct keyword arguments @@ -417,10 +508,23 @@ def _create_sdfg_args(self, sdfg: SDFG, args: Tuple[Any], kwargs: Dict[str, Any] # Update closure with respect to callback mapping result.update({k: result[v] for k, v in sdfg.callback_mapping.items()}) + def _try_create_datadescriptor(key: str, val: Any) -> data.Data: + """ + Tries to create a data descriptor from the given argument. If this fails but the argument has a type + annotation, uses the annotation to create the data descriptor instead. This allows users to pass in + ``PythonClass`` arguments. + """ + try: + return create_datadescriptor(val) + except TypeError: + if key in annotated_argtypes: + return annotated_argtypes[key] + raise + # Update arguments with symbols in data shapes result.update( infer_symbols_from_datadescriptor(sdfg, { - k: create_datadescriptor(v) + k: _try_create_datadescriptor(k, v) for k, v in result.items() if k not in self.constant_args })) return result @@ -502,6 +606,8 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF :return: The generated SDFG object. """ + self._reject_async_program() + # Obtain DaCe program as SDFG sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify) @@ -537,6 +643,36 @@ def _evaluate_annotation(self, ann): # Evaluating arbitrary code - anything can happen. Good luck. return dtypes.compiletime + def _resolved_schedule_tree_arg_annotations(self) -> Dict[str, Any]: + result: Dict[str, Any] = {} + + for aname, sig_arg in self.signature.parameters.items(): + if self.objname is not None and aname == self.objname: + continue + + ann = sig_arg.annotation + if self.ignore_type_hints or _is_empty(ann) or ann is dtypes.compiletime: + continue + + try: + if get_origin(ann) is Union: + hint_args = get_args(ann) + if len(hint_args) == 1: + ann = hint_args[0] + elif len(hint_args) == 2 and (hint_args[0] is type(None) or hint_args[1] is type(None)): + ann = hint_args[1] if hint_args[0] is type(None) else hint_args[0] + + ann = self._evaluate_annotation(ann) + except (TypeError, ValueError): + continue + + if ann is dtypes.compiletime: + continue + + result[aname] = ann + + return result + def _get_type_annotations( self, given_args: Tuple[Any], given_kwargs: Dict[str, Any]) -> Tuple[ArgTypes, Dict[str, Any], Dict[str, Any], Set[str]]: @@ -824,6 +960,203 @@ def get_program_hash(self, *args, **kwargs) -> cached_program.ProgramCacheKey: _, key = self._load_sdfg(None, *args, **kwargs) return key + def _generate_schedule_tree(self, + args: Tuple[Any], + kwargs: Dict[str, Any], + *, + lambda_bindings: Optional[Dict[str, ast.Lambda]] = None, + callable_bindings: Optional[Dict[str, Any]] = None, + update_program_state: bool = True) -> 'tn.ScheduleTreeRoot': + """Generates a schedule tree directly from the preprocessed frontend AST.""" + dace_func = self.f + + argtypes, _, gvars, specified = self._get_type_annotations(args, kwargs) + runtime_args = self._bind_schedule_tree_arguments(args, kwargs) + + if self.methodobj is not None: + self.global_vars[self.objname] = self.methodobj + + for name, descriptor in argtypes.items(): + if isinstance(descriptor, data.View): + argtypes[name] = descriptor.as_array() + argtypes[name].transient = False + else: + descriptor_copy = copy.deepcopy(descriptor) + if descriptor_copy.transient: + descriptor_copy.transient = False + argtypes[name] = descriptor_copy + + global_vars = copy.copy(self.global_vars) + + removed_args = set() + for name, descriptor in argtypes.items(): + if descriptor.dtype.type is None: + global_vars[name] = None + removed_args.add(name) + + modules = {k: v.__name__ for k, v in global_vars.items() if dtypes.ismodule(v)} + modules['builtins'] = '' + + global_vars.update({v.name: v for _, v in global_vars.items() if isinstance(v, symbolic.symbol)}) + + unspecified_default_args = {k: v for k, v in self.default_args.items() if k not in specified} + removed_args.update(unspecified_default_args) + gvars.update(unspecified_default_args) + + global_vars.update(gvars) + + argtypes = {k: v for k, v in argtypes.items() if k not in removed_args} + for argtype in argtypes.values(): + global_vars.update({v.name: v for v in argtype.free_symbols}) + + parsed_ast, closure = preprocessing.preprocess_dace_program( + dace_func, + argtypes, + global_vars, + modules, + resolve_functions=self.resolve_functions, + default_args=unspecified_default_args.keys(), + normalize_generic_for_loops=True, + preserve_object_attributes=True, + prefer_resolved_object_attributes=({self.objname} + if self.methodobj is not None and self.objname is not None else None), + disallowed_stmts=set(), + preserve_raises=True, + preserve_fstrings=True, + preserve_uninlinable_context_managers=True, + preserve_call_expansions=True) + parsed_ast.resolved_arg_annotations = self._resolved_schedule_tree_arg_annotations() + + if update_program_state: + self.closure_arg_mapping = {k: v for k, (_, _, v, _) in closure.closure_arrays.items()} + self.closure_array_keys = set(closure.closure_arrays.keys()) - removed_args + self.closure_constant_keys = set(closure.closure_constants.keys()) - removed_args + self.resolver = closure + + constants: Dict[str, Tuple[Data, Any]] = {} + for name, value in closure.closure_constants.items(): + if name in removed_args: + continue + try: + descriptor = create_datadescriptor(value) + except (TypeError, ValueError): + descriptor = None + if descriptor is not None: + constants[name] = (descriptor, value) + + callback_mapping = {name: original_name for name, (original_name, _, _) in closure.callbacks.items()} + arg_names = [name for name in self.argnames if name in argtypes] + + seeded_callable_bindings = dict(callable_bindings or {}) + for name, value in runtime_args.items(): + if name in removed_args or name in seeded_callable_bindings: + continue + if callable(value): + seeded_callable_bindings[name] = value + + seed_bindings = None + if self.methodobj is not None and self.objname is not None: + from dace.data.core import infer_structured_object_members + from dace.data.pydata import PythonClass + + try: + self_descriptor = PythonClass(infer_structured_object_members(self.methodobj), + name=type(self.methodobj).__name__) + except (TypeError, ValueError): + # Keep bound-method self available to the schedule-tree frontend + # even when the object has no currently inferable typed members. + self_descriptor = PythonClass({}, name=type(self.methodobj).__name__) + + if self_descriptor is not None: + seed_bindings = { + self.objname: schedule_tree_frontend._Binding(descriptor=self_descriptor, kind='container') + } + + stree = schedule_tree_frontend.build_schedule_tree(self.name, + parsed_ast, + argtypes, + constants=constants, + callback_mapping=callback_mapping, + arg_names=arg_names, + lambda_bindings=lambda_bindings, + callable_bindings=seeded_callable_bindings, + seed_bindings=seed_bindings) + + for name, (_, descriptor, _, _) in closure.closure_arrays.items(): + if name in removed_args or name in stree.containers: + continue + stree.containers[name] = copy.deepcopy(descriptor) + for free_symbol in descriptor.free_symbols: + stree.symbols.setdefault(free_symbol.name, free_symbol) + + return stree + + def _run_parallel_schedule_tree_lowering_checks(self, args: Tuple[Any], kwargs: Dict[str, Any], sdfg: SDFG) -> None: + stree = self._generate_schedule_tree(args, kwargs, update_program_state=False) + self._check_schedule_tree_parallel_lowering(stree, sdfg) + + def _check_schedule_tree_parallel_lowering(self, stree: 'tn.ScheduleTreeRoot', sdfg: SDFG) -> None: + from dace.data.pydata import PythonClass + from dace.sdfg.analysis.schedule_tree import treenodes as tn + + statement_nodes: List[tn.StatementNode] = [] + refset_nodes: List[tn.RefSetNode] = [] + pythonclass_names: List[str] = [] + + for node in stree.preorder_traversal(): + if isinstance(node, tn.StatementNode): + statement_nodes.append(node) + elif isinstance(node, tn.RefSetNode): + refset_nodes.append(node) + + for name, descriptor in stree.containers.items(): + if isinstance(descriptor, PythonClass): + pythonclass_names.append(name) + + if statement_nodes: + examples = ', '.join(repr(node.code.as_string) for node in statement_nodes[:3]) + raise RuntimeError(f'Schedule-tree parallel lowering failed for {self.name}: ' + f'generated {len(statement_nodes)} StatementNode(s); examples: {examples}') + + if refset_nodes: + if not self._sdfg_contains_reference_descriptors(sdfg): + targets = ', '.join(sorted({node.target for node in refset_nodes})[:5]) + warnings.warn( + 'Schedule-tree parallel lowering failed for ' + f'{self.name}: generated RefSetNode(s) for {targets}, ' + 'but the SDFG contains no reference descriptors', + UserWarning, + stacklevel=4) + + for node in refset_nodes: + source_text = node.source_expr + if source_text is None and node.memlet is not None: + source_text = str(node.memlet) + if source_text is None: + source_text = type(node.src_desc).__name__ + warnings.warn( + 'Schedule-tree parallel lowering warning for ' + f'{self.name}: RefSetNode target "{node.target}" from {source_text}', + UserWarning, + stacklevel=4) + + for name in pythonclass_names: + warnings.warn(f'Schedule-tree parallel lowering warning for {self.name}: PythonClass container "{name}"', + UserWarning, + stacklevel=4) + + def _sdfg_contains_reference_descriptors(self, sdfg: SDFG) -> bool: + return any( + isinstance(descriptor, data.Reference) + for _, _, descriptor in sdfg.arrays_recursive(include_nested_data=True)) + + def _bind_schedule_tree_arguments(self, args: Tuple[Any], kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Return a parameter-to-value map for direct schedule-tree specialization.""" + filtered_kwargs = {k: v for k, v in kwargs.items() if k not in self.symbols} + parameters = [p for p in self.signature.parameters.values() if self.objname is None or p.name != self.objname] + bound = inspect.Signature(parameters).bind_partial(*args, **filtered_kwargs) + return dict(bound.arguments) + def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any], @@ -950,4 +1283,6 @@ def _generate_pdp(self, sdfg.regenerate_code = self.regenerate_code sdfg._recompile = self.recompile + self._run_parallel_schedule_tree_lowering_checks(args, kwargs, sdfg) + return sdfg, cached diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 5f6a387b6b..2c8dd1a0b7 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -1,13 +1,16 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. import ast +import asyncio import collections import copy from dataclasses import dataclass +import functools import inspect import numbers import numpy import re import sympy +import threading import warnings from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union @@ -44,6 +47,21 @@ class PreprocessedAST: src: str preprocessed_ast: ast.AST program_globals: Dict[str, Any] + resolved_arg_annotations: Optional[Dict[str, Any]] = None + + +TypeAlias = getattr(ast, 'TypeAlias', type(None)) + + +def __dace_iterator_init(iterable): + return iterable.__iter__() + + +def __dace_iterator_next(iterator): + try: + return (True, iterator.__next__()) + except StopIteration: + return (False, None) class StructTransformer(ast.NodeTransformer): @@ -107,6 +125,94 @@ def visit_Attribute(self, node): return self.generic_visit(node) +class TypeAliasResolver(ast.NodeTransformer): + """Resolve compile-time-only ``type`` aliases inside function bodies.""" + + class _AnnotationRewriter(ast.NodeTransformer): + + def __init__(self, aliases: Dict[str, ast.AST]) -> None: + self.aliases = aliases + + def visit_Name(self, node: ast.Name) -> ast.AST: + if isinstance(node.ctx, ast.Load) and node.id in self.aliases: + return ast.copy_location(astutils.copy_tree(self.aliases[node.id]), node) + return node + + def __init__(self, filename: str) -> None: + super().__init__() + self.filename = filename + self._alias_scopes: List[Dict[str, ast.AST]] = [dict()] + self._visitor = collections.namedtuple('Visitor', 'filename') + self._visitor.filename = filename + + def _current_aliases(self) -> Dict[str, ast.AST]: + return self._alias_scopes[-1] + + def _rewrite_annotation(self, node: Optional[ast.AST]) -> Optional[ast.AST]: + if node is None: + return None + rewritten = self._AnnotationRewriter(self._current_aliases()).visit(astutils.copy_tree(node)) + return ast.fix_missing_locations(ast.copy_location(rewritten, node)) + + def _visit_body(self, body: List[ast.AST]) -> List[ast.AST]: + new_body: List[ast.AST] = [] + for stmt in body: + if isinstance(stmt, TypeAlias): + self._bind_type_alias(stmt) + continue + + visited = self.visit(stmt) + if visited is None: + continue + if isinstance(visited, list): + new_body.extend(value for value in visited if value is not None) + else: + new_body.append(visited) + return new_body + + def _bind_type_alias(self, node: TypeAlias) -> None: + if getattr(node, 'type_params', None): + raise DaceSyntaxError(self._visitor, node, + 'Generic type aliases are unsupported in @dace.program preprocessing') + + if not isinstance(getattr(node, 'name', None), ast.Name): + return + + self._current_aliases()[node.name.id] = self._rewrite_annotation(node.value) + + def visit_Module(self, node: ast.Module) -> ast.Module: + node.body = self._visit_body(node.body) + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: + node.decorator_list = [self.visit(decorator) for decorator in node.decorator_list] + node.args = self.visit(node.args) + node.returns = self._rewrite_annotation(node.returns) + if hasattr(node, 'type_params'): + node.type_params = [] + + self._alias_scopes.append(dict(self._current_aliases())) + try: + node.body = self._visit_body(node.body) + finally: + self._alias_scopes.pop() + return node + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef: + return self.visit_FunctionDef(node) + + def visit_arg(self, node: ast.arg) -> ast.arg: + node.annotation = self._rewrite_annotation(node.annotation) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: + node.annotation = self._rewrite_annotation(node.annotation) + node.target = self.visit(node.target) + if node.value is not None: + node.value = self.visit(node.value) + return node + + class RewriteSympyEquality(ast.NodeTransformer): """ Replaces symbolic equality checks by ``sympy.{Eq,Ne}``. @@ -143,9 +249,10 @@ class ConditionalCodeResolver(ast.NodeTransformer): Replaces if conditions by their bodies if can be evaluated at compile time. """ - def __init__(self, globals: Dict[str, Any]): + def __init__(self, globals: Dict[str, Any], preserve_raises: bool = False): super().__init__() self.globals_and_locals = copy.copy(globals) + self.preserve_raises = preserve_raises def visit_Name(self, node: ast.Name): if isinstance(node.ctx, ast.Store): @@ -186,6 +293,11 @@ def visit_If(self, node: ast.If) -> Any: def visit_IfExp(self, node: ast.IfExp) -> Any: return self.visit_If(node) + def visit_Raise(self, node: ast.Raise) -> Any: + if self.preserve_raises: + return self.generic_visit(node) + return node + class _FindBreakContinueStmts(ast.NodeVisitor): """ @@ -221,7 +333,14 @@ def visit_Continue(self, node): class DeadCodeEliminator(ast.NodeTransformer): """ Removes any code within scope after return/break/continue/raise. """ + def __init__(self, preserve_raises: bool = False): + super().__init__() + self.preserve_raises = preserve_raises + def generic_visit(self, node: ast.AST): + terminators = (ast.Return, ast.Break, ast.Continue) + if not self.preserve_raises: + terminators = terminators + (ast.Raise, ) for field, old_value in ast.iter_fields(node): if isinstance(old_value, list): # Scope fields @@ -236,7 +355,7 @@ def generic_visit(self, node: ast.AST): elif not isinstance(value, ast.AST): new_values.extend(value) continue - elif (scope_field and isinstance(value, (ast.Return, ast.Break, ast.Continue, ast.Raise))): + elif (scope_field and isinstance(value, terminators)): # Any AST node after this one is unreachable and # not parsed by this transformer new_values.append(value) @@ -383,9 +502,37 @@ def flatten_callback(func: Callable, node: ast.Call, global_vars: Dict[str, Any] # Filter arguments from AST poscount = len(node.args) + def _wrap_async_callback(callback: Callable) -> Callable: + if not inspect.iscoroutinefunction(func): + return callback + + @functools.wraps(callback) + def _wrapped(*all_args): + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(callback(*all_args)) + + holder: Dict[str, Any] = {} + + def _runner() -> None: + try: + holder['result'] = asyncio.run(callback(*all_args)) + except BaseException as ex: + holder['error'] = ex + + worker = threading.Thread(target=_runner) + worker.start() + worker.join() + if 'error' in holder: + raise holder['error'] + return holder.get('result') + + return _wrapped + # Nothing to do, early exit if not node.keywords and not instructions_exist: - return func + return _wrap_async_callback(func) keywords = [kw.arg for kw in node.keywords] @@ -423,21 +570,34 @@ def cb_func(*all_args): return cb_func - return make_cb(keywords, poscount, unflatten_instructions) + return _wrap_async_callback(make_cb(keywords, poscount, unflatten_instructions)) class GlobalResolver(astutils.ExtNodeTransformer, astutils.ASTHelperMixin): """ Resolves global constants and lambda expressions if not already defined in the given scope. """ - def __init__(self, globals: Dict[str, Any], resolve_functions: bool = False, default_args: Set[str] = None): + def __init__(self, + globals: Dict[str, Any], + resolve_functions: bool = False, + default_args: Set[str] = None, + preserve_object_attributes: bool = False, + prefer_resolved_object_attributes: Optional[Set[str]] = None, + preserve_raises: bool = False, + preserve_fstrings: bool = False): self._globals = globals self.resolve_functions = resolve_functions self.default_args = default_args or set() + self.preserve_object_attributes = preserve_object_attributes + self.prefer_resolved_object_attributes = prefer_resolved_object_attributes or set() + self.preserve_raises = preserve_raises + self.preserve_fstrings = preserve_fstrings self.current_scope = set() self.toplevel_function = True self.do_not_detect_callables = False self.ignore_node_ctx = False + self._declared_globals: List[Set[str]] = [] + self._declared_nonlocals: List[Set[str]] = [] self.closure = SDFGClosure() @@ -445,6 +605,83 @@ def __init__(self, globals: Dict[str, Any], resolve_functions: bool = False, def def globals(self): return {k: v for k, v in self._globals.items() if k not in self.current_scope} + def _contains_preserved_attribute_access(self, node: ast.AST) -> bool: + return any( + isinstance(child, ast.Attribute) and self._should_preserve_attribute_access(child) + for child in ast.walk(node)) + + def _should_preserve_attribute_access(self, node: ast.Attribute) -> bool: + if not self.preserve_object_attributes: + return False + + try: + base_value = astutils.evalnode(node.value, self.globals) + except Exception: + return False + + if self._is_native_attribute_base(base_value): + return False + + prefer_resolved_base = (isinstance(node.value, ast.Name) + and node.value.id in self.prefer_resolved_object_attributes) + + # User objects should remain attribute accesses in the preprocessed AST. + # The schedule-tree frontend can then decide whether to keep direct + # attribute syntax or rewrite it into explicit special-method calls. + preserve_direct_attribute = True + + try: + static_attr = inspect.getattr_static(base_value, node.attr) + except AttributeError: + static_attr = None + + if static_attr is not None and self._is_descriptor(static_attr): + if isinstance(node.ctx, ast.Load) and hasattr(static_attr, '__get__'): + return True + if isinstance(node.ctx, (ast.Store, ast.Del)) and (hasattr(static_attr, '__set__') + or hasattr(static_attr, '__delete__')): + return True + + objtype = type(base_value) + if isinstance(node.ctx, ast.Load): + if '__getattr__' in objtype.__dict__: + return True + getattribute = objtype.__dict__.get('__getattribute__') + if getattribute is not None and getattribute is not object.__getattribute__: + return True + if isinstance(node.ctx, (ast.Store, ast.Del)): + setattr_method = objtype.__dict__.get('__setattr__') + if setattr_method is not None and setattr_method is not object.__setattr__: + return True + + if prefer_resolved_base and isinstance(node.ctx, ast.Load): + try: + attribute_value = getattr(base_value, node.attr) + except Exception: + return True + + try: + data.create_datadescriptor(attribute_value) + except (TypeError, ValueError): + return True + + return False + + return preserve_direct_attribute + + def _is_descriptor(self, value: Any) -> bool: + return any(hasattr(value, attr) for attr in ('__get__', '__set__', '__delete__')) + + def _is_native_attribute_base(self, value: Any) -> bool: + if dtypes.ismodule(value): + return True + if isinstance(value, + (dtypes.typeclass, symbolic.symbol, sympy.Basic, data.Data, SDFG, numpy.ndarray, numpy.generic)): + return True + + module_name = getattr(type(value), '__module__', '') + return module_name.startswith(('numpy', 'dace', 'sympy', 'builtins')) + def generic_visit(self, node: ast.AST): if hasattr(node, 'body') or hasattr(node, 'orelse'): oldscope = self.current_scope @@ -502,9 +739,10 @@ def global_value_to_node(self, newnode = ast.parse(symbolic.symstr(value)).body[0].value elif isinstance(value, ast.Name): newnode = ast.Name(id=value.id, ctx=ast.Load()) - elif (dtypes.isconstant(value) or isinstance(value, (StringLiteral, SDFG)) or hasattr(value, '__sdfg__')): - # Could be a constant, an SDFG, or SDFG-convertible object - if isinstance(value, SDFG) or hasattr(value, '__sdfg__'): + elif (dtypes.isconstant(value) or isinstance(value, (StringLiteral, SDFG)) or hasattr(value, '__sdfg__') + or hasattr(value, '__schedule_tree__')): + # Could be a constant, an SDFG, or frontend-convertible object + if isinstance(value, SDFG) or hasattr(value, '__sdfg__') or hasattr(value, '__schedule_tree__'): self.closure.closure_sdfgs[id(value)] = (qualname, value) elif isinstance(value, StringLiteral): value = value.value @@ -523,7 +761,8 @@ def global_value_to_node(self, newnode = astutils.create_constant(value) newnode.qualname = qualname - elif detect_callables and hasattr(value, '__call__') and hasattr(value.__call__, '__sdfg__'): + elif detect_callables and hasattr(value, '__call__') and (hasattr(value.__call__, '__sdfg__') + or hasattr(value.__call__, '__schedule_tree__')): return self.global_value_to_node(value.__call__, parent_node, qualname, recurse, detect_callables) elif dtypes.is_array(value): # Arrays need to be stored as a new name and fed as an argument @@ -589,6 +828,9 @@ def global_value_to_node(self, if isinstance(parent_node, ast.Call): newnode.oldnode = parent_node.func + if inspect.iscoroutinefunction(value): + return newnode + # Decorated or functions with missing source code sast, _, _, _ = astutils.function_to_ast(value) if len(sast.body[0].decorator_list) > 0: @@ -627,31 +869,59 @@ def global_value_to_node(self, else: return newnode + def _current_declared_globals(self) -> Set[str]: + if not self._declared_globals: + return set() + return self._declared_globals[-1] + + def _current_declared_nonlocals(self) -> Set[str]: + if not self._declared_nonlocals: + return set() + return self._declared_nonlocals[-1] + + def _is_declared_global(self, name: str) -> bool: + return bool(self._declared_globals) and name in self._declared_globals[-1] + + def _is_declared_nonlocal(self, name: str) -> bool: + return bool(self._declared_nonlocals) and name in self._declared_nonlocals[-1] + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # Skip the top function definition (handled outside of the resolver) if self.toplevel_function: self.toplevel_function = False node.decorator_list = [] # Skip decorators - return self.generic_visit(node) + self._declared_globals.append(set()) + self._declared_nonlocals.append(set()) + try: + return self.generic_visit(node) + finally: + self._declared_globals.pop() + self._declared_nonlocals.pop() - for arg in ast.walk(node.args): - if isinstance(arg, ast.arg): - # Skip unspecified default arguments - if arg.arg in self.default_args: - continue + self._declared_globals.append(set()) + self._declared_nonlocals.append(set()) + try: + for arg in ast.walk(node.args): + if isinstance(arg, ast.arg): + # Skip unspecified default arguments + if arg.arg in self.default_args: + continue - # Skip ``dace.compiletime``-annotated arguments - is_constant = False - if arg.annotation is not None: - try: - ann = astutils.evalnode(arg.annotation, self.globals) - if ann is dace.compiletime: - is_constant = True - except SyntaxError: - pass - if not is_constant: - self.current_scope.add(arg.arg) - return self.generic_visit(node) + # Skip ``dace.compiletime``-annotated arguments + is_constant = False + if arg.annotation is not None: + try: + ann = astutils.evalnode(arg.annotation, self.globals) + if ann is dace.compiletime: + is_constant = True + except SyntaxError: + pass + if not is_constant: + self.current_scope.add(arg.arg) + return self.generic_visit(node) + finally: + self._declared_globals.pop() + self._declared_nonlocals.pop() def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: return self.visit_FunctionDef(node) @@ -659,6 +929,18 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any: def visit_Lambda(self, node: ast.Lambda) -> Any: return self.visit_FunctionDef(node) + def visit_Global(self, node: ast.Global) -> Any: + self._current_declared_globals().update(node.names) + for name in node.names: + self.current_scope.discard(name) + return node + + def visit_Nonlocal(self, node: ast.Nonlocal) -> Any: + self._current_declared_nonlocals().update(node.names) + for name in node.names: + self.current_scope.discard(name) + return node + def visit_AugAssign(self, node: ast.AugAssign): # Node target in augassign is ast.Store, even though it is updating an existing value oldvalue = self.ignore_node_ctx @@ -671,9 +953,10 @@ def visit_AugAssign(self, node: ast.AugAssign): def visit_Name(self, node: ast.Name): if not self.ignore_node_ctx and isinstance(node.ctx, ast.Store): - self.current_scope.add(node.id) + if not self._is_declared_global(node.id) and not self._is_declared_nonlocal(node.id): + self.current_scope.add(node.id) else: - if node.id in self.current_scope: + if not self._is_declared_global(node.id) and node.id in self.current_scope: return node if node.id in self.globals: global_val = self.globals[node.id] @@ -689,6 +972,11 @@ def visit_keyword(self, node: ast.keyword): return self.generic_visit(node) def _visit_potential_constant(self, node: ast.AST, recurse_on_fail: bool) -> Optional[ast.AST]: + if self._contains_preserved_attribute_access(node): + if recurse_on_fail: + return self.generic_visit(node) + return node + # Try to evaluate the expression with only the globals try: global_val = astutils.evalnode(node, self.globals) @@ -842,6 +1130,8 @@ def visit_Assert(self, node: ast.Assert) -> Any: return None def visit_Raise(self, node: ast.Raise) -> Any: + if self.preserve_raises: + return self.generic_visit(node) warnings.warn(f'Runtime exception at line {node.lineno} is not supported and will be skipped.') return None @@ -850,6 +1140,8 @@ def visit_JoinedStr(self, node: ast.JoinedStr) -> Any: global_val = astutils.evalnode(node, self.globals) return ast.copy_location(ast.Constant(kind='', value=global_val), node) except SyntaxError: + if self.preserve_fstrings: + return self.generic_visit(node) warnings.warn(f'f-string at line {node.lineno} could not ' 'be fully evaluated in DaCe program, converting to ' 'partially-evaluated string.') @@ -871,7 +1163,12 @@ class ContextManagerInliner(ast.NodeTransformer, astutils.ASTHelperMixin): a return statement, or top-level break/continue statements. """ - def __init__(self, globals: Dict[str, Any], filename: str, closure_resolver: GlobalResolver) -> None: + def __init__(self, + globals: Dict[str, Any], + filename: str, + closure_resolver: GlobalResolver, + *, + preserve_uninlinable_context_managers: bool = False) -> None: super().__init__() self.with_statements: List[ast.With] = [] self.context_managers: Dict[ast.With, List[Tuple[str, Any]]] = {} @@ -879,6 +1176,7 @@ def __init__(self, globals: Dict[str, Any], filename: str, closure_resolver: Glo self.filename = filename self.resolver = closure_resolver self.names: Set[str] = set() + self.preserve_uninlinable_context_managers = preserve_uninlinable_context_managers def _visit_node_with_body(self, node): node = self.generic_visit_filtered(node, {'body'}) @@ -976,7 +1274,13 @@ def visit_With(self, node: ast.With): ifnode = ast.copy_location(ifnode, node) # Make enter calls - entries = self._add_entries(node) + try: + entries = self._add_entries(node) + except ValueError: + if self.preserve_uninlinable_context_managers: + self.with_statements.pop() + return node + raise ifnode.body = entries # Visit body @@ -1221,6 +1525,343 @@ def visit_AsyncFor(self, node) -> Any: return self.visit_For(node) +class IteratorForLoopNormalizer(ast.NodeTransformer): + """ + Rewrites non-range/map for-loops into simpler control-flow that the direct + schedule-tree frontend can lower. Array-like iteration, zip, and enumerate + are normalized to index-based loops; remaining iterators fall back to an + explicit iterator protocol while-loop. + """ + + def __init__(self, globals: Dict[str, Any], argtypes: Dict[str, data.Data], closure_resolver: GlobalResolver): + super().__init__() + self.globals = globals + self.argtypes = argtypes + self.resolver = closure_resolver + self._counter = 0 + + def visit_For(self, node: ast.For) -> Any: + node = self.generic_visit(node) + + if self._is_structured_iterator(node.iter): + return node + + rewritten = self._normalize_indexed_iteration(node) + if rewritten is not None: + return rewritten + + rewritten = self._normalize_zip_iteration(node) + if rewritten is not None: + return rewritten + + rewritten = self._normalize_enumerate_iteration(node) + if rewritten is not None: + return rewritten + + return self._normalize_generic_iteration(node) + + def _is_structured_iterator(self, iterator: ast.AST) -> bool: + schedule_target = iterator.left if isinstance(iterator, ast.BinOp) and isinstance(iterator.op, + ast.MatMult) else iterator + if isinstance(schedule_target, ast.Call): + return astutils.rname(schedule_target.func) in {'range', 'prange', 'parrange'} + if isinstance(schedule_target, ast.Subscript): + return astutils.rname(schedule_target.value) == 'dace.map' + return False + + def _normalize_indexed_iteration(self, node: ast.For) -> Optional[ast.For]: + length_expr = self._indexed_iterator_length(node.iter) + if length_expr is None: + return None + + index_name = self._fresh_name('iter_idx') + yielded_value = self._indexed_iterator_value(node.iter, index_name, node) + if yielded_value is None: + return None + replacements = self._target_replacements(node.target, yielded_value) + if replacements is None: + return None + + rewritten = ast.For(target=ast.Name(id=index_name, ctx=ast.Store()), + iter=self._make_range_call(length_expr), + body=self._rewrite_body(node.body, replacements), + orelse=[astutils.copy_tree(stmt) for stmt in node.orelse]) + return ast.fix_missing_locations(ast.copy_location(rewritten, node)) + + def _normalize_zip_iteration(self, node: ast.For) -> Optional[Any]: + if not isinstance(node.iter, ast.Call) or astutils.rname(node.iter.func) != 'zip' or len(node.iter.args) == 0: + return None + + return self._normalize_generic_zip_iteration(node) + + def _normalize_enumerate_iteration(self, node: ast.For) -> Optional[Any]: + if not isinstance(node.iter, ast.Call) or astutils.rname(node.iter.func) != 'enumerate' or len( + node.iter.args) == 0: + return None + + iterable = node.iter.args[0] + start = astutils.copy_tree(node.iter.args[1]) if len(node.iter.args) > 1 else astutils.create_constant(0, node) + + return self._normalize_generic_iteration(node, enumerate_start=start) + + def _normalize_generic_zip_iteration(self, node: ast.For) -> Any: + iterator_names = [self._fresh_name('iter') for _ in node.iter.args] + next_names = [self._fresh_name('iter_next') for _ in node.iter.args] + has_next_names = [self._fresh_name('iter_has_next') for _ in node.iter.args] + value_names = [self._fresh_name('iter_value') for _ in node.iter.args] + + init_nodes: List[ast.AST] = [] + for iterator_name, next_name, has_next_name, value_name, arg in zip(iterator_names, next_names, has_next_names, + value_names, node.iter.args): + init_nodes.append(self._assign(iterator_name, self._helper_call('__dace_iterator_init', [arg]), node)) + init_nodes.extend(self._iterator_next_sequence(iterator_name, next_name, has_next_name, value_name, node)) + + yielded_value = ast.Tuple(elts=[ast.Name(id=value_name, ctx=ast.Load()) for value_name in value_names], + ctx=ast.Load()) + replacements = self._target_replacements(node.target, yielded_value) + destructuring_setup = None + if replacements is None: + destructuring_setup = self._destructuring_setup(node.target, yielded_value, node) + if destructuring_setup is None: + return node + replacements = {} + + test = ast.BoolOp(op=ast.And(), + values=[ast.Name(id=has_next_name, ctx=ast.Load()) for has_next_name in has_next_names]) + body: List[ast.AST] = [] + if destructuring_setup is not None: + body.append(destructuring_setup) + body.extend(self._rewrite_body(node.body, replacements)) + for iterator_name, next_name, has_next_name, value_name in zip(iterator_names, next_names, has_next_names, + value_names): + body.extend(self._iterator_next_sequence(iterator_name, next_name, has_next_name, value_name, node)) + + loop = ast.While(test=test, body=body, orelse=[astutils.copy_tree(stmt) for stmt in node.orelse]) + return [*init_nodes, ast.fix_missing_locations(ast.copy_location(loop, node))] + + def _normalize_generic_iteration(self, node: ast.For, enumerate_start: Optional[ast.AST] = None) -> Any: + iterator_name = self._fresh_name('iter') + next_name = self._fresh_name('iter_next') + has_next_name = self._fresh_name('iter_has_next') + value_name = self._fresh_name('iter_value') + + init_nodes: List[ast.AST] = [ + self._assign(iterator_name, self._helper_call('__dace_iterator_init', [node.iter]), node) + ] + init_nodes.extend(self._iterator_next_sequence(iterator_name, next_name, has_next_name, value_name, node)) + + counter_name: Optional[str] = None + if enumerate_start is not None: + counter_name = self._fresh_name('iter_index') + init_nodes.append(self._assign(counter_name, enumerate_start, node)) + yielded_value: ast.AST = ast.Tuple( + elts=[ast.Name(id=counter_name, ctx=ast.Load()), + ast.Name(id=value_name, ctx=ast.Load())], + ctx=ast.Load()) + else: + yielded_value = ast.Name(id=value_name, ctx=ast.Load()) + + target_setup: Optional[ast.Assign] = None + if self._requires_explicit_target_binding(node.target, node.body): + replacements = {} + target_setup = self._binding_setup(node.target, + yielded_value, + node, + annotation=self._target_binding_annotation(node.target, node.body)) + if target_setup is None: + return node + else: + replacements = self._target_replacements(node.target, yielded_value) + if replacements is None: + target_setup = self._destructuring_setup(node.target, yielded_value, node) + if target_setup is None: + return node + replacements = {} + + body: List[ast.AST] = [] + if target_setup is not None: + body.append(target_setup) + body.extend(self._rewrite_body(node.body, replacements)) + if counter_name is not None: + body.append( + self._assign( + counter_name, + ast.BinOp(left=ast.Name(id=counter_name, ctx=ast.Load()), + op=ast.Add(), + right=astutils.create_constant(1, node)), node)) + body.extend(self._iterator_next_sequence(iterator_name, next_name, has_next_name, value_name, node)) + + test = ast.Name(id=has_next_name, ctx=ast.Load()) + loop = ast.While(test=test, body=body, orelse=[astutils.copy_tree(stmt) for stmt in node.orelse]) + return [*init_nodes, ast.fix_missing_locations(ast.copy_location(loop, node))] + + def _is_indexable_expr(self, node: ast.AST) -> bool: + if isinstance(node, ast.Name) and node.id in self.argtypes: + descriptor = self.argtypes[node.id] + return hasattr(descriptor, 'shape') and not isinstance(descriptor, data.Scalar) + try: + value = astutils.evalnode(node, self.globals) + except SyntaxError: + return False + return dtypes.is_array(value) or (hasattr(value, '__len__') and hasattr(value, '__getitem__')) + + def _indexed_iterator_length(self, iterator: ast.AST) -> Optional[ast.AST]: + if self._is_indexable_expr(iterator): + return self._make_len_call(iterator) + + if isinstance(iterator, ast.Call): + call_name = astutils.rname(iterator.func) + if call_name == 'zip' and iterator.args: + lengths = [self._indexed_iterator_length(arg) for arg in iterator.args] + if any(length is None for length in lengths): + return None + return self._make_min_call(lengths) + if call_name == 'enumerate' and iterator.args: + return self._indexed_iterator_length(iterator.args[0]) + + return None + + def _indexed_iterator_value(self, iterator: ast.AST, index_name: str, location: ast.AST) -> Optional[ast.AST]: + if self._is_indexable_expr(iterator): + return self._make_subscript(iterator, index_name) + + if isinstance(iterator, ast.Call): + call_name = astutils.rname(iterator.func) + if call_name == 'zip' and iterator.args: + values = [self._indexed_iterator_value(arg, index_name, location) for arg in iterator.args] + if any(value is None for value in values): + return None + return ast.Tuple(elts=values, ctx=ast.Load()) + if call_name == 'enumerate' and iterator.args: + inner_value = self._indexed_iterator_value(iterator.args[0], index_name, location) + if inner_value is None: + return None + start = astutils.copy_tree(iterator.args[1]) if len(iterator.args) > 1 else astutils.create_constant( + 0, location) + counter = ast.BinOp(left=start, op=ast.Add(), right=ast.Name(id=index_name, ctx=ast.Load())) + return ast.Tuple(elts=[counter, inner_value], ctx=ast.Load()) + + return None + + def _fresh_name(self, prefix: str) -> str: + name = f'__dace_{prefix}_{self._counter}' + self._counter += 1 + return name + + def _make_range_call(self, stop: ast.AST) -> ast.Call: + return ast.Call(func=ast.Name(id='range', ctx=ast.Load()), args=[stop], keywords=[]) + + def _make_len_call(self, value: ast.AST) -> ast.Call: + return ast.Call(func=ast.Name(id='len', ctx=ast.Load()), args=[astutils.copy_tree(value)], keywords=[]) + + def _make_min_call(self, values: List[ast.AST]) -> ast.Call: + return ast.Call(func=ast.Name(id='min', ctx=ast.Load()), args=values, keywords=[]) + + def _make_subscript(self, value: ast.AST, index_name: str) -> ast.Subscript: + return ast.Subscript(value=astutils.copy_tree(value), + slice=ast.Name(id=index_name, ctx=ast.Load()), + ctx=ast.Load()) + + def _helper_call(self, helper_name: str, args: List[ast.AST]) -> ast.Call: + return ast.Call(func=ast.Name(id=helper_name, ctx=ast.Load()), + args=[astutils.copy_tree(arg) for arg in args], + keywords=[]) + + def _iterator_next_sequence(self, iterator_name: str, next_name: str, has_next_name: str, value_name: str, + location: ast.AST) -> List[ast.Assign]: + next_expr = ast.Name(id=next_name, ctx=ast.Load()) + return [ + self._assign(next_name, + self._helper_call('__dace_iterator_next', [ast.Name(id=iterator_name, ctx=ast.Load())]), + location), + self._assign_target( + ast.Name(id=has_next_name, ctx=ast.Store()), + ast.Subscript(value=astutils.copy_tree(next_expr), + slice=astutils.create_constant(0, location), + ctx=ast.Load()), location), + self._assign_target( + ast.Name(id=value_name, ctx=ast.Store()), + ast.Subscript(value=astutils.copy_tree(next_expr), + slice=astutils.create_constant(1, location), + ctx=ast.Load()), location) + ] + + def _assign(self, target_name: str, value: ast.AST, location: ast.AST) -> ast.Assign: + return ast.fix_missing_locations( + ast.copy_location(ast.Assign(targets=[ast.Name(id=target_name, ctx=ast.Store())], value=value), location)) + + def _assign_target(self, target: ast.AST, value: ast.AST, location: ast.AST) -> ast.Assign: + return ast.fix_missing_locations(ast.copy_location(ast.Assign(targets=[target], value=value), location)) + + def _target_replacements(self, target: ast.AST, value: ast.AST) -> Optional[Dict[str, ast.AST]]: + result: Dict[str, ast.AST] = {} + + def _collect(current_target: ast.AST, current_value: ast.AST) -> bool: + if isinstance(current_target, ast.Name): + result[current_target.id] = current_value + return True + + if isinstance(current_target, (ast.Tuple, ast.List)): + if not isinstance(current_value, (ast.Tuple, ast.List)): + return False + if len(current_target.elts) != len(current_value.elts): + return False + return all( + _collect(sub_target, sub_value) + for sub_target, sub_value in zip(current_target.elts, current_value.elts)) + + return False + + if not _collect(target, value): + return None + return result + + def _rewrite_body(self, body: List[ast.AST], replacements: Dict[str, ast.AST]) -> List[ast.AST]: + rewritten: List[ast.AST] = [] + for stmt in body: + copied = astutils.copy_tree(stmt) + replace = astutils.ASTFindReplace({name: astutils.copy_tree(value) for name, value in replacements.items()}) + rewritten.append(ast.fix_missing_locations(replace.visit(copied))) + return rewritten + + def _requires_explicit_target_binding(self, target: ast.AST, body: List[ast.stmt]) -> bool: + return (isinstance(target, ast.Name) and any( + isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.target.id == target.id + for stmt in body)) + + def _target_binding_annotation(self, target: ast.AST, body: List[ast.stmt]) -> Optional[ast.AST]: + if not isinstance(target, ast.Name): + return None + for stmt in body: + if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.target.id == target.id: + return astutils.copy_tree(stmt.annotation) + return None + + def _binding_setup(self, + target: ast.AST, + value: ast.AST, + location: ast.AST, + annotation: Optional[ast.AST] = None) -> Optional[ast.stmt]: + if isinstance(target, ast.Name) and annotation is not None: + return ast.fix_missing_locations( + ast.copy_location( + ast.AnnAssign(target=astutils.copy_tree(target), + annotation=astutils.copy_tree(annotation), + value=astutils.copy_tree(value), + simple=1), location)) + if isinstance(target, ast.Name): + return ast.fix_missing_locations( + ast.copy_location(ast.Assign(targets=[astutils.copy_tree(target)], value=astutils.copy_tree(value)), + location)) + return self._destructuring_setup(target, value, location) + + def _destructuring_setup(self, target: ast.AST, value: ast.AST, location: ast.AST) -> Optional[ast.Assign]: + if not isinstance(target, (ast.Tuple, ast.List)): + return None + setup = ast.Assign(targets=[astutils.copy_tree(target)], value=astutils.copy_tree(value)) + return ast.fix_missing_locations(ast.copy_location(setup, location)) + + class ExpressionInliner(ast.NodeTransformer): """ Replaces dace.inline() expressions by their bodies if they can be @@ -1312,7 +1953,7 @@ def _eval_args(self, node: ast.Call) -> Dict[str, Any]: return res - def _get_given_args(self, node: ast.Call, function: 'DaceProgram') -> Set[str]: + def _get_given_args(self, node: ast.Call, function) -> Set[str]: """ Returns a set of names of the given arguments from the positional and keyword arguments """ from dace.frontend.python.parser import DaceProgram # Avoid import loop @@ -1443,10 +2084,11 @@ class DisallowedAssignmentChecker(ast.NodeVisitor): ``DaceSyntaxError`` exception if one is found. """ - def __init__(self, filename: str) -> None: + def __init__(self, filename: str, preserve_call_expansions: bool = False) -> None: super().__init__() self.visitor = collections.namedtuple('Visitor', 'filename') self.visitor.filename = filename + self.preserve_call_expansions = preserve_call_expansions def _check_assignment_target(self, node: ast.expr, parent_node: ast.AST): if hasattr(node, 'qualname'): @@ -1472,10 +2114,304 @@ def visit_NamedExpr(self, node): self.generic_visit(node) def visit_Call(self, node: ast.Call): - if any(k.arg is None for k in node.keywords): + if any(k.arg is None for k in node.keywords) and not self.preserve_call_expansions: raise DaceSyntaxError( self.visitor, node, 'Double-starred (dictionary unpacking, e.g., `**a`) arguments are ' 'currently unsupported.') + self.generic_visit(node) + + +class NamedExprDesugarer(ast.NodeTransformer): + """Lifts walrus operator (NamedExpr / :=) assignments out of expressions. + + ``if (x := f()): body`` becomes:: + + x = f() + if x: body + + ``while (x := f()): body`` becomes:: + + x = f() + while x: + body + x = f() + """ + + def _extract_named_exprs(self, node: ast.AST): + """Find NamedExpr nodes in an expression and return (assignments, rewritten_expr).""" + assignments = [] + + class _Replacer(ast.NodeTransformer): + + def visit_NamedExpr(self, ne: ast.NamedExpr) -> ast.AST: + # Recurse into the value first + ne.value = self.visit(ne.value) + assign = ast.Assign(targets=[astutils.copy_tree(ne.target)], value=ne.value) + ast.copy_location(assign, ne) + assignments.append(assign) + replacement = ast.Name(id=ne.target.id, ctx=ast.Load()) + return ast.copy_location(replacement, ne) + + rewritten = _Replacer().visit(astutils.copy_tree(node)) + return assignments, rewritten + + def _has_named_expr(self, node: ast.AST) -> bool: + for child in ast.walk(node): + if isinstance(child, ast.NamedExpr): + return True + return False + + def visit_If(self, node: ast.If) -> ast.AST: + self.generic_visit(node) + if not self._has_named_expr(node.test): + return node + assignments, new_test = self._extract_named_exprs(node.test) + node.test = new_test + ast.fix_missing_locations(node) + return assignments + [node] + + def visit_While(self, node: ast.While) -> ast.AST: + self.generic_visit(node) + if not self._has_named_expr(node.test): + return node + assignments, new_test = self._extract_named_exprs(node.test) + node.test = new_test + # Add re-evaluation at end of loop body + for assign in assignments: + node.body.append(astutils.copy_tree(assign)) + ast.fix_missing_locations(node) + return assignments + [node] + + def visit_Assign(self, node: ast.Assign) -> ast.AST: + self.generic_visit(node) + if not self._has_named_expr(node.value): + return node + assignments, new_value = self._extract_named_exprs(node.value) + node.value = new_value + ast.fix_missing_locations(node) + return assignments + [node] + + def visit_Expr(self, node: ast.Expr) -> ast.AST: + self.generic_visit(node) + if not self._has_named_expr(node.value): + return node + assignments, new_value = self._extract_named_exprs(node.value) + node.value = new_value + ast.fix_missing_locations(node) + return assignments + [node] + + def visit_Return(self, node: ast.Return) -> ast.AST: + self.generic_visit(node) + if node.value is None or not self._has_named_expr(node.value): + return node + assignments, new_value = self._extract_named_exprs(node.value) + node.value = new_value + ast.fix_missing_locations(node) + return assignments + [node] + + +class ComprehensionDesugarer(ast.NodeTransformer): + """Desugars all comprehensions and generator expressions to explicit loops. + + ``[expr for x in iter if cond]`` becomes:: + + __comp_tmp_N = [] + for x in iter: + if cond: + __comp_tmp_N.append(expr) + + Set and dict comprehensions are handled similarly. + Generator expressions consumed by a call (e.g. ``sum(x for x in ...)``) + are desugared to list comprehensions then wrapped in the call. + """ + + def __init__(self): + self._counter = 0 + + def _fresh_name(self) -> str: + self._counter += 1 + return f'__comp_tmp_{self._counter}' + + def _build_loop_nest(self, generators, body_stmt, target_node) -> list: + """Build nested for/if statements from comprehension generators.""" + stmts = body_stmt + # Build inside-out + for gen in reversed(generators): + # Wrap with if-filters + for if_clause in reversed(gen.ifs): + if_node = ast.If(test=if_clause, body=stmts if isinstance(stmts, list) else [stmts], orelse=[]) + ast.copy_location(if_node, target_node) + stmts = [if_node] + # Wrap with for-loop + for_node = ast.For(target=gen.target, + iter=gen.iter, + body=stmts if isinstance(stmts, list) else [stmts], + orelse=[]) + ast.copy_location(for_node, target_node) + stmts = [for_node] + return stmts if isinstance(stmts, list) else [stmts] + + def _desugar_listcomp(self, node: ast.ListComp, target_node: ast.AST) -> Tuple[str, list]: + name = self._fresh_name() + # __comp_tmp = [] + init = ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=ast.List(elts=[], ctx=ast.Load())) + ast.copy_location(init, target_node) + # __comp_tmp.append(elt) + append_call = ast.Expr( + value=ast.Call(func=ast.Attribute(value=ast.Name(id=name, ctx=ast.Load()), attr='append', ctx=ast.Load()), + args=[node.elt], + keywords=[])) + ast.copy_location(append_call, target_node) + loops = self._build_loop_nest(node.generators, [append_call], target_node) + return name, [init] + loops + + def _desugar_setcomp(self, node: ast.SetComp, target_node: ast.AST) -> Tuple[str, list]: + name = self._fresh_name() + # __comp_tmp = set() + init = ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], + value=ast.Call(func=ast.Name(id='set', ctx=ast.Load()), args=[], keywords=[])) + ast.copy_location(init, target_node) + # __comp_tmp.add(elt) + add_call = ast.Expr( + value=ast.Call(func=ast.Attribute(value=ast.Name(id=name, ctx=ast.Load()), attr='add', ctx=ast.Load()), + args=[node.elt], + keywords=[])) + ast.copy_location(add_call, target_node) + loops = self._build_loop_nest(node.generators, [add_call], target_node) + return name, [init] + loops + + def _desugar_dictcomp(self, node: ast.DictComp, target_node: ast.AST) -> Tuple[str, list]: + name = self._fresh_name() + # __comp_tmp = {} + init = ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=ast.Dict(keys=[], values=[])) + ast.copy_location(init, target_node) + # __comp_tmp[key] = value + assign_stmt = ast.Assign( + targets=[ast.Subscript(value=ast.Name(id=name, ctx=ast.Load()), slice=node.key, ctx=ast.Store())], + value=node.value) + ast.copy_location(assign_stmt, target_node) + loops = self._build_loop_nest(node.generators, [assign_stmt], target_node) + return name, [init] + loops + + def _desugar_generatorexp(self, node: ast.GeneratorExp, target_node: ast.AST) -> Tuple[str, list]: + # Desugar generator expressions as list comprehensions + listcomp = ast.ListComp(elt=node.elt, generators=node.generators) + ast.copy_location(listcomp, target_node) + return self._desugar_listcomp(listcomp, target_node) + + def _find_and_desugar(self, node: ast.AST) -> Tuple[list, ast.AST]: + """Walk an expression, desugar any comprehensions found, return (prefix_stmts, rewritten_expr).""" + prefix_stmts = [] + + outer_self = self + + class _Replacer(ast.NodeTransformer): + + def visit_ListComp(self, lc: ast.ListComp) -> ast.AST: + # Recurse into sub-expressions first + lc = self.generic_visit(lc) + name, stmts = outer_self._desugar_listcomp(lc, lc) + prefix_stmts.extend(stmts) + return ast.Name(id=name, ctx=ast.Load()) + + def visit_SetComp(self, sc: ast.SetComp) -> ast.AST: + sc = self.generic_visit(sc) + name, stmts = outer_self._desugar_setcomp(sc, sc) + prefix_stmts.extend(stmts) + return ast.Name(id=name, ctx=ast.Load()) + + def visit_DictComp(self, dc: ast.DictComp) -> ast.AST: + dc = self.generic_visit(dc) + name, stmts = outer_self._desugar_dictcomp(dc, dc) + prefix_stmts.extend(stmts) + return ast.Name(id=name, ctx=ast.Load()) + + def visit_GeneratorExp(self, ge: ast.GeneratorExp) -> ast.AST: + ge = self.generic_visit(ge) + name, stmts = outer_self._desugar_generatorexp(ge, ge) + prefix_stmts.extend(stmts) + return ast.Name(id=name, ctx=ast.Load()) + + rewritten = _Replacer().visit(astutils.copy_tree(node)) + return prefix_stmts, rewritten + + def _has_comprehension(self, node: ast.AST) -> bool: + for child in ast.walk(node): + if isinstance(child, (ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp)): + return True + return False + + def visit_Assign(self, node: ast.Assign) -> ast.AST: + self.generic_visit(node) + if not self._has_comprehension(node.value): + return node + prefix, new_value = self._find_and_desugar(node.value) + node.value = new_value + ast.fix_missing_locations(node) + result = prefix + [node] + for stmt in result: + ast.fix_missing_locations(stmt) + return result + + def visit_Expr(self, node: ast.Expr) -> ast.AST: + self.generic_visit(node) + if not self._has_comprehension(node.value): + return node + prefix, new_value = self._find_and_desugar(node.value) + node.value = new_value + ast.fix_missing_locations(node) + result = prefix + [node] + for stmt in result: + ast.fix_missing_locations(stmt) + return result + + def visit_Return(self, node: ast.Return) -> ast.AST: + self.generic_visit(node) + if node.value is None or not self._has_comprehension(node.value): + return node + prefix, new_value = self._find_and_desugar(node.value) + node.value = new_value + ast.fix_missing_locations(node) + result = prefix + [node] + for stmt in result: + ast.fix_missing_locations(stmt) + return result + + def visit_If(self, node: ast.If) -> ast.AST: + self.generic_visit(node) + if not self._has_comprehension(node.test): + return node + prefix, new_test = self._find_and_desugar(node.test) + node.test = new_test + ast.fix_missing_locations(node) + result = prefix + [node] + for stmt in result: + ast.fix_missing_locations(stmt) + return result + + def visit_For(self, node: ast.For) -> ast.AST: + self.generic_visit(node) + if not self._has_comprehension(node.iter): + return node + prefix, new_iter = self._find_and_desugar(node.iter) + node.iter = new_iter + ast.fix_missing_locations(node) + result = prefix + [node] + for stmt in result: + ast.fix_missing_locations(stmt) + return result + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST: + self.generic_visit(node) + if not self._has_comprehension(node.value): + return node + prefix, new_value = self._find_and_desugar(node.value) + node.value = new_value + ast.fix_missing_locations(node) + result = prefix + [node] + for stmt in result: + ast.fix_missing_locations(stmt) + return result class AugAssignExpander(ast.NodeTransformer): @@ -1483,12 +2419,14 @@ class AugAssignExpander(ast.NodeTransformer): def visit_AugAssign(self, node: ast.AugAssign) -> ast.Assign: target = self.generic_visit(node.target) value = self.generic_visit(node.value) - newvalue = ast.copy_location(ast.BinOp(left=copy.deepcopy(target), op=node.op, right=value), value) + newvalue = ast.copy_location(ast.BinOp(left=astutils.copy_tree(target), op=node.op, right=value), value) return ast.copy_location(ast.Assign(targets=[target], value=newvalue), node) -def find_disallowed_statements(node: ast.AST): - from dace.frontend.python.newast import DISALLOWED_STMTS # Avoid import loop +def find_disallowed_statements(node: ast.AST, stmts=None): + if stmts is None: + from dace.frontend.python.newast import DISALLOWED_STMTS # Avoid import loop + stmts = DISALLOWED_STMTS # Skip everything until the function contents (in case there are disallowed statements in a decorator) if isinstance(node, ast.Module) and isinstance(node.body[0], ast.FunctionDef): nodes = node.body[0].body @@ -1498,7 +2436,7 @@ def find_disallowed_statements(node: ast.AST): for topnode in nodes: for subnode in ast.walk(topnode): # Found disallowed statement - if type(subnode).__name__ in DISALLOWED_STMTS: + if type(subnode).__name__ in stmts: return type(subnode).__name__ # Calls with double-starred arguments (**args) @@ -1572,7 +2510,15 @@ def preprocess_dace_program(f: Callable[..., Any], modules: Dict[str, Any], resolve_functions: bool = False, parent_closure: Optional[SDFGClosure] = None, - default_args: Optional[Set[str]] = None) -> Tuple[PreprocessedAST, SDFGClosure]: + default_args: Optional[Set[str]] = None, + normalize_generic_for_loops: bool = False, + preserve_object_attributes: bool = False, + prefer_resolved_object_attributes: Optional[Set[str]] = None, + disallowed_stmts: Optional[Set[str]] = None, + preserve_raises: bool = False, + preserve_fstrings: bool = False, + preserve_uninlinable_context_managers: bool = False, + preserve_call_expansions: bool = False) -> Tuple[PreprocessedAST, SDFGClosure]: """ Preprocesses a ``@dace.program`` and all its nested functions, returning a preprocessed AST object and the closure of the resulting SDFG. @@ -1592,12 +2538,32 @@ def preprocess_dace_program(f: Callable[..., Any], :param parent_closure: If not None, represents the closure of the parent of the currently processed function. :param default_args: If not None, defines a list of unspecified default arguments. + :param prefer_resolved_object_attributes: If given, attribute loads on + these base names prefer closure/global resolution + when the field value can be represented as a DaCe + descriptor, instead of being preserved as plain + object syntax. + :param preserve_raises: If True, keep ``raise`` statements in the + preprocessed AST for downstream frontends to + handle explicitly. + :param preserve_fstrings: If True, keep non-constant f-string AST nodes + in the preprocessed AST for downstream + frontends to handle explicitly. + :param preserve_uninlinable_context_managers: If True, leave ``with`` / + ``async with`` statements in the AST when the + context manager cannot be created at compile + time, so downstream frontends can decide how to + handle them. + :param preserve_call_expansions: If True, leave calls that use ``**`` + argument expansion in the AST so downstream + frontends can represent them explicitly. :return: A 2-tuple of the AST and its reduced (used) closure. """ src_ast, src_file, src_line, src = astutils.function_to_ast(f) # Resolve data structures src_ast = StructTransformer(global_vars).visit(src_ast) + src_ast = TypeAliasResolver(src_file).visit(src_ast) src_ast = ModuleResolver(modules).visit(src_ast) # Convert modules after resolution @@ -1610,14 +2576,24 @@ def preprocess_dace_program(f: Callable[..., Any], try: src_ast = MPIResolver(global_vars).visit(src_ast) - except (ImportError, ModuleNotFoundError): + except (ImportError, ModuleNotFoundError, RuntimeError): pass src_ast = ModuloConverter().visit(src_ast) + if normalize_generic_for_loops: + global_vars['__dace_iterator_init'] = __dace_iterator_init + global_vars['__dace_iterator_next'] = __dace_iterator_next + # Resolve constants to their values (if they are not already defined in this scope) # and symbols to their names resolved = {k: v for k, v in global_vars.items() if k not in (argtypes.keys() - default_args) and k != '_'} - closure_resolver = GlobalResolver(resolved, resolve_functions, default_args=default_args) + closure_resolver = GlobalResolver(resolved, + resolve_functions, + default_args=default_args, + preserve_object_attributes=preserve_object_attributes, + prefer_resolved_object_attributes=prefer_resolved_object_attributes, + preserve_raises=preserve_raises, + preserve_fstrings=preserve_fstrings) # Append element to call stack and handle max recursion depth if parent_closure is not None: @@ -1633,7 +2609,12 @@ def preprocess_dace_program(f: Callable[..., Any], closure_resolver.closure.callstack = parent_closure.callstack + [fid] # Find disallowed AST nodes - disallowed = find_disallowed_statements(src_ast) + if disallowed_stmts is None: + disallowed = find_disallowed_statements(src_ast) + elif disallowed_stmts: + disallowed = find_disallowed_statements(src_ast, disallowed_stmts) + else: + disallowed = None # Empty set means nothing is disallowed if disallowed: raise TypeError(f'Converting function "{f.__name__}" ({src_file}:{src_line}) to callback due to disallowed ' f'keyword: {disallowed}') @@ -1660,12 +2641,22 @@ def check_code(src_ast): try: closure_resolver.toplevel_function = True src_ast = closure_resolver.visit(src_ast) - DisallowedAssignmentChecker(src_file).visit(src_ast) + DisallowedAssignmentChecker(src_file, preserve_call_expansions=preserve_call_expansions).visit(src_ast) + if normalize_generic_for_loops: + src_ast = ComprehensionDesugarer().visit(src_ast) src_ast = LoopUnroller(resolved, src_file, closure_resolver).visit(src_ast) + if normalize_generic_for_loops: + src_ast = IteratorForLoopNormalizer(resolved, argtypes, closure_resolver).visit(src_ast) src_ast = ExpressionInliner(resolved, src_file, closure_resolver).visit(src_ast) - src_ast = ContextManagerInliner(resolved, src_file, closure_resolver).visit(src_ast) - src_ast = ConditionalCodeResolver(resolved).visit(src_ast) - src_ast = DeadCodeEliminator().visit(src_ast) + src_ast = ContextManagerInliner( + resolved, + src_file, + closure_resolver, + preserve_uninlinable_context_managers=preserve_uninlinable_context_managers).visit(src_ast) + src_ast = ConditionalCodeResolver(resolved, preserve_raises=preserve_raises).visit(src_ast) + if normalize_generic_for_loops: + src_ast = NamedExprDesugarer().visit(src_ast) + src_ast = DeadCodeEliminator(preserve_raises=preserve_raises).visit(src_ast) except Exception: if Config.get_bool('frontend', 'verbose_errors'): print(f'VERBOSE: Failed to preprocess (pass #{pass_num}) the following program:') diff --git a/dace/frontend/python/replacements/__init__.py b/dace/frontend/python/replacements/__init__.py index 11a4254bd6..cbc184c047 100644 --- a/dace/frontend/python/replacements/__init__.py +++ b/dace/frontend/python/replacements/__init__.py @@ -17,3 +17,7 @@ from .pymath import * from .reduction import * from .ufunc import * + +# Lightweight descriptor-inference registrations for the schedule-tree frontend. +# Imported for side effects (populates Replacements._dtype_rep). +from . import type_inference as _type_inference # noqa: F401 diff --git a/dace/frontend/python/replacements/array_creation.py b/dace/frontend/python/replacements/array_creation.py index 66bcd57163..15dfc0d65d 100644 --- a/dace/frontend/python/replacements/array_creation.py +++ b/dace/frontend/python/replacements/array_creation.py @@ -11,12 +11,45 @@ import copy from numbers import Number, Integral +import re from typing import Any, List, Optional, Sequence, Union import numpy as np import sympy as sp +def arange_promoted_symbol_name(name: str) -> str: + sanitized = re.sub(r'\W|^(?=\d)', '_', name) + return f'__sym_{sanitized}' + + +def normalize_arange_argument(value: Any, input_descs: Optional[dict[str, data.Data]] = None) -> Any: + if isinstance(value, Number) or symbolic.issymbolic(value): + return value + if not isinstance(value, str): + return value + + desc = (input_descs or {}).get(value) + if isinstance(desc, data.Scalar): + return symbolic.pystr_to_symbolic(arange_promoted_symbol_name(value)) + + try: + normalized = symbolic.pystr_to_symbolic(value) + except Exception: + return value + return normalized if symbolic.issymbolic(normalized) else value + + +def infer_arange_shape(start: Any, stop: Any, step: Any) -> Optional[Sequence[Any]]: + if all(isinstance(v, Number) for v in (start, stop, step)): + return (int(np.ceil((stop - start) / step)), ) + if any(not isinstance(v, Number) and not symbolic.issymbolic(v) for v in (start, stop, step)): + return None + if step == 1: + return (stop - start, ) + return (symbolic.int_ceil(stop - start, step), ) + + @oprepo.replaces('numpy.copy') def _numpy_copy(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str): """ Creates a copy of array a. @@ -233,6 +266,24 @@ def _arange(pv: ProgramVisitor, like: Optional[str] = None): """ Implementes numpy.arange """ + def _promote_scalar_argument(value: Any) -> Any: + if not isinstance(value, str): + return value + if value not in sdfg.arrays: + return value + + desc = sdfg.arrays[value] + if not isinstance(desc, data.Scalar): + return value + + promoted_name = sdfg.add_symbol(arange_promoted_symbol_name(value), desc.dtype, find_new_name=True) + promoted = symbolic.symbol(promoted_name, desc.dtype) + + symassign_state = pv.cfg_target.add_state_before(state, label=f'promote_{value}_to_{promoted_name}') + isedge = pv.cfg_target.edges_between(symassign_state, state)[0] + isedge.data.assignments[promoted_name] = value + return promoted + start = 0 stop = None step = 1 @@ -245,6 +296,10 @@ def _arange(pv: ProgramVisitor, else: start, stop, step = args + start = _promote_scalar_argument(start) + stop = _promote_scalar_argument(stop) + step = _promote_scalar_argument(step) + if isinstance(start, str): raise TypeError(f'Cannot compile numpy.arange with a scalar start value "{start}" (only constants and symbolic ' 'expressions are supported). Please use numpy.linspace instead.') @@ -259,13 +314,9 @@ def _arange(pv: ProgramVisitor, if isinstance(start, Number) and isinstance(stop, Number): actual_step = type(start + step)(start + step) - start - if any(not isinstance(s, Number) for s in [start, stop, step]): - if step == 1: # Common case where ceiling is not necessary - shape = (stop - start, ) - else: - shape = (symbolic.int_ceil(stop - start, step), ) - else: - shape = (np.int64(np.ceil((stop - start) / step)), ) + shape = infer_arange_shape(start, stop, step) + if shape is None: + raise TypeError('Cannot compile numpy.arange with non-scalar or unsupported dynamic arguments.') # Infer dtype from input arguments if dtype is None: @@ -419,3 +470,237 @@ def _linspace(pv: ProgramVisitor, external_edges=True) return outname, stepname + + +# -------------------------------------------------------------------- # +# Descriptor inference for array creation (schedule-tree frontend) # +# -------------------------------------------------------------------- # + +from dace.frontend.common.op_repository import (infers_descriptor, infers_method_descriptor) +from dace.frontend.python.replacements.type_inference import _get_desc, _to_int + + +@infers_descriptor('numpy.full') +def _infer_full(input_descs, shape, fill_value=None, dtype=None, **_kw): + if isinstance(shape, (Number, np.integer)) or symbolic.issymbolic(shape): + shape = [shape] + if not isinstance(shape, (tuple, list)): + return None + out_shape = [] + for s in shape: + v = _to_int(s) + if v is not None: + out_shape.append(v) + elif symbolic.issymbolic(s): + out_shape.append(s) + else: + return None + if dtype is None: + if isinstance(fill_value, (Number, np.bool_)): + dtype = dtypes.dtype_to_typeclass(type(fill_value)) + else: + desc = _get_desc(input_descs, fill_value) + if desc is not None: + dtype = desc.dtype + else: + dtype = dtypes.float64 + if not isinstance(dtype, dtypes.typeclass): + try: + dtype = dtypes.dtype_to_typeclass(dtype) + except (TypeError, ValueError): + return None + return data.Array(dtype, out_shape, transient=True) + + +@infers_descriptor('numpy.empty') +@infers_descriptor('numpy.zeros') +@infers_descriptor('numpy.ones') +def _infer_empty_zeros_ones(input_descs, shape, dtype=None, **_kw): + if isinstance(shape, (Number, np.integer)) or symbolic.issymbolic(shape): + shape = [shape] + if not isinstance(shape, (tuple, list)): + return None + out_shape = [] + for s in shape: + v = _to_int(s) + if v is not None: + out_shape.append(v) + elif symbolic.issymbolic(s): + out_shape.append(s) + else: + return None + if dtype is None: + dtype = dtypes.float64 + if not isinstance(dtype, dtypes.typeclass): + try: + dtype = dtypes.dtype_to_typeclass(dtype) + except (TypeError, ValueError): + return None + return data.Array(dtype, out_shape, transient=True) + + +@infers_descriptor('numpy.empty_like') +@infers_descriptor('numpy.zeros_like') +@infers_descriptor('numpy.ones_like') +def _infer_empty_zeros_ones_like(input_descs, prototype, dtype=None, shape=None, **_kw): + descriptor = _get_desc(input_descs, prototype) + if descriptor is None: + return None + + result = copy.deepcopy(descriptor) + result.transient = True + + if shape is not None: + if isinstance(shape, (Number, np.integer)) or symbolic.issymbolic(shape): + shape = [shape] + if not isinstance(shape, (tuple, list)): + return None + out_shape = [] + for size in shape: + value = _to_int(size) + if value is not None: + out_shape.append(value) + elif symbolic.issymbolic(size): + out_shape.append(size) + else: + return None + if hasattr(result, 'set_shape'): + result.set_shape(out_shape) + + if dtype is not None: + if not isinstance(dtype, dtypes.typeclass): + try: + dtype = dtypes.dtype_to_typeclass(dtype) + except (TypeError, ValueError): + return None + result.dtype = dtype + + return result + + +@infers_descriptor('numpy.full_like') +def _infer_full_like(input_descs, prototype, fill_value=None, dtype=None, shape=None, **_kw): + return _infer_empty_zeros_ones_like(input_descs, prototype, dtype=dtype, shape=shape) + + +@infers_descriptor('numpy.eye') +def _infer_eye(input_descs, N, M=None, k=0, dtype=None, **_kw): + n = _to_int(N) + if n is None and not symbolic.issymbolic(N): + return None + if M is None: + M = N + m = _to_int(M) + if m is None and not symbolic.issymbolic(M): + return None + if dtype is None: + dtype = dtypes.float64 + if not isinstance(dtype, dtypes.typeclass): + try: + dtype = dtypes.dtype_to_typeclass(dtype) + except (TypeError, ValueError): + return None + return data.Array(dtype, [N, M], transient=True) + + +@infers_descriptor('numpy.identity') +def _infer_identity(input_descs, n, dtype=None, **_kw): + return _infer_eye(input_descs, n, M=n, dtype=dtype) + + +@infers_descriptor('numpy.arange') +@infers_descriptor('dace.arange') +def _infer_arange(input_descs, *args, dtype=None, **_kw): + if len(args) == 1: + start, stop, step = 0, args[0], 1 + elif len(args) == 2: + start, stop, step = args[0], args[1], 1 + elif len(args) >= 3: + start, stop, step = args[0], args[1], args[2] + else: + return None + + start = normalize_arange_argument(start, input_descs) + stop = normalize_arange_argument(stop, input_descs) + step = normalize_arange_argument(step, input_descs) + shape = infer_arange_shape(start, stop, step) + if shape is None: + return None + + if dtype is None: + if all(isinstance(v, Number) for v in args): + from dace.frontend.python.replacements.operators import result_type as rt + dtype, _ = rt(args) + else: + dtype = dtypes.float64 + if not isinstance(dtype, dtypes.typeclass): + try: + dtype = dtypes.dtype_to_typeclass(dtype) + except (TypeError, ValueError): + return None + return data.Array(dtype, list(shape), transient=True) + + +@infers_descriptor('numpy.linspace') +def _infer_linspace(input_descs, + start=None, + stop=None, + num=50, + endpoint=True, + retstep=False, + dtype=None, + axis=0, + **_kw): + n = _to_int(num) + if n is None and not symbolic.issymbolic(num): + return None + if dtype is None: + dtype = dtypes.float64 + if not isinstance(dtype, dtypes.typeclass): + try: + dtype = dtypes.dtype_to_typeclass(dtype) + except (TypeError, ValueError): + return None + + start_desc = _get_desc(input_descs, start) + stop_desc = _get_desc(input_descs, stop) + start_shape = list(start_desc.shape) if start_desc is not None and hasattr(start_desc, 'shape') else [] + stop_shape = list(stop_desc.shape) if stop_desc is not None and hasattr(stop_desc, 'shape') else [] + + try: + shape, _ranges, _outind, _ind1, _ind2 = broadcast_together(start_shape, stop_shape) + shape_with_axis = _add_axis_to_shape(shape, axis, num if n is None else n) + except (TypeError, ValueError): + return None + + array_result = data.Array(dtype, list(shape_with_axis), transient=True) + if not retstep: + return array_result + + if shape: + step_result = data.Array(dtype, list(shape), transient=True) + else: + step_result = data.Scalar(dtype, transient=True) + return (array_result, step_result) + + +@infers_descriptor('numpy.copy') +def _infer_copy(input_descs, a, **_kw): + desc = _get_desc(input_descs, a) + if desc is None: + return None + if isinstance(desc, data.Scalar): + return data.Scalar(desc.dtype) + return data.Array(desc.dtype, list(desc.shape), transient=True) + + +# Method: .copy() +def _infer_method_copy(self_desc, **_kw): + if isinstance(self_desc, data.Scalar): + return data.Scalar(self_desc.dtype) + return data.Array(self_desc.dtype, list(self_desc.shape), transient=True) + + +for _cls in ('Array', 'Scalar', 'View'): + infers_method_descriptor(_cls, 'copy')(_infer_method_copy) + infers_method_descriptor(_cls, 'fill')(lambda self_desc, value=None, **_kw: ()) diff --git a/dace/frontend/python/replacements/array_creation_cupy.py b/dace/frontend/python/replacements/array_creation_cupy.py index 5b83c15237..5127bcaaed 100644 --- a/dace/frontend/python/replacements/array_creation_cupy.py +++ b/dace/frontend/python/replacements/array_creation_cupy.py @@ -4,14 +4,35 @@ """ from dace.frontend.common import op_repository as oprepo import dace.frontend.python.memlet_parser as mem_parser +from dace.frontend.python.replacements.array_creation_dace import _normalize_allocator_shape +from dace.frontend.python.replacements.type_inference import _get_desc from dace.frontend.python.replacements.utils import ProgramVisitor, Shape, sym_type -from dace import dtypes, symbolic, Memlet, SDFG, SDFGState +from dace import data, dtypes, symbolic, Memlet, SDFG, SDFGState from numbers import Number import numpy as np +def _normalize_cupy_dtype(dtype: dtypes.typeclass): + if dtype is None: + return None + if isinstance(dtype, dtypes.typeclass): + return dtype + try: + return dtypes.dtype_to_typeclass(dtype) + except (TypeError, ValueError): + return None + + +def _cupy_array_descriptor(shape: Shape, dtype: dtypes.typeclass): + out_shape = _normalize_allocator_shape(shape) + out_dtype = _normalize_cupy_dtype(dtype) + if out_shape is None or out_dtype is None: + return None + return data.Array(out_dtype, out_shape, storage=dtypes.StorageType.GPU_Global, transient=True) + + @oprepo.replaces("cupy._core.core.ndarray") @oprepo.replaces("cupy.ndarray") def _define_cupy_local(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, shape: Shape, dtype: dtypes.typeclass): @@ -22,6 +43,12 @@ def _define_cupy_local(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, shape: return name +@oprepo.infers_descriptor("cupy._core.core.ndarray") +@oprepo.infers_descriptor("cupy.ndarray") +def _infer_cupy_local(input_descs, shape: Shape, dtype: dtypes.typeclass, **_kw): + return _cupy_array_descriptor(shape, dtype) + + @oprepo.replaces('cupy.full') def _cupy_full(pv: ProgramVisitor, sdfg: SDFG, @@ -52,6 +79,22 @@ def _cupy_full(pv: ProgramVisitor, return name +@oprepo.infers_descriptor('cupy.full') +def _infer_cupy_full(input_descs, + shape: Shape, + fill_value: symbolic.SymbolicType, + dtype: dtypes.typeclass = None, + **_kw): + if dtype is None: + if isinstance(fill_value, (Number, np.bool_)): + dtype = dtypes.dtype_to_typeclass(type(fill_value)) + elif symbolic.issymbolic(fill_value): + dtype = sym_type(fill_value) + else: + return None + return _cupy_array_descriptor(shape, dtype) + + @oprepo.replaces('cupy.zeros') def _cupy_zeros(pv: ProgramVisitor, sdfg: SDFG, @@ -63,6 +106,11 @@ def _cupy_zeros(pv: ProgramVisitor, return _cupy_full(pv, sdfg, state, shape, 0.0, dtype) +@oprepo.infers_descriptor('cupy.zeros') +def _infer_cupy_zeros(input_descs, shape: Shape, dtype: dtypes.typeclass = dtypes.float64, **_kw): + return _cupy_array_descriptor(shape, dtype) + + @oprepo.replaces('cupy.empty_like') def _cupy_empty_like(pv: ProgramVisitor, sdfg: SDFG, @@ -81,8 +129,35 @@ def _cupy_empty_like(pv: ProgramVisitor, return name +@oprepo.infers_descriptor('cupy.empty_like') +def _infer_cupy_empty_like(input_descs, prototype: str, dtype: dtypes.typeclass = None, shape: Shape = None, **_kw): + desc = _get_desc(input_descs, prototype) + if not isinstance(desc, data.Data): + return None + result = desc.clone() + if dtype is not None: + out_dtype = _normalize_cupy_dtype(dtype) + if out_dtype is None: + return None + result.dtype = out_dtype + if shape is not None: + out_shape = _normalize_allocator_shape(shape) + if out_shape is None: + return None + result.shape = out_shape + result.storage = dtypes.StorageType.GPU_Global + result.transient = True + return result + + @oprepo.replaces('cupy.empty') @oprepo.replaces('cupy_empty') def _cupy_empty(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, shape: Shape, dtype: dtypes.typeclass): """ Creates an unitialized array of the specificied shape and dtype. """ return _define_cupy_local(pv, sdfg, state, shape, dtype) + + +@oprepo.infers_descriptor('cupy.empty') +@oprepo.infers_descriptor('cupy_empty') +def _infer_cupy_empty(input_descs, shape: Shape, dtype: dtypes.typeclass, **_kw): + return _cupy_array_descriptor(shape, dtype) diff --git a/dace/frontend/python/replacements/array_creation_dace.py b/dace/frontend/python/replacements/array_creation_dace.py index b9aae4a888..6314736fe8 100644 --- a/dace/frontend/python/replacements/array_creation_dace.py +++ b/dace/frontend/python/replacements/array_creation_dace.py @@ -5,8 +5,8 @@ """ from dace.frontend.common import op_repository as oprepo from dace.frontend.python.common import DaceSyntaxError, StringLiteral -from dace.frontend.python.replacements.utils import ProgramVisitor, Shape, Size -from dace import data, dtypes, Memlet, SDFG, SDFGState +from dace.frontend.python.replacements.utils import ProgramVisitor, Shape, Size, sym_type +from dace import data, dtypes, symbolic, Memlet, SDFG, SDFGState from copy import deepcopy as dcpy from numbers import Integral @@ -15,6 +15,303 @@ import numpy as np +def _normalize_allocator_shape(shape: Shape): + if isinstance(shape, Integral) or symbolic.issymbolic(shape): + return [shape] + if not isinstance(shape, (list, tuple)): + return None + return list(shape) + + +def infer_array_creation_descriptor(obj: Any, + *, + dtype: dtypes.typeclass = None, + copy: bool = True, + order: StringLiteral = StringLiteral('K'), + subok: bool = False, + ndmin: int = 0, + like: Any = None) -> Optional[data.Data]: + if like is not None: + return None + + if dtype is not None and not isinstance(dtype, dtypes.typeclass): + try: + dtype = dtypes.typeclass(dtype) + except TypeError: + return None + + if isinstance(obj, data.Data): + descriptor = dcpy(obj) + if dtype is not None: + descriptor.dtype = dtype + + shape = list(getattr(descriptor, 'shape', ())) + if isinstance(descriptor, data.Scalar): + if ndmin <= 0: + descriptor.transient = True + return descriptor + return data.Array(descriptor.dtype, [1] * ndmin, transient=True) + + if len(shape) < ndmin and hasattr(descriptor, 'set_shape'): + descriptor.set_shape([1] * (ndmin - len(shape)) + shape) + descriptor.transient = True + return descriptor + + try: + if dtype is None: + arr = np.array(obj, copy=copy, order=str(order), subok=subok, ndmin=ndmin) + else: + arr = np.array(obj, dtype.as_numpy_dtype(), copy=copy, order=str(order), subok=subok, ndmin=ndmin) + except Exception: + return None + + try: + descriptor = data.create_datadescriptor(arr) + except TypeError: + scalar_dtype = dtypes.typeclass(np.asarray(arr).dtype.type) + if getattr(arr, 'shape', tuple()) == tuple(): + descriptor = data.Scalar(scalar_dtype, transient=True) + else: + return None + descriptor.transient = True + return descriptor + + +def infer_dynamic_literal_descriptor(obj: Any, + sdfg: SDFG, + *, + dtype: dtypes.typeclass = None, + ndmin: int = 0) -> Optional[data.Array]: + shape_dtype = _infer_dynamic_literal_shape_dtype(obj, sdfg) + if shape_dtype is None: + return None + + shape, inferred_dtype = shape_dtype + result_dtype = dtype or inferred_dtype + if result_dtype is None: + return None + + out_shape = list(shape) + if len(out_shape) < ndmin: + out_shape = [1] * (ndmin - len(out_shape)) + out_shape + return data.Array(result_dtype, out_shape, transient=True) + + +def populate_dynamic_literal_array(state: SDFGState, sdfg: SDFG, outname: str, obj: Any) -> None: + outdesc = sdfg.arrays[outname] + constant_array = _entire_constant_literal_array(obj, outdesc) + if constant_array is not None: + const_name = sdfg.find_new_constant(f'{outname}_literal') + sdfg.add_constant(const_name, constant_array) + sdfg.arrays[const_name] = sdfg.constants_prop[const_name][0] + sdfg.arrays[const_name].transient = True + read = state.add_read(const_name) + write = state.add_write(outname) + subset = ', '.join(f'0:{dim}' for dim in constant_array.shape) + state.add_edge(read, None, write, None, Memlet.simple(const_name, subset, other_subset_str=subset)) + return + + write = state.add_write(outname) + counter = 0 + + def emit(value: Any, index: tuple[int, ...]) -> None: + nonlocal counter + if isinstance(value, (list, tuple)): + for child_idx, child in enumerate(value): + emit(child, index + (child_idx, )) + return + + tasklet_name = f'{outname}_literal_{counter}' + counter += 1 + subset = ', '.join(str(i) for i in index) + if isinstance(value, str) and value in sdfg.arrays: + desc = sdfg.arrays[value] + read = state.add_read(value) + tasklet = state.add_tasklet(tasklet_name, {'__inp'}, {'__out'}, '__out = __inp') + state.add_edge(read, None, tasklet, '__inp', Memlet.from_array(value, desc)) + state.add_edge(tasklet, '__out', write, None, Memlet.simple(outname, subset)) + return + + tasklet = state.add_tasklet(tasklet_name, set(), {'__out'}, f'__out = {_literal_code(value)}') + state.add_edge(tasklet, '__out', write, None, Memlet.simple(outname, subset)) + + emit(obj, tuple()) + + +def _entire_constant_literal_array(obj: Any, outdesc: data.Array) -> Optional[np.ndarray]: + if not _is_entire_literal_constant(obj): + return None + npdtype = outdesc.dtype.as_numpy_dtype() + result = np.array(obj, dtype=npdtype) + if tuple(result.shape) != tuple(outdesc.shape): + try: + result = result.reshape(tuple(outdesc.shape)) + except ValueError: + return None + return result + + +def _is_entire_literal_constant(obj: Any) -> bool: + if isinstance(obj, (list, tuple)): + return all(_is_entire_literal_constant(v) for v in obj) + return isinstance(obj, (np.generic, bool, int, float, complex)) + + +def _infer_dynamic_literal_shape_dtype(obj: Any, sdfg: SDFG) -> Optional[tuple[tuple[int, ...], dtypes.typeclass]]: + if isinstance(obj, (list, tuple)): + child_shapes: list[tuple[int, ...]] = [] + child_dtype: Optional[dtypes.typeclass] = None + for element in obj: + shape_dtype = _infer_dynamic_literal_shape_dtype(element, sdfg) + if shape_dtype is None: + return None + element_shape, element_dtype = shape_dtype + child_shapes.append(element_shape) + child_dtype = element_dtype if child_dtype is None else dtypes.result_type_of(child_dtype, element_dtype) + + if not child_shapes: + return ((0, ), dtypes.float64) + + first_shape = child_shapes[0] + if any(shape != first_shape for shape in child_shapes[1:]): + return None + return ((len(obj), ) + first_shape, child_dtype) + + dtype = _dynamic_literal_scalar_dtype(obj, sdfg) + if dtype is None: + return None + return (tuple(), dtype) + + +def _dynamic_literal_scalar_dtype(obj: Any, sdfg: SDFG) -> Optional[dtypes.typeclass]: + if isinstance(obj, np.generic): + return dtypes.typeclass(obj.dtype.type) + if isinstance(obj, bool): + return dtypes.bool + if isinstance(obj, (int, float, complex)): + return dtypes.typeclass(type(obj)) + if symbolic.issymbolic(obj): + return sym_type(obj) + if isinstance(obj, str): + if obj in sdfg.arrays: + desc = sdfg.arrays[obj] + if isinstance(desc, data.Scalar): + return desc.dtype + if isinstance(desc, data.Array) and tuple(desc.shape) == (1, ): + return desc.dtype + return None + if obj in sdfg.symbols: + return sdfg.symbols[obj] + try: + parsed = symbolic.pystr_to_symbolic(obj) + except Exception: + return None + if symbolic.issymbolic(parsed): + return sym_type(parsed) + return None + + +def _literal_code(value: Any) -> str: + if isinstance(value, np.generic): + return repr(value.item()) + if isinstance(value, str): + return value + if symbolic.issymbolic(value): + return symbolic.symstr(value) + return repr(value) + + +@oprepo.infers_descriptor('dace.define_local') +@oprepo.infers_descriptor('dace.ndarray') +@oprepo.infers_descriptor('numpy.ndarray') +@oprepo.infers_descriptor('numpy.empty') +def _infer_local_array_descriptor(input_descs, shape: Shape, dtype: dtypes.typeclass, **_kw): + del input_descs + out_shape = _normalize_allocator_shape(shape) + if out_shape is None or dtype is None: + return None + if not isinstance(dtype, dtypes.typeclass): + try: + dtype = dtypes.dtype_to_typeclass(dtype) + except (TypeError, ValueError): + return None + return data.Array(dtype, out_shape, transient=True) + + +@oprepo.infers_descriptor('dace.define_local_scalar') +def _infer_local_scalar_descriptor(input_descs, dtype: dtypes.typeclass, **_kw): + del input_descs + if dtype is None: + return None + if not isinstance(dtype, dtypes.typeclass): + try: + dtype = dtypes.dtype_to_typeclass(dtype) + except (TypeError, ValueError): + return None + return data.Scalar(dtype, transient=True) + + +@oprepo.infers_descriptor('dace.define_local_structure') +def _infer_local_structure_descriptor(input_descs, dtype: data.Structure, **_kw): + del input_descs + if dtype is None: + return None + descriptor = dcpy(dtype) + descriptor.transient = True + return descriptor + + +def _normalize_inferred_dtype(dtype: dtypes.typeclass) -> Optional[dtypes.typeclass]: + if dtype is None: + return None + if isinstance(dtype, dtypes.typeclass): + return dtype + try: + return dtypes.dtype_to_typeclass(dtype) + except (TypeError, ValueError): + return None + + +@oprepo.infers_descriptor('dace.define_stream') +def _infer_stream_descriptor(input_descs, dtype: dtypes.typeclass, buffer_size: Size = 1, **_kw): + out_dtype = _normalize_inferred_dtype(dtype) + if out_dtype is None: + return None + return data.Stream(out_dtype, buffer_size, transient=True) + + +@oprepo.infers_descriptor('dace.define_streamarray') +@oprepo.infers_descriptor('dace.stream') +def _infer_streamarray_descriptor(input_descs, shape: Shape, dtype: dtypes.typeclass, buffer_size: Size = 1, **_kw): + out_shape = _normalize_allocator_shape(shape) + out_dtype = _normalize_inferred_dtype(dtype) + if out_shape is None or out_dtype is None: + return None + return data.Stream(out_dtype, buffer_size, shape=out_shape, transient=True) + + +@oprepo.infers_descriptor('numpy.asarray') +@oprepo.infers_descriptor('numpy.array') +@oprepo.infers_descriptor('dace.array') +def _infer_literal_array_descriptor(input_descs, + obj: Any, + dtype: dtypes.typeclass = None, + copy: bool = True, + order: StringLiteral = StringLiteral('K'), + subok: bool = False, + ndmin: int = 0, + like: Any = None, + **_kw): + del input_descs + return infer_array_creation_descriptor(obj, + dtype=dtype, + copy=copy, + order=order, + subok=subok, + ndmin=ndmin, + like=like) + + @oprepo.replaces('dace.define_local') @oprepo.replaces('dace.ndarray') def _define_local_ex(pv: ProgramVisitor, @@ -103,6 +400,7 @@ def _define_streamarray(pv: ProgramVisitor, return name +@oprepo.replaces('numpy.asarray') @oprepo.replaces('numpy.array') @oprepo.replaces('dace.array') def _define_literal_ex(pv: ProgramVisitor, @@ -130,13 +428,26 @@ def _define_literal_ex(pv: ProgramVisitor, desc = dcpy(sdfg.arrays[obj]) if dtype is not None: desc.dtype = dtype + dynamic_literal = False else: # From literal / constant - if dtype is None: - arr = np.array(obj, copy=copy, order=str(order), subok=subok, ndmin=ndmin) + desc = infer_array_creation_descriptor(obj, + dtype=dtype, + copy=copy, + order=order, + subok=subok, + ndmin=ndmin, + like=like) + dynamic_literal = desc is None + if dynamic_literal: + desc = infer_dynamic_literal_descriptor(obj, sdfg, dtype=dtype, ndmin=ndmin) + if desc is None: + raise DaceSyntaxError(pv, None, 'Could not infer numpy.array descriptor from literal input') else: - npdtype = dtype.as_numpy_dtype() - arr = np.array(obj, npdtype, copy=copy, order=str(order), subok=subok, ndmin=ndmin) - desc = data.create_datadescriptor(arr) + if dtype is None: + arr = np.array(obj, copy=copy, order=str(order), subok=subok, ndmin=ndmin) + else: + npdtype = dtype.as_numpy_dtype() + arr = np.array(obj, npdtype, copy=copy, order=str(order), subok=subok, ndmin=ndmin) # Set extra properties desc.transient = True @@ -153,6 +464,8 @@ def _define_literal_ex(pv: ProgramVisitor, rnode = state.add_read(obj) wnode = state.add_write(name) state.add_nedge(rnode, wnode, Memlet.from_array(name, desc)) + elif dynamic_literal: + populate_dynamic_literal_array(state, sdfg, name, obj) else: # Make constant sdfg.add_constant(name, arr, desc) diff --git a/dace/frontend/python/replacements/array_manipulation.py b/dace/frontend/python/replacements/array_manipulation.py index f414d239f3..6246e156dc 100644 --- a/dace/frontend/python/replacements/array_manipulation.py +++ b/dace/frontend/python/replacements/array_manipulation.py @@ -426,13 +426,16 @@ def _ndarray_T(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str) -> st ############################################################################### +def _resolve_converter_dtype(typeclass: str) -> dtypes.typeclass: + if typeclass == 'bool': + return dtypes.bool + if typeclass in {'int', 'float', 'complex'}: + return dtypes.dtype_to_typeclass(eval(typeclass)) + return dtypes.dtype_to_typeclass(getattr(np, typeclass)) + + def _make_datatype_converter(typeclass: str): - if typeclass == "bool": - dtype = dtypes.bool - elif typeclass in {"int", "float", "complex"}: - dtype = dtypes.dtype_to_typeclass(eval(typeclass)) - else: - dtype = dtypes.dtype_to_typeclass(getattr(np, typeclass)) + dtype = _resolve_converter_dtype(typeclass) @oprepo.replaces(typeclass) @oprepo.replaces("dace.{}".format(typeclass)) @@ -840,3 +843,465 @@ def _hsplit(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, ary: str, def _vsplit(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, ary: str, indices_or_sections: Union[symbolic.SymbolicType, List[symbolic.SymbolicType], str]): return _split_core(visitor, sdfg, state, ary, indices_or_sections, axis=0, allow_uneven=False) + + +# -------------------------------------------------------------------- # +# Descriptor inference for array manipulation (schedule-tree frontend) # +# -------------------------------------------------------------------- # + +from dace.frontend.common.op_repository import (infers_descriptor, infers_method_descriptor, + infers_attribute_descriptor) +from dace.frontend.python.replacements.type_inference import _get_desc, _to_int +from dace.frontend.python.replacements.utils import normalize_axes + +# -- Free functions ---------------------------------------------------- # + + +@infers_descriptor('numpy.reshape') +def _infer_reshape(input_descs, arr, newshape, **_kw): + desc = _get_desc(input_descs, arr) + if desc is None: + return None + if not isinstance(newshape, (tuple, list)): + return None + shape = [] + for s in newshape: + v = _to_int(s) + if v is not None: + shape.append(v) + elif symbolic.issymbolic(s): + shape.append(s) + else: + return None + if not shape: + return None + return data.Array(desc.dtype, shape, transient=True) + + +@infers_descriptor('transpose') +@infers_descriptor('dace.transpose') +@infers_descriptor('numpy.transpose') +def _infer_transpose(input_descs, arr, axes=None, **_kw): + desc = _get_desc(input_descs, arr) + if desc is None: + return None + shape = list(desc.shape) + if axes is None: + shape = list(reversed(shape)) + else: + if not isinstance(axes, (tuple, list)): + return None + shape = [shape[i] for i in axes] + if len(shape) == 0: + return data.Scalar(desc.dtype) + return data.Array(desc.dtype, shape, transient=True) + + +@infers_descriptor('numpy.flip') +def _infer_flip(input_descs, arr, axis=None, **_kw): + """flip preserves shape and dtype.""" + desc = _get_desc(input_descs, arr) + if desc is None: + return None + if isinstance(desc, data.Scalar): + return data.Scalar(desc.dtype) + return data.Array(desc.dtype, list(desc.shape), transient=True) + + +@infers_descriptor('numpy.rot90') +def _infer_rot90(input_descs, arr, k=1, axes=(0, 1), **_kw): + desc = _get_desc(input_descs, arr) + if desc is None or not isinstance(desc, (data.Array, data.View)): + return None + + ndim = len(desc.shape) + if not isinstance(axes, (tuple, list)) or len(axes) != 2: + return None + try: + axis0 = int(axes[0]) + axis1 = int(axes[1]) + k = int(k) % 4 + except Exception: + return None + + if axis0 < 0: + axis0 += ndim + if axis1 < 0: + axis1 += ndim + if axis0 == axis1 or abs(axis0 - axis1) == ndim: + return None + if axis0 >= ndim or axis0 < 0 or axis1 >= ndim or axis1 < 0: + return None + + shape = list(desc.shape) + if k % 2 == 1: + shape[axis0], shape[axis1] = shape[axis1], shape[axis0] + return data.Array(desc.dtype, shape, transient=True) + + +@infers_descriptor('numpy.squeeze') +def _infer_squeeze(input_descs, arr, axis=None, **_kw): + desc = _get_desc(input_descs, arr) + if desc is None: + return None + shape = list(desc.shape) + if axis is None: + shape = [s for s in shape if s != 1] + else: + if not isinstance(axis, (tuple, list)): + axis = (axis, ) + axis = tuple(_to_int(a) for a in axis) + if any(a is None for a in axis): + return None + axis = tuple(normalize_axes(axis, len(shape))) + shape = [s for i, s in enumerate(shape) if i not in axis] + if not shape: + return data.Scalar(desc.dtype) + return data.Array(desc.dtype, shape, transient=True) + + +@infers_descriptor('numpy.expand_dims') +def _infer_expand_dims(input_descs, arr, axis, **_kw): + desc = _get_desc(input_descs, arr) + if desc is None: + return None + shape = list(desc.shape) + if not isinstance(axis, (tuple, list)): + axis = (axis, ) + axis = tuple(_to_int(a) for a in axis) + if any(a is None for a in axis): + return None + ndim_out = len(shape) + len(axis) + axis = tuple(a if a >= 0 else a + ndim_out for a in axis) + out_shape = [None] * ndim_out + for a in sorted(axis): + out_shape[a] = 1 + si = 0 + for i in range(ndim_out): + if out_shape[i] is None: + out_shape[i] = shape[si] + si += 1 + return data.Array(desc.dtype, out_shape, transient=True) + + +@infers_descriptor('numpy.concatenate') +def _infer_concatenate(input_descs, arrays, axis=0, **_kw): + if not isinstance(arrays, (tuple, list)) or len(arrays) == 0: + return None + descs = [_get_desc(input_descs, a) for a in arrays] + if any(d is None for d in descs): + return None + shape = list(descs[0].shape) + if axis is None: + # Flatten all, then concatenate + total = sum(data._prod(d.shape) for d in descs) + return data.Array(descs[0].dtype, [total], transient=True) + ax = _to_int(axis) + if ax is None: + return None + if ax < 0: + ax += len(shape) + shape[ax] = sum(d.shape[ax] for d in descs) + return data.Array(descs[0].dtype, shape, transient=True) + + +@infers_descriptor('numpy.stack') +def _infer_stack(input_descs, arrays, axis=0, **_kw): + if not isinstance(arrays, (tuple, list)) or len(arrays) == 0: + return None + descs = [_get_desc(input_descs, a) for a in arrays] + if any(d is None for d in descs): + return None + shape = list(descs[0].shape) + ax = _to_int(axis) + if ax is None: + return None + if ax < 0: + ax += len(shape) + 1 + shape.insert(ax, len(arrays)) + return data.Array(descs[0].dtype, shape, transient=True) + + +@infers_descriptor('numpy.vstack') +@infers_descriptor('numpy.row_stack') +def _infer_vstack(input_descs, tup, **kwargs): + if not isinstance(tup, (tuple, list)) or len(tup) == 0: + return None + first = _get_desc(input_descs, tup[0]) + if first is None: + return None + if len(first.shape) == 1: + return _infer_stack(input_descs, tup, axis=0, **kwargs) + return _infer_concatenate(input_descs, tup, axis=0, **kwargs) + + +@infers_descriptor('numpy.hstack') +@infers_descriptor('numpy.column_stack') +def _infer_hstack(input_descs, tup, **kwargs): + if not isinstance(tup, (tuple, list)) or len(tup) == 0: + return None + first = _get_desc(input_descs, tup[0]) + if first is None: + return None + axis = 0 if len(first.shape) == 1 else 1 + return _infer_concatenate(input_descs, tup, axis=axis, **kwargs) + + +@infers_descriptor('numpy.dstack') +def _infer_dstack(input_descs, tup, **kwargs): + if not isinstance(tup, (tuple, list)) or len(tup) == 0: + return None + first = _get_desc(input_descs, tup[0]) + if first is None or len(first.shape) < 3: + return None + return _infer_concatenate(input_descs, tup, axis=2, **kwargs) + + +def _split_descriptors(desc: data.Data, axis: int, sections: Sequence[symbolic.SymbolicType]): + result = [] + offset = 0 + for section in sections: + shape = list(desc.shape) + shape[axis] = section - offset + result.append(data.Array(desc.dtype, shape, transient=True)) + offset = section + + shape = list(desc.shape) + shape[axis] = desc.shape[axis] - offset + result.append(data.Array(desc.dtype, shape, transient=True)) + return result + + +def _infer_split_core(input_descs, ary, indices_or_sections, axis, allow_uneven): + desc = _get_desc(input_descs, ary) + if desc is None: + return None + + ax = _to_int(axis) + if ax is None: + return None + if ax < 0: + ax += len(desc.shape) + if ax < 0 or ax >= len(desc.shape): + return None + + dim_size = desc.shape[ax] + if isinstance(indices_or_sections, (list, tuple)): + sections = [] + for section in indices_or_sections: + value = _to_int(section) + if value is not None: + sections.append(value) + elif symbolic.issymbolic(section): + sections.append(section) + else: + return None + return _split_descriptors(desc, ax, sections) + + nsections = _to_int(indices_or_sections) + if nsections is None or nsections <= 0 or symbolic.issymbolic(dim_size): + return None + + section_size = dim_size // nsections + remainder = dim_size % nsections + if not allow_uneven and remainder != 0: + return None + + result = [] + for index in range(nsections): + shape = list(desc.shape) + size = section_size + if allow_uneven and index < remainder: + size += 1 + shape[ax] = size + result.append(data.Array(desc.dtype, shape, transient=True)) + return result + + +@infers_descriptor('numpy.split') +def _infer_split(input_descs, ary, indices_or_sections, axis=0, **_kw): + return _infer_split_core(input_descs, ary, indices_or_sections, axis, allow_uneven=False) + + +@infers_descriptor('numpy.array_split') +def _infer_array_split(input_descs, ary, indices_or_sections, axis=0, **_kw): + return _infer_split_core(input_descs, ary, indices_or_sections, axis, allow_uneven=True) + + +@infers_descriptor('numpy.dsplit') +def _infer_dsplit(input_descs, ary, indices_or_sections, **_kw): + desc = _get_desc(input_descs, ary) + if desc is None or len(desc.shape) < 3: + return None + return _infer_split_core(input_descs, ary, indices_or_sections, axis=2, allow_uneven=False) + + +@infers_descriptor('numpy.hsplit') +def _infer_hsplit(input_descs, ary, indices_or_sections, **_kw): + desc = _get_desc(input_descs, ary) + if desc is None: + return None + axis = 0 if len(desc.shape) <= 1 else 1 + return _infer_split_core(input_descs, ary, indices_or_sections, axis=axis, allow_uneven=False) + + +@infers_descriptor('numpy.vsplit') +def _infer_vsplit(input_descs, ary, indices_or_sections, **_kw): + return _infer_split_core(input_descs, ary, indices_or_sections, axis=0, allow_uneven=False) + + +# -- Method inference -------------------------------------------------- # + + +def _infer_method_reshape(self_desc, *newshape, **_kw): + if len(newshape) == 1 and isinstance(newshape[0], (tuple, list)): + newshape = newshape[0] + shape = [] + for s in newshape: + v = _to_int(s) + if v is not None: + shape.append(v) + elif symbolic.issymbolic(s): + shape.append(s) + else: + return None + if not shape: + return None + return data.Array(self_desc.dtype, shape, transient=True) + + +for _cls in ('Array', 'View'): + infers_method_descriptor(_cls, 'reshape')(_infer_method_reshape) + + +def _infer_method_flatten(self_desc, **_kw): + total = data._prod(self_desc.shape) + return data.Array(self_desc.dtype, [total], transient=True) + + +for _cls in ('Array', 'Scalar', 'View'): + infers_method_descriptor(_cls, 'flatten')(_infer_method_flatten) + infers_method_descriptor(_cls, 'ravel')(_infer_method_flatten) + + +def _infer_method_transpose(self_desc, *axes, **_kw): + shape = list(self_desc.shape) + if len(axes) == 0 or axes[0] is None: + shape = list(reversed(shape)) + else: + if len(axes) == 1 and isinstance(axes[0], (tuple, list)): + axes = axes[0] + shape = [shape[i] for i in axes] + if len(shape) == 0: + return data.Scalar(self_desc.dtype) + return data.Array(self_desc.dtype, shape, transient=True) + + +for _cls in ('Array', 'View'): + infers_method_descriptor(_cls, 'transpose')(_infer_method_transpose) + + +def _normalize_view_dtype(dtype) -> Optional[dtypes.typeclass]: + if dtype is None: + return None + if isinstance(dtype, dtypes.typeclass): + return dtype + try: + return dtypes.dtype_to_typeclass(np.dtype(dtype).type) + except (TypeError, ValueError): + return None + + +def _infer_method_view(self_desc, dtype, type=None, **_kw): + if type is not None: + return None + + dtype = _normalize_view_dtype(dtype) + if dtype is None: + return None + + result = data.View.view(self_desc) + result.dtype = dtype + + if isinstance(self_desc, data.Scalar): + return result + + orig_bytes = self_desc.dtype.bytes + view_bytes = dtype.bytes + if view_bytes < orig_bytes and orig_bytes % view_bytes != 0: + return None + + contigdim = next((i for i, stride in enumerate(self_desc.strides) if stride == 1), None) + if contigdim is None: + return None + + if (not symbolic.issymbolic(self_desc.shape[contigdim]) and orig_bytes < view_bytes + and self_desc.shape[contigdim] * orig_bytes % view_bytes != 0): + return None + + newshape = list(self_desc.shape) + newstrides = [(stride * orig_bytes) // view_bytes if i != contigdim else stride + for i, stride in enumerate(self_desc.strides)] + newshape[contigdim] = (newshape[contigdim] * orig_bytes) // view_bytes + + result.shape = newshape + result.strides = newstrides + result.total_size = (self_desc.total_size * orig_bytes) // view_bytes + return result + + +for _cls in ('Array', 'Scalar', 'View'): + infers_method_descriptor(_cls, 'view')(_infer_method_view) + + +def _infer_method_astype(self_desc, dtype, **_kw): + if dtype is None: + return None + if isinstance(dtype, type) and dtype in dtypes._CONSTANT_TYPES[:-1]: + dtype = dtypes.typeclass(dtype) + if not isinstance(dtype, dtypes.typeclass): + return None + if isinstance(self_desc, data.Scalar): + return data.Scalar(dtype) + return data.Array(dtype, list(self_desc.shape), transient=True) + + +for _cls in ('Array', 'Scalar', 'View'): + infers_method_descriptor(_cls, 'astype')(_infer_method_astype) + + +def _make_datatype_converter_inference(typeclass: str) -> None: + dtype = _resolve_converter_dtype(typeclass) + + @infers_descriptor(typeclass) + @infers_descriptor(f'dace.{typeclass}') + @infers_descriptor(f'numpy.{typeclass}') + def _infer(input_descs, arg, **_kw): + desc = _get_desc(input_descs, arg) + if desc is None: + return None + return _infer_method_astype(desc, dtype) + + +for _typeclass in dtypes.TYPECLASS_STRINGS: + _make_datatype_converter_inference(_typeclass) + +# -- Attribute inference ----------------------------------------------- # + + +def _infer_attr_T(self_desc): + shape = list(reversed(self_desc.shape)) + return data.Array(self_desc.dtype, shape, transient=True) + + +for _cls in ('Array', 'View'): + infers_attribute_descriptor(_cls, 'T')(_infer_attr_T) + + +def _infer_attr_flat(self_desc): + total = data._prod(self_desc.shape) + return data.Array(self_desc.dtype, [total], transient=True) + + +for _cls in ('Array', 'Scalar', 'View'): + infers_attribute_descriptor(_cls, 'flat')(_infer_attr_flat) diff --git a/dace/frontend/python/replacements/array_metadata.py b/dace/frontend/python/replacements/array_metadata.py index 387c1c9195..f2c57fb4a3 100644 --- a/dace/frontend/python/replacements/array_metadata.py +++ b/dace/frontend/python/replacements/array_metadata.py @@ -4,8 +4,9 @@ """ import dace # noqa from dace.frontend.common import op_repository as oprepo +from dace.frontend.common.op_repository import infers_attribute_descriptor, infers_descriptor from dace.frontend.python.replacements.utils import ProgramVisitor, Size -from dace import data, SDFG, SDFGState +from dace import data, dtypes, SDFG, SDFGState @oprepo.replaces('len') @@ -20,6 +21,13 @@ def _len_array(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, a: str): return len(a) +@infers_descriptor('len') +def _infer_len(input_descs, a, **_kw): + if not isinstance(a, str) or a not in input_descs: + return None + return data.Scalar(dtypes.int64, transient=True) + + @oprepo.replaces_attribute('Array', 'size') @oprepo.replaces_attribute('Scalar', 'size') @oprepo.replaces_attribute('View', 'size') @@ -27,3 +35,11 @@ def size(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str) -> Size: desc = sdfg.arrays[arr] totalsize = data._prod(desc.shape) return totalsize + + +def _infer_size(self_desc, **_kw): + return data.Scalar(dtypes.int64, transient=True) + + +for _cls in ('Array', 'Scalar', 'View'): + infers_attribute_descriptor(_cls, 'size')(_infer_size) diff --git a/dace/frontend/python/replacements/fft.py b/dace/frontend/python/replacements/fft.py index b32f6e122e..aab93e0a90 100644 --- a/dace/frontend/python/replacements/fft.py +++ b/dace/frontend/python/replacements/fft.py @@ -89,3 +89,37 @@ def _ifft(pv: 'ProgramVisitor', axis=-1, norm: StringLiteral = StringLiteral('backward')): return _fft_core(pv, sdfg, state, a, n, axis, norm, True) + + +# -------------------------------------------------------------------- # +# Descriptor inference for FFT (schedule-tree frontend) # +# -------------------------------------------------------------------- # + +from dace import data +from dace.frontend.common.op_repository import infers_descriptor +from dace.frontend.python.replacements.type_inference import _get_desc + + +def _infer_fft_descriptor(input_descs, a, n=None, axis=-1, is_inverse=False, **_kw): + desc = _get_desc(input_descs, a) + if desc is None or not isinstance(desc, data.Data): + return None + if axis not in (0, -1) or n is not None: + return None + if is_inverse and desc.dtype not in (dtypes.complex64, dtypes.complex128): + return None + + out_dtype = _real_to_complex(desc.dtype) + if isinstance(desc, data.Scalar): + return data.Scalar(out_dtype, transient=True) + return data.Array(out_dtype, list(desc.shape), transient=True) + + +@infers_descriptor('numpy.fft.fft') +def _infer_fft(input_descs, a, n=None, axis=-1, norm=StringLiteral('backward'), **_kw): + return _infer_fft_descriptor(input_descs, a, n=n, axis=axis, is_inverse=False) + + +@infers_descriptor('numpy.fft.ifft') +def _infer_ifft(input_descs, a, n=None, axis=-1, norm=StringLiteral('backward'), **_kw): + return _infer_fft_descriptor(input_descs, a, n=n, axis=axis, is_inverse=True) diff --git a/dace/frontend/python/replacements/filtering.py b/dace/frontend/python/replacements/filtering.py index d1acd38686..83a39276ae 100644 --- a/dace/frontend/python/replacements/filtering.py +++ b/dace/frontend/python/replacements/filtering.py @@ -176,3 +176,73 @@ def _array_array_select(visitor: ProgramVisitor, right_operand_node = nd return out_operand + + +# -------------------------------------------------------------------- # +# Descriptor inference for filtering (schedule-tree frontend) # +# -------------------------------------------------------------------- # + +from dace.frontend.common.op_repository import infers_descriptor +from dace.frontend.python.replacements.type_inference import _get_desc +from dace.frontend.python.replacements.operators import result_type + + +def _where_operand_value(input_descs, operand): + desc = _get_desc(input_descs, operand) + if desc is not None: + return desc + return operand + + +def _where_operand_shape(operand) -> List[int]: + if isinstance(operand, data.Data): + return list(operand.shape) + return [1] + + +@infers_descriptor('numpy.where') +def _infer_where(input_descs, cond_operand, left_operand=None, right_operand=None, **_kw): + if left_operand is None or right_operand is None: + return None + + cond_desc = _get_desc(input_descs, cond_operand) + if cond_desc is None: + return None + + left_value = _where_operand_value(input_descs, left_operand) + right_value = _where_operand_value(input_descs, right_operand) + + if not isinstance(left_value, data.Data) and not isinstance(right_value, data.Data): + return None + + try: + result_dtype, _casting = result_type([left_value, right_value]) + except Exception: + return None + + if not isinstance(result_dtype, dtypes.typeclass): + return None + + try: + out_shape, _all_idx, _out_idx, _left_idx, _right_idx = broadcast_together(_where_operand_shape(left_value), + _where_operand_shape(right_value)) + broadcast_together(list(cond_desc.shape), out_shape) + except Exception: + return None + + return data.Array(result_dtype, list(out_shape), transient=True) + + +@infers_descriptor('numpy.select') +def _infer_select(input_descs, cond_list, choice_list, default=None, **_kw): + if not isinstance(cond_list, (tuple, list)) or not isinstance(choice_list, (tuple, list)): + return None + if len(cond_list) != len(choice_list) or len(cond_list) == 0: + return None + + current = 0 if default is None else default + for cond_operand, left_operand in reversed(list(zip(cond_list, choice_list))): + current = _infer_where(input_descs, cond_operand, left_operand, current) + if current is None: + return None + return current diff --git a/dace/frontend/python/replacements/linalg.py b/dace/frontend/python/replacements/linalg.py index f5b054be25..1d1245e0e7 100644 --- a/dace/frontend/python/replacements/linalg.py +++ b/dace/frontend/python/replacements/linalg.py @@ -329,3 +329,170 @@ def _einsum(pv: ProgramVisitor, output_name=pv.get_target_name(), alpha=alpha, beta=beta) + + +# -------------------------------------------------------------------- # +# Descriptor inference for operators (schedule-tree frontend) # +# -------------------------------------------------------------------- # + +from dace.frontend.common.op_repository import infers_descriptor, infers_operator_descriptor +from dace.frontend.common.einsum import EinsumParser +from dace.frontend.python.replacements.type_inference import _get_desc +from dace.frontend.python.schedule_tree.expression_support import (_matmul_output_shape) + + +@infers_operator_descriptor('MatMult') +def _infer_matmult(left_desc, right_desc): + """Infer result descriptor for the ``@`` (MatMult) operator.""" + left_shape = tuple(left_desc.shape) + right_shape = tuple(right_desc.shape) + out_shape = _matmul_output_shape(left_shape, right_shape) + if out_shape is None: + return None + + type1 = left_desc.dtype.type + type2 = right_desc.dtype.type + restype = dtypes.dtype_to_typeclass(np.result_type(type1, type2).type) + + if len(out_shape) == 0: + return data.Scalar(restype) + return data.Array(restype, list(out_shape), transient=True) + + +@infers_descriptor('dace.linalg.inv') +@infers_descriptor('numpy.linalg.inv') +def _infer_inv(input_descs, inp_op, **_kw): + desc = input_descs.get(inp_op) if isinstance(inp_op, str) else inp_op + if not isinstance(desc, data.Data): + return None + if isinstance(desc, data.Scalar): + return data.Scalar(desc.dtype, transient=True) + return data.Array(desc.dtype, list(desc.shape), transient=True) + + +@infers_descriptor('dace.linalg.solve') +@infers_descriptor('numpy.linalg.solve') +def _infer_solve(input_descs, op_a, op_b, **_kw): + desc = input_descs.get(op_b) if isinstance(op_b, str) else op_b + if not isinstance(desc, data.Data): + return None + if isinstance(desc, data.Scalar): + return data.Scalar(desc.dtype, transient=True) + return data.Array(desc.dtype, list(desc.shape), transient=True) + + +@infers_descriptor('dace.linalg.cholesky') +@infers_descriptor('numpy.linalg.cholesky') +def _infer_cholesky(input_descs, inp_op, **_kw): + return _infer_inv(input_descs, inp_op, **_kw) + + +@infers_descriptor('dace.dot') +@infers_descriptor('numpy.dot') +def _infer_dot(input_descs, op_a, op_b, op_out=None, **_kw): + from dace.frontend.python.replacements.ufunc import _infer_ufunc_descriptor + from dace.frontend.python.replacements.operators import result_type + + desc_a = input_descs.get(op_a) if isinstance(op_a, str) else op_a + desc_b = input_descs.get(op_b) if isinstance(op_b, str) else op_b + if not isinstance(desc_a, data.Data) or not isinstance(desc_b, data.Data): + return None + + if len(desc_a.shape) == 2 and len(desc_b.shape) == 2: + return _infer_matmult(desc_a, desc_b) + + if (isinstance(desc_a, data.Scalar) or list(desc_a.shape) == [1] or isinstance(desc_b, data.Scalar) + or list(desc_b.shape) == [1]): + return _infer_ufunc_descriptor(input_descs, 'multiply', op_a, op_b, *(() if op_out is None else (op_out, ))) + + if len(desc_a.shape) > 2 or len(desc_b.shape) > 2: + return None + if desc_a.shape[0] != desc_b.shape[0]: + return None + if op_out is not None: + if isinstance(op_out, str) and op_out in input_descs: + out_desc = input_descs[op_out] + if isinstance(out_desc, data.Data): + return copy.deepcopy(out_desc) + return None + + restype, _ = result_type([desc_a, desc_b], 'Mul') + if not isinstance(restype, dtypes.typeclass): + return None + return data.Scalar(restype, transient=True) + + +@infers_descriptor('dace.tensordot') +@infers_descriptor('numpy.tensordot') +def _infer_tensordot(input_descs, op_a, op_b, axes=2, out_axes=None, **_kw): + desc_a = input_descs.get(op_a) if isinstance(op_a, str) else op_a + desc_b = input_descs.get(op_b) if isinstance(op_b, str) else op_b + if not isinstance(desc_a, data.Data) or not isinstance(desc_b, data.Data): + return None + + if isinstance(axes, Integral): + left_axes = list(range(len(desc_a.shape) - axes, len(desc_a.shape))) + right_axes = list(range(0, axes)) + else: + if not isinstance(axes, (tuple, list)) or len(axes) != 2: + return None + left_axes = list(axes[0]) + right_axes = list(axes[1]) + + if any(a >= len(desc_a.shape) or a < 0 for a in left_axes): + return None + if any(a >= len(desc_b.shape) or a < 0 for a in right_axes): + return None + if len(left_axes) != len(right_axes): + return None + if any(desc_a.shape[l] != desc_b.shape[r] for l, r in zip(left_axes, right_axes)): + return None + + dot_shape = [s for i, s in enumerate(desc_a.shape) if i not in left_axes] + dot_shape.extend([s for i, s in enumerate(desc_b.shape) if i not in right_axes]) + + if out_axes is not None: + if not isinstance(out_axes, (tuple, list)): + return None + if list(sorted(out_axes)) != list(range(len(dot_shape))): + return None + dot_shape = [dot_shape[i] for i in out_axes] + + if len(dot_shape) == 0: + return data.Scalar(desc_a.dtype, transient=True) + return data.Array(desc_a.dtype, list(dot_shape), transient=True) + + +@infers_descriptor('numpy.einsum') +def _infer_einsum(input_descs, einsum_string, *arrays, dtype=None, optimize=False, output=None, **_kw): + if output is not None: + explicit = _get_desc(input_descs, output) + if explicit is None: + return None + return copy.deepcopy(explicit) + + try: + parser = EinsumParser(str(einsum_string)) + except Exception: + return None + + if len(parser.inputs) != len(arrays): + return None + + chardict = {} + descs = [] + for inp, arr in zip(parser.inputs, arrays): + desc = _get_desc(input_descs, arr) + if not isinstance(desc, data.Data): + return None + if len(inp) != len(desc.shape): + return None + descs.append(desc) + for char, shp in zip(inp, desc.shape): + if char in chardict and shp != chardict[char]: + return None + chardict[char] = shp + + out_dtype = dtype or descs[0].dtype + output_shape = [chardict[k] for k in parser.output] or [1] + return data.Array(out_dtype, output_shape, transient=True) diff --git a/dace/frontend/python/replacements/misc.py b/dace/frontend/python/replacements/misc.py index 7bfc15e8a4..5e7bc059fe 100644 --- a/dace/frontend/python/replacements/misc.py +++ b/dace/frontend/python/replacements/misc.py @@ -7,8 +7,9 @@ from dace.frontend.common import op_repository as oprepo from dace.frontend.python import astutils from dace.frontend.python.common import StringLiteral +from dace.frontend.python.replacements.type_inference import _get_desc from dace.frontend.python.replacements.utils import ProgramVisitor -from dace import Memlet, SDFG, SDFGState, dtypes +from dace import Memlet, SDFG, SDFGState, data, dtypes import ast import functools @@ -20,6 +21,11 @@ def _slice(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs): return (slice(*args, **kwargs), ) +@oprepo.infers_descriptor('slice') +def _infer_slice(input_descs, *args, **kwargs): + return (data.Scalar(dtypes.pyobject(), transient=True), ) + + @oprepo.replaces_operator('Array', 'MatMult', otherclass='StorageType') def _cast_storage(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, stype: dtypes.StorageType) -> str: desc = sdfg.arrays[arr] @@ -27,6 +33,16 @@ def _cast_storage(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: st return arr +@oprepo.infers_operator_descriptor('MatMult', 'Array', 'StorageType') +def _infer_cast_storage(arr_desc, stype: dtypes.StorageType): + if not isinstance(stype, dtypes.StorageType): + return None + result = arr_desc.clone() + result.storage = stype + result.transient = True + return result + + @oprepo.replaces('dace.elementwise') def elementwise(pv: ProgramVisitor, sdfg: SDFG, @@ -79,3 +95,16 @@ def elementwise(pv: ProgramVisitor, external_edges=True) return out_array + + +@oprepo.infers_descriptor('dace.elementwise') +def _infer_elementwise(input_descs, func: Union[StringLiteral, str], in_array: str, out_array=None, **_kw): + desc = _get_desc(input_descs, out_array) if out_array is not None else None + if desc is None: + desc = _get_desc(input_descs, in_array) + if not isinstance(desc, data.Data): + return None + result = desc.clone() + if out_array is None: + result.transient = True + return result diff --git a/dace/frontend/python/replacements/mpi.py b/dace/frontend/python/replacements/mpi.py index a3653ae2ea..74cbf066b7 100644 --- a/dace/frontend/python/replacements/mpi.py +++ b/dace/frontend/python/replacements/mpi.py @@ -3,7 +3,7 @@ import itertools import sympy as sp -from dace import dtypes, symbolic +from dace import data, dtypes, symbolic from dace.frontend.common import op_repository as oprepo from dace.frontend.python.replacements.utils import ProgramVisitor from dace.memlet import Memlet @@ -1444,3 +1444,138 @@ def _distr_matmult(pv: ProgramVisitor, state.add_edge(tasklet, '_c', cnode, None, Memlet.from_array(*out)) return out[0] + + +# -------------------------------------------------------------------- # +# Descriptor inference for MPI/distributed replacements # +# -------------------------------------------------------------------- # + +from dace.frontend.common.op_repository import infers_descriptor, infers_method_descriptor, infers_operator_descriptor +from dace.frontend.python.replacements.type_inference import _get_desc + + +def _pyobject_scalar_descriptor(): + return data.Scalar(dtypes.pyobject(), transient=True) + + +def _request_descriptor(): + return data.Array(dtypes.opaque('MPI_Request'), [1], transient=True) + + +def _int_vector_descriptor(length: int): + return data.Array(dtypes.int32, [length], transient=True) + + +def _zero_output(*_args, **_kwargs): + return () + + +@infers_descriptor('mpi4py.MPI.COMM_WORLD.Create_cart') +@infers_descriptor('dace.comm.Cart_create') +def _infer_cart_create(input_descs, dims, **_kw): + return _pyobject_scalar_descriptor() + + +@infers_method_descriptor('Intracomm', 'Create_cart') +def _infer_intracomm_create(self_desc, dims, **_kw): + return _pyobject_scalar_descriptor() + + +@infers_descriptor('dace.comm.Cart_sub') +def _infer_cart_sub(input_descs, parent_grid, color, exact_grid=None, **_kw): + return _pyobject_scalar_descriptor() + + +@infers_method_descriptor('ProcessGrid', 'Sub') +def _infer_pgrid_sub(self_desc, color, **_kw): + return _pyobject_scalar_descriptor() + + +for _name in ('mpi4py.MPI.COMM_WORLD.Bcast', 'dace.comm.Bcast', 'mpi4py.MPI.COMM_WORLD.Reduce', 'dace.comm.Reduce', + 'mpi4py.MPI.COMM_WORLD.Alltoall', 'dace.comm.Alltoall', 'mpi4py.MPI.COMM_WORLD.Allreduce', + 'dace.comm.Allreduce', 'mpi4py.MPI.COMM_WORLD.Scatter', 'dace.comm.Scatter', + 'mpi4py.MPI.COMM_WORLD.Gather', 'dace.comm.Gather', 'mpi4py.MPI.COMM_WORLD.Send', 'dace.comm.Send', + 'mpi4py.MPI.COMM_WORLD.Recv', 'dace.comm.Recv', 'dace.comm.Wait', 'mpi4py.MPI.Request.Waitall', + 'dace.comm.Waitall', 'dace.comm.BCGather'): + infers_descriptor(_name)(_zero_output) + +for _cls, _method in (('Cartcomm', 'Bcast'), ('Intracomm', 'Bcast'), ('ProcessGrid', 'Bcast'), + ('Intracomm', 'Alltoall'), ('ProcessGrid', + 'Alltoall'), ('Intracomm', 'Allreduce'), ('ProcessGrid', 'Allreduce'), + ('Intracomm', 'Send'), ('ProcessGrid', 'Send'), ('Intracomm', 'Recv'), ('ProcessGrid', 'Recv')): + infers_method_descriptor(_cls, _method)(_zero_output) + + +@infers_descriptor('mpi4py.MPI.COMM_WORLD.Isend') +@infers_descriptor('dace.comm.Isend') +def _infer_isend(input_descs, buffer, dst, tag, request=None, grid=None, **_kw): + if request is None: + return _request_descriptor() + return () + + +@infers_descriptor('mpi4py.MPI.COMM_WORLD.Irecv') +@infers_descriptor('dace.comm.Irecv') +def _infer_irecv(input_descs, buffer, src, tag, request=None, grid=None, **_kw): + if request is None: + return _request_descriptor() + return () + + +for _cls, _method in (('Intracomm', 'Isend'), ('ProcessGrid', 'Isend'), ('Intracomm', 'Irecv'), ('ProcessGrid', + 'Irecv')): + infers_method_descriptor(_cls, _method)(lambda self_desc, *args, **_kw: _request_descriptor()) + + +def _comm_bool_result(*_args, **_kwargs): + return data.Scalar(dtypes.bool_, transient=True) + + +for _left_cls, _right_cls in itertools.product(['Comm', 'Cartcomm', 'Intracomm'], repeat=2): + for _op in ('Eq', 'NotEq', 'Is', 'IsNot'): + infers_operator_descriptor(_op, _left_cls, _right_cls)(_comm_bool_result) + +for _cls_a, _cls_b, _op in itertools.product(['ProcessGrid'], ['Comm', 'Cartcomm', 'Intracomm'], + ['Eq', 'NotEq', 'Is', 'IsNot']): + infers_operator_descriptor(_op, _cls_a, _cls_b)(_comm_bool_result) + infers_operator_descriptor(_op, _cls_b, _cls_a)(_comm_bool_result) + +for _name in ('dace.comm.Subarray', 'dace.comm.BlockScatter', 'dace.comm.BlockGather', 'dace.comm.Redistribute'): + infers_descriptor(_name)(lambda input_descs, *args, **_kw: _pyobject_scalar_descriptor()) + + +@infers_descriptor('dace.comm.BCScatter') +def _infer_bcscatter(input_descs, in_buffer, out_buffer, block_sizes, **_kw): + return (_int_vector_descriptor(9), _int_vector_descriptor(9)) + + +@infers_descriptor('dace.distr.MatMult') +@infers_descriptor('distr.MatMult') +def _infer_distr_matmult(input_descs, + opa, + opb, + shape, + a_block_sizes=None, + b_block_sizes=None, + c_block_sizes=None, + **_kw): + desc_a = _get_desc(input_descs, opa) + desc_b = _get_desc(input_descs, opb) + if not isinstance(desc_a, data.Data) or not isinstance(desc_b, data.Data): + return None + + if len(desc_a.shape) == 2 and len(desc_b.shape) == 2: + return data.Array(desc_a.dtype, [desc_a.shape[0], desc_b.shape[-1]], transient=True) + if len(desc_a.shape) == 2 and len(desc_b.shape) == 1: + if isinstance(c_block_sizes, (tuple, list)) and c_block_sizes: + out_dim = c_block_sizes[0] + else: + out_dim = desc_a.shape[0] + return data.Array(desc_a.dtype, [out_dim], transient=True) + if len(desc_a.shape) == 1 and len(desc_b.shape) == 2: + if isinstance(c_block_sizes, (tuple, list)) and c_block_sizes: + out_dim = c_block_sizes[0] + else: + out_dim = desc_b.shape[1] + return data.Array(desc_b.dtype, [out_dim], transient=True) + return None diff --git a/dace/frontend/python/replacements/operators.py b/dace/frontend/python/replacements/operators.py index 9fce562074..9d07e595c6 100644 --- a/dace/frontend/python/replacements/operators.py +++ b/dace/frontend/python/replacements/operators.py @@ -4,22 +4,48 @@ """ from dace.frontend.common import op_repository as oprepo from dace.frontend.python import astutils -from dace.frontend.python.common import StringLiteral +from dace.frontend.python.common import ListLiteral, StringLiteral, TupleLiteral from dace.frontend.python.replacements.utils import (ProgramVisitor, broadcast_together, cast_str, np_result_type, representative_num, sym_type) from dace import data, dtypes, subsets, symbolic, Memlet, SDFG, SDFGState from numbers import Number -from typing import List, Sequence, Tuple, Union +from typing import Any, List, Sequence, Tuple, Union import warnings import numpy as np import sympy as sp import dace # noqa: F401 (used during evaluation of data types, e.g. casting in replaced op) +from dace.frontend.common.op_repository import infers_operator_descriptor + numpy_version = np.lib.NumpyVersion(np.__version__) +def _materialize_sequence_literal(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, + literal: Union[ListLiteral, TupleLiteral]) -> str: + from dace.frontend.python.replacements.array_creation_dace import (infer_dynamic_literal_descriptor, + populate_dynamic_literal_array) + + value = literal.value + desc = infer_dynamic_literal_descriptor(value, sdfg) + if desc is None: + raise SyntaxError('Operand cannot be materialized as an array literal') + + name = sdfg.temp_data_name() + name = sdfg.add_datadesc(name, desc, find_new_name=True) + init_states = getattr(visitor, '_literal_init_states', None) + if init_states is None: + init_states = {} + setattr(visitor, '_literal_init_states', init_states) + init_state = init_states.get(state) + if init_state is None: + init_state = visitor.cfg_target.add_state_before(state) + init_states[state] = init_state + populate_dynamic_literal_array(init_state, sdfg, name, value) + return name + + def _unop(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, opcode: str, opname: str): """ Implements a general element-wise array unary operator. """ arr1 = sdfg.arrays[op1] @@ -775,6 +801,16 @@ def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: st def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): return _array_sym_binop(visitor, sdfg, state, op1, op2, op, opcode) + @oprepo.replaces_operator('Array', op, otherclass='ListLiteral') + def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: ListLiteral): + op2_arr = _materialize_sequence_literal(visitor, sdfg, state, op2) + return _array_array_binop(visitor, sdfg, state, op1, op2_arr, op, opcode) + + @oprepo.replaces_operator('Array', op, otherclass='TupleLiteral') + def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: TupleLiteral): + op2_arr = _materialize_sequence_literal(visitor, sdfg, state, op2) + return _array_array_binop(visitor, sdfg, state, op1, op2_arr, op, opcode) + @oprepo.replaces_operator('View', op, otherclass='View') def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): return _array_array_binop(visitor, sdfg, state, op1, op2, op, opcode) @@ -799,6 +835,16 @@ def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: st def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): return _array_sym_binop(visitor, sdfg, state, op1, op2, op, opcode) + @oprepo.replaces_operator('View', op, otherclass='ListLiteral') + def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: ListLiteral): + op2_arr = _materialize_sequence_literal(visitor, sdfg, state, op2) + return _array_array_binop(visitor, sdfg, state, op1, op2_arr, op, opcode) + + @oprepo.replaces_operator('View', op, otherclass='TupleLiteral') + def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: TupleLiteral): + op2_arr = _materialize_sequence_literal(visitor, sdfg, state, op2) + return _array_array_binop(visitor, sdfg, state, op1, op2_arr, op, opcode) + @oprepo.replaces_operator('Scalar', op, otherclass='Array') def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): return _array_array_binop(visitor, sdfg, state, op1, op2, op, opcode) @@ -847,6 +893,26 @@ def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: st def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): return _const_const_binop(visitor, sdfg, state, op1, op2, op, opcode) + @oprepo.replaces_operator('ListLiteral', op, otherclass='Array') + def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: ListLiteral, op2: str): + op1_arr = _materialize_sequence_literal(visitor, sdfg, state, op1) + return _array_array_binop(visitor, sdfg, state, op1_arr, op2, op, opcode) + + @oprepo.replaces_operator('ListLiteral', op, otherclass='View') + def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: ListLiteral, op2: str): + op1_arr = _materialize_sequence_literal(visitor, sdfg, state, op1) + return _array_array_binop(visitor, sdfg, state, op1_arr, op2, op, opcode) + + @oprepo.replaces_operator('TupleLiteral', op, otherclass='Array') + def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: TupleLiteral, op2: str): + op1_arr = _materialize_sequence_literal(visitor, sdfg, state, op1) + return _array_array_binop(visitor, sdfg, state, op1_arr, op2, op, opcode) + + @oprepo.replaces_operator('TupleLiteral', op, otherclass='View') + def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: TupleLiteral, op2: str): + op1_arr = _materialize_sequence_literal(visitor, sdfg, state, op1) + return _array_array_binop(visitor, sdfg, state, op1_arr, op2, op, opcode) + @oprepo.replaces_operator('BoolConstant', op, otherclass='Array') def _op(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): return _array_const_binop(visitor, sdfg, state, op1, op2, op, opcode) @@ -926,3 +992,69 @@ def _op(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: StringLite } for op, method in _boolop_to_method.items(): _makeboolop(op, method) + + +def _operand_shape(operand) -> List[Any]: + if isinstance(operand, data.Scalar): + return [] + if isinstance(operand, data.Data): + return list(operand.shape) + return [] + + +def _elementwise_binary_descriptor(left_desc, right_desc, operator: str) -> Union[data.Data, None]: + try: + result_dtype, _casting = result_type([left_desc, right_desc], operator) + except Exception: + return None + + if not isinstance(result_dtype, dtypes.typeclass): + return None + + left_shape = _operand_shape(left_desc) + right_shape = _operand_shape(right_desc) + + try: + out_shape, _ranges, _out_idx, _left_idx, _right_idx = broadcast_together(left_shape, right_shape) + except Exception: + return None + + if len(out_shape) == 0: + return data.Scalar(result_dtype, transient=True) + return data.Array(result_dtype, list(out_shape), transient=True) + + +def _elementwise_unary_descriptor(operand_desc, operator: str) -> Union[data.Data, None]: + try: + result_dtype, _casting = result_type([operand_desc], operator) + except Exception: + return None + + if not isinstance(result_dtype, dtypes.typeclass): + return None + + if not isinstance(operand_desc, data.Data) or isinstance(operand_desc, data.Scalar): + return data.Scalar(result_dtype, transient=True) + return data.Array(result_dtype, list(operand_desc.shape), transient=True) + + +def _register_unary_operator_descriptor(opname: str) -> None: + + @infers_operator_descriptor(opname) + def _infer(operand_desc): + return _elementwise_unary_descriptor(operand_desc, opname) + + +def _register_binary_operator_descriptor(opname: str) -> None: + + @infers_operator_descriptor(opname) + def _infer(left_desc, right_desc): + return _elementwise_binary_descriptor(left_desc, right_desc, opname) + + +for _descriptor_op in ('Add', 'Sub', 'Mult', 'Div', 'FloorDiv', 'Mod', 'Pow', 'LShift', 'RShift', 'BitOr', 'BitXor', + 'BitAnd', 'And', 'Or', 'Eq', 'NotEq', 'Lt', 'LtE', 'Gt', 'GtE', 'Is', 'IsNot'): + _register_binary_operator_descriptor(_descriptor_op) + +for _descriptor_unary_op in ('UAdd', 'USub', 'Not', 'Invert'): + _register_unary_operator_descriptor(_descriptor_unary_op) diff --git a/dace/frontend/python/replacements/pymath.py b/dace/frontend/python/replacements/pymath.py index ad39909933..1169586cb7 100644 --- a/dace/frontend/python/replacements/pymath.py +++ b/dace/frontend/python/replacements/pymath.py @@ -4,7 +4,7 @@ """ from dace.frontend.common import op_repository as oprepo from dace.frontend.python.replacements.utils import ProgramVisitor, complex_to_scalar, simple_call -from dace import dtypes, symbolic, SDFG, SDFGState +from dace import data, dtypes, symbolic, SDFG, SDFGState from numbers import Number from typing import Union @@ -120,3 +120,84 @@ def _abs(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, input: Union[str, Num @oprepo.replaces('round') def _round(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, input: Union[str, Number, symbolic.symbol]): return simple_call(pv, sdfg, state, input, 'round', dtypes.typeclass(int)) + + +# -------------------------------------------------------------------- # +# Descriptor inference for math attributes (schedule-tree frontend) # +# -------------------------------------------------------------------- # + +from dace.frontend.common.op_repository import infers_attribute_descriptor, infers_descriptor, infers_method_descriptor +from dace.frontend.python.replacements.utils import complex_to_scalar as _complex_to_scalar +from dace.frontend.python.replacements.type_inference import _get_desc + + +def _clone_shape_preserving_descriptor(desc: data.Data, dtype=None): + out_dtype = desc.dtype if dtype is None else dtype + retval = desc.clone() + retval.dtype = out_dtype + retval.transient = True + return retval + + +def _infer_shape_preserving_math_descriptor(input_descs, input, dtype=None, **_kw): + desc = _get_desc(input_descs, input) + if desc is None: + return None + return _clone_shape_preserving_descriptor(desc, dtype=dtype) + + +def _infer_attr_real(self_desc): + out_dtype = _complex_to_scalar(self_desc.dtype) + if isinstance(self_desc, data.Scalar): + return data.Scalar(out_dtype) + return data.Array(out_dtype, list(self_desc.shape), transient=True) + + +def _infer_attr_imag(self_desc): + out_dtype = _complex_to_scalar(self_desc.dtype) + if isinstance(self_desc, data.Scalar): + return data.Scalar(out_dtype) + return data.Array(out_dtype, list(self_desc.shape), transient=True) + + +for _name in ('exp', 'dace.exp', 'numpy.exp', 'math.exp', 'sin', 'dace.sin', 'numpy.sin', 'math.sin', 'cos', 'dace.cos', + 'numpy.cos', 'math.cos', 'sqrt', 'dace.sqrt', 'numpy.sqrt', 'math.sqrt', 'log', 'dace.log', 'numpy.log', + 'math.log', 'log10', 'dace.log10', 'math.log10', 'abs'): + infers_descriptor(_name)(_infer_shape_preserving_math_descriptor) + +for _name in ('math.floor', 'math.ceil', 'round'): + infers_descriptor(_name)(lambda input_descs, input, **_kw: _infer_shape_preserving_math_descriptor( + input_descs, input, dtype=dtypes.typeclass(int))) + + +@infers_descriptor('conj') +@infers_descriptor('dace.conj') +@infers_descriptor('numpy.conj') +def _infer_conj(input_descs, input, **_kw): + return _infer_shape_preserving_math_descriptor(input_descs, input) + + +@infers_descriptor('real') +@infers_descriptor('dace.real') +@infers_descriptor('numpy.real') +def _infer_real(input_descs, input, **_kw): + desc = _get_desc(input_descs, input) + if desc is None: + return None + return _infer_attr_real(desc) + + +@infers_descriptor('imag') +@infers_descriptor('dace.imag') +@infers_descriptor('numpy.imag') +def _infer_imag(input_descs, input, **_kw): + desc = _get_desc(input_descs, input) + if desc is None: + return None + return _infer_attr_imag(desc) + + +for _cls in ('Array', 'Scalar', 'View'): + infers_attribute_descriptor(_cls, 'real')(_infer_attr_real) + infers_attribute_descriptor(_cls, 'imag')(_infer_attr_imag) + infers_method_descriptor(_cls, 'conj')(lambda self_desc, **_kw: _clone_shape_preserving_descriptor(self_desc)) diff --git a/dace/frontend/python/replacements/reduction.py b/dace/frontend/python/replacements/reduction.py index cd7938ece4..73e929482e 100644 --- a/dace/frontend/python/replacements/reduction.py +++ b/dace/frontend/python/replacements/reduction.py @@ -7,13 +7,15 @@ from dace.frontend.common import op_repository as oprepo from dace.frontend.python.nested_call import NestedCall from dace.frontend.python.replacements.utils import ProgramVisitor, normalize_axes -from dace import dtypes, nodes, subsets, symbolic, Memlet, SDFG, SDFGState +from dace import data, dtypes, nodes, subsets, symbolic, Memlet, SDFG, SDFGState import copy import functools from numbers import Integral, Number from typing import Any, Dict, Callable, Optional, Union +import numpy as np + @oprepo.replaces('dace.reduce') def reduce(pv: ProgramVisitor, @@ -89,6 +91,13 @@ def reduce(pv: ProgramVisitor, return [] +@oprepo.infers_descriptor('dace.reduce') +def _infer_reduce(input_descs, redfunction, in_array, out_array=None, axis=None, identity=None, **_kw): + if out_array is not None: + return () + return _reduction_descriptor(input_descs, in_array, axis) + + @oprepo.replaces('numpy.sum') def _sum(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None): return reduce(pv, sdfg, state, "lambda x, y: x + y", a, axis=axis, identity=0) @@ -391,3 +400,93 @@ def _ndarray_argmin(pv: ProgramVisitor, state.add_nedge(r, w, Memlet.from_array(newarr, sdfg.arrays[newarr])) newarr = out return newarr + + +# -------------------------------------------------------------------- # +# Descriptor inference for reductions (schedule-tree frontend) # +# -------------------------------------------------------------------- # + +from dace.frontend.common.op_repository import infers_descriptor, infers_method_descriptor +from dace.frontend.python.replacements.type_inference import (_reduction_descriptor, _method_reduction_descriptor, + _get_desc) + + +def _infer_basic_reduction(input_descs, arr, axis=None, **_kw): + return _reduction_descriptor(input_descs, arr, axis) + + +for _name in ('numpy.sum', 'numpy.prod', 'numpy.max', 'numpy.amax', 'numpy.min', 'numpy.amin', 'numpy.any', + 'numpy.all'): + infers_descriptor(_name)(_infer_basic_reduction) + + +@infers_descriptor('sum') +def _infer_builtin_sum(input_descs, arr, **_kw): + return _reduction_descriptor(input_descs, arr, axis=0) + + +def _infer_builtin_minmax(input_descs, first_arg, *args, **_kw): + from dace.frontend.python.replacements.operators import result_type + + operands = [] + for arg in (first_arg, ) + args: + desc = _get_desc(input_descs, arg) + if isinstance(desc, data.Data) and not isinstance(desc, data.Scalar): + return None + if desc is not None: + operands.append(desc) + continue + if isinstance(arg, (Number, symbolic.symbol)): + operands.append(arg) + continue + return None + + try: + out_dtype, _casts = result_type(operands) + except Exception: + return None + + if not isinstance(out_dtype, dtypes.typeclass): + return None + return data.Scalar(out_dtype, transient=True) + + +for _name in ('max', 'min'): + infers_descriptor(_name)(_infer_builtin_minmax) + + +@infers_descriptor('numpy.mean') +def _infer_mean(input_descs, arr, axis=None, **_kw): + desc = _get_desc(input_descs, arr) + if desc is None: + return None + out_dtype = desc.dtype + if out_dtype.type in (int, np.int32, np.int64, np.int16, np.int8, np.uint8, np.uint16, np.uint32, np.uint64, bool, + np.bool_): + out_dtype = dtypes.float64 + return _reduction_descriptor(input_descs, arr, axis, dtype_override=out_dtype) + + +@infers_descriptor('numpy.argmax') +@infers_descriptor('numpy.argmin') +def _infer_argminmax(input_descs, arr, axis=None, **_kw): + return _reduction_descriptor(input_descs, arr, axis, dtype_override=dtypes.int64) + + +# Method inference for .max(), .min(), .argmax(), .argmin() +def _infer_method_basic_reduction(self_desc, axis=None, **_kw): + return _method_reduction_descriptor(self_desc, axis) + + +for _cls in ('Array', 'View', 'Scalar'): + for _method in ('max', 'min'): + infers_method_descriptor(_cls, _method)(_infer_method_basic_reduction) + + +def _infer_method_argminmax(self_desc, axis=None, **_kw): + return _method_reduction_descriptor(self_desc, axis, dtype_override=dtypes.int64) + + +for _cls in ('Array', 'View', 'Scalar'): + for _method in ('argmax', 'argmin'): + infers_method_descriptor(_cls, _method)(_infer_method_argminmax) diff --git a/dace/frontend/python/replacements/torch_autodiff.py b/dace/frontend/python/replacements/torch_autodiff.py index 2433ae567b..f09c172d8e 100644 --- a/dace/frontend/python/replacements/torch_autodiff.py +++ b/dace/frontend/python/replacements/torch_autodiff.py @@ -3,6 +3,7 @@ Integration with the dace python frontend """ +import copy from typing import Optional, Union, Sequence import itertools @@ -112,6 +113,11 @@ def backward(pv: newast.ProgramVisitor, state.add_edge(bwd_node, conn_name, write_an, None, sdfg.make_array_memlet(grad_name)) +@op_repository.infers_descriptor('torch.autograd.backward') +def _infer_backward(input_descs, tensors: TensorOrTensors, grads: Optional[TensorOrTensors] = None, **_kw): + return () + + @op_repository.replaces_attribute('ParameterArray', 'grad') def grad(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str) -> str: """ @@ -132,6 +138,14 @@ def grad(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, arr: str) -> str: return desc.gradient +@op_repository.infers_attribute_descriptor('ParameterArray', 'grad') +def _infer_grad(self_desc: ParameterArray): + result = copy.deepcopy(self_desc) + result.__class__ = data.Array + result.transient = True + return result + + @op_repository.replaces_method('Array', 'requires_grad_') @op_repository.replaces_method('Scalar', 'requires_grad_') def requires_grad_(pv: newast.ProgramVisitor, sdfg: SDFG, state: SDFGState, self: str): @@ -145,6 +159,21 @@ def requires_grad_(pv: newast.ProgramVisitor, sdfg: SDFG, state: SDFGState, self ParameterArray.make_parameter(sdfg, self) +@op_repository.infers_method_descriptor('Array', 'requires_grad_') +@op_repository.infers_method_descriptor('Scalar', 'requires_grad_') +def _infer_requires_grad(input_desc, **_kw): + return () + + +@op_repository.infers_method_self_descriptor('Array', 'requires_grad_') +def _infer_requires_grad_self(self_desc, **_kw): + result = copy.deepcopy(self_desc) + result.__class__ = ParameterArray + result.gradient = None + result.transient = True + return result + + @op_repository.replaces_method('Array', 'backward') @op_repository.replaces_method('Scalar', 'backward') def backward_method(pv: newast.ProgramVisitor, sdfg: SDFG, state: SDFGState, self: str, grad: Optional[str] = None): @@ -154,4 +183,10 @@ def backward_method(pv: newast.ProgramVisitor, sdfg: SDFG, state: SDFGState, sel backward(pv, sdfg, state, self, grad) +@op_repository.infers_method_descriptor('Array', 'backward') +@op_repository.infers_method_descriptor('Scalar', 'backward') +def _infer_backward_method(self_desc, grad: Optional[str] = None, **_kw): + return () + + dace.hooks.register_sdfg_call_hook(before_hook=lambda sdfg: expand_nodes(sdfg, lambda n: isinstance(n, BackwardPass))) diff --git a/dace/frontend/python/replacements/type_inference.py b/dace/frontend/python/replacements/type_inference.py new file mode 100644 index 0000000000..c7d9ee22d1 --- /dev/null +++ b/dace/frontend/python/replacements/type_inference.py @@ -0,0 +1,77 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +""" +Shared helpers for lightweight descriptor-inference functions used by the +schedule-tree frontend. The actual ``@infers_descriptor`` (and related) +registrations live next to their SDFG-level replacements in the respective +``replacements/*.py`` modules. +""" + +from numbers import Number +from typing import Dict, Optional + +import numpy as np + +from dace import data, dtypes +from dace.frontend.python.replacements.utils import normalize_axes + +# -------------------------------------------------------------------- # +# Helpers # +# -------------------------------------------------------------------- # + + +def _get_desc(input_descs: Dict[str, data.Data], arg) -> Optional[data.Data]: + """Resolve *arg* to a descriptor if it names an input array.""" + if isinstance(arg, str) and arg in input_descs: + return input_descs[arg] + if isinstance(arg, data.Data): + return arg + return None + + +def _to_int(v) -> Optional[int]: + """Try to convert *v* to a plain Python int (for axis, shape elements).""" + if isinstance(v, (int, np.integer)): + return int(v) + if isinstance(v, Number): + iv = int(v) + if iv == v: + return iv + return None + + +def _reduction_descriptor(input_descs: Dict[str, data.Data], + arr, + axis=None, + dtype_override: Optional[dtypes.typeclass] = None) -> Optional[data.Data]: + """Shared logic for reduction-style operations (sum, max, prod, ...).""" + desc = _get_desc(input_descs, arr) + if desc is None: + return None + + out_dtype = dtype_override or desc.dtype + shape = list(desc.shape) + + if axis is None: + return data.Scalar(out_dtype) + + if not isinstance(axis, (tuple, list)): + axis = (axis, ) + axis = tuple(_to_int(a) for a in axis) + if any(a is None for a in axis): + return None + axis = tuple(normalize_axes(axis, len(shape))) + + if len(axis) == len(shape): + return data.Scalar(out_dtype) + + out_shape = [s for i, s in enumerate(shape) if i not in axis] + if not out_shape: + return data.Scalar(out_dtype) + return data.Array(out_dtype, out_shape, transient=True) + + +def _method_reduction_descriptor(self_desc: data.Data, + axis=None, + dtype_override: Optional[dtypes.typeclass] = None) -> Optional[data.Data]: + """Shared logic for method-style reductions (a.sum(), a.max(), ...).""" + return _reduction_descriptor({}, self_desc, axis, dtype_override) diff --git a/dace/frontend/python/replacements/ufunc.py b/dace/frontend/python/replacements/ufunc.py index f924007571..9e4fa1d803 100644 --- a/dace/frontend/python/replacements/ufunc.py +++ b/dace/frontend/python/replacements/ufunc.py @@ -6,8 +6,8 @@ from dace.frontend.common import op_repository as oprepo from dace.frontend.python import astutils from dace.frontend.python.nested_call import NestedCall -from dace.frontend.python.replacements.utils import (ProgramVisitor, Shape, UfuncInput, UfuncOutput, normalize_axes, - sym_type) +from dace.frontend.python.replacements.utils import (ProgramVisitor, Shape, UfuncInput, UfuncOutput, broadcast_together, + normalize_axes, representative_num, sym_type) import dace.frontend.python.memlet_parser as mem_parser from dace import InterstateEdge, Memlet, SDFG, SDFGState from dace import dtypes, data, symbolic, nodes @@ -1873,6 +1873,288 @@ def _ndarray_any(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, arr: str, kwa return implement_ufunc_reduce(pv, None, sdfg, state, 'logical_or', [arr], kwargs)[0] +# -------------------------------------------------------------------- # +# Descriptor inference for method reductions (schedule-tree frontend) # +# -------------------------------------------------------------------- # + +from dace.frontend.common.op_repository import infers_descriptor, infers_method_descriptor, infers_ufunc_descriptor +from dace.frontend.python.replacements.type_inference import _method_reduction_descriptor + + +def _clone_inferred_output(output): + if isinstance(output, data.Data): + result = copy.deepcopy(output) + result.transient = True + return result + if isinstance(output, tuple): + return tuple(_clone_inferred_output(element) for element in output) + if isinstance(output, list): + return [_clone_inferred_output(element) for element in output] + return None + + +def _resolve_inference_output(input_descs: Dict[str, data.Data], output): + if isinstance(output, str) and output in input_descs: + return _clone_inferred_output(input_descs[output]) + return _clone_inferred_output(output) + + +def _resolve_inference_operand(input_descs: Dict[str, data.Data], arg): + if isinstance(arg, str) and arg in input_descs: + return input_descs[arg] + if isinstance(arg, (list, tuple, np.ndarray)) or dtypes.is_array(arg) or hasattr(arg, '__array__'): + try: + descriptor = data.create_datadescriptor(arg) + except Exception: + descriptor = None + if descriptor is not None: + return descriptor + return arg + + +def _descriptor_from_dtype_and_shape(dtype: dtypes.typeclass, shape: Sequence[Any]) -> data.Data: + if len(shape) == 0: + return data.Scalar(dtype, transient=True) + return data.Array(dtype, list(shape), transient=True) + + +def _descriptor_from_sample(sample, shape: Sequence[Any], dtype_override: Optional[dtypes.typeclass] = None): + out_dtype = dtype_override + if out_dtype is None: + try: + out_dtype = dtypes.dtype_to_typeclass(np.asarray(sample).dtype.type) + except Exception: + return None + return _descriptor_from_dtype_and_shape(out_dtype, shape) + + +def _resolve_dtype_override(dtype) -> Optional[dtypes.typeclass]: + if dtype is None: + return None + if isinstance(dtype, dtypes.typeclass): + return dtype + try: + return dtypes.dtype_to_typeclass(np.dtype(dtype).type) + except Exception: + return None + + +def _sample_operand_value(operand): + if isinstance(operand, data.Data): + return representative_num(operand.dtype) + if isinstance(operand, dtypes.typeclass): + return representative_num(operand) + if symbolic.issymbolic(operand): + return representative_num(sym_type(operand)) + if isinstance(operand, np.generic): + return operand.item() + if isinstance(operand, (Number, bool)): + return operand + return None + + +def _broadcast_shape_from_operands(operands: Sequence[Any]) -> Optional[List[Any]]: + array_shapes = [ + list(operand.shape) for operand in operands + if isinstance(operand, data.Data) and not isinstance(operand, data.Scalar) + ] + if not array_shapes: + return [] + + out_shape = array_shapes[0] + for shape in array_shapes[1:]: + try: + out_shape, _, _, _, _ = broadcast_together(out_shape, shape) + except Exception: + return None + return list(out_shape) + + +def _resolve_explicit_output(input_descs: Dict[str, data.Data], outputs): + if not outputs: + return None + if len(outputs) == 1: + return _resolve_inference_output(input_descs, outputs[0]) + return tuple(_resolve_inference_output(input_descs, output) for output in outputs) + + +def _infer_reduce_shape(desc: data.Data, axis, keepdims: bool) -> Optional[List[Any]]: + shape = list(getattr(desc, 'shape', [])) + if axis is None: + return [1] * len(shape) if keepdims and shape else [] + + if not isinstance(axis, (tuple, list)): + axis = (axis, ) + try: + axis = tuple(normalize_axes(tuple(int(a) for a in axis), len(shape))) + except Exception: + return None + + if keepdims: + return [1 if i in axis else dim for i, dim in enumerate(shape)] + return [dim for i, dim in enumerate(shape) if i not in axis] + + +def _infer_reduce_dtype(ufunc_name: str, desc: data.Data, method_name: str) -> Optional[dtypes.typeclass]: + try: + sample_array = np.array([representative_num(desc.dtype)], dtype=desc.dtype.as_numpy_dtype()) + ufunc = getattr(np, ufunc_name) + if method_name == 'reduce': + result = ufunc.reduce(sample_array) + elif method_name == 'accumulate': + result = ufunc.accumulate(sample_array) + else: + result = ufunc.outer(sample_array, sample_array) + return dtypes.dtype_to_typeclass(np.asarray(result).dtype.type) + except Exception: + return None + + +@infers_ufunc_descriptor('ufunc') +def _infer_ufunc_descriptor(input_descs: Dict[str, data.Data], ufunc_name: str, *args, **kwargs): + impl = ufuncs.get(ufunc_name) + if impl is None: + return None + + num_inputs = len(impl['inputs']) + explicit_outputs = list(args[num_inputs:]) + kw_out = kwargs.get('out') + if kw_out is not None: + explicit_outputs = list(kw_out if isinstance(kw_out, tuple) else (kw_out, )) + explicit_result = _resolve_explicit_output(input_descs, explicit_outputs) + if explicit_result is not None: + return explicit_result + + operands = [_resolve_inference_operand(input_descs, arg) for arg in args[:num_inputs]] + if any(operand is None for operand in operands): + return None + + sample_args = [_sample_operand_value(operand) for operand in operands] + if any(sample is None for sample in sample_args): + return None + + out_shape = _broadcast_shape_from_operands(operands) + if out_shape is None: + return None + + try: + sample_result = getattr(np, ufunc_name)(*sample_args) + except Exception: + return None + + dtype_override = _resolve_dtype_override(kwargs.get('dtype')) + if isinstance(sample_result, tuple): + results = tuple(_descriptor_from_sample(sample, out_shape) for sample in sample_result) + return None if any(result is None for result in results) else results + return _descriptor_from_sample(sample_result, out_shape, dtype_override) + + +@infers_descriptor('numpy.clip') +def _infer_clip(input_descs: Dict[str, data.Data], a, a_min=None, a_max=None, **kwargs): + if a_min is None and a_max is None: + return None + if a_min is None: + return _infer_ufunc_descriptor(input_descs, 'minimum', a, a_max, **kwargs) + if a_max is None: + return _infer_ufunc_descriptor(input_descs, 'maximum', a, a_min, **kwargs) + return _infer_ufunc_descriptor(input_descs, 'clip', a, a_min, a_max, **kwargs) + + +@infers_ufunc_descriptor('reduce') +def _infer_ufunc_reduce_descriptor(input_descs: Dict[str, data.Data], + ufunc_name: str, + arr, + axis=0, + dtype=None, + out=None, + keepdims=False, + **_kwargs): + desc = _resolve_inference_operand(input_descs, arr) + if not isinstance(desc, data.Data): + return None + + explicit_result = _resolve_explicit_output(input_descs, [out] if out is not None else []) + if explicit_result is not None: + return explicit_result + + out_dtype = _resolve_dtype_override(dtype) or _infer_reduce_dtype(ufunc_name, desc, 'reduce') or desc.dtype + out_shape = _infer_reduce_shape(desc, axis, keepdims) + if out_shape is None: + return None + return _descriptor_from_dtype_and_shape(out_dtype, out_shape) + + +@infers_ufunc_descriptor('accumulate') +def _infer_ufunc_accumulate_descriptor(input_descs: Dict[str, data.Data], + ufunc_name: str, + arr, + axis=0, + dtype=None, + out=None, + **_kwargs): + desc = _resolve_inference_operand(input_descs, arr) + if not isinstance(desc, data.Data): + return None + + explicit_result = _resolve_explicit_output(input_descs, [out] if out is not None else []) + if explicit_result is not None: + return explicit_result + + del axis + out_dtype = _resolve_dtype_override(dtype) or _infer_reduce_dtype(ufunc_name, desc, 'accumulate') or desc.dtype + shape = list(getattr(desc, 'shape', [])) + return _descriptor_from_dtype_and_shape(out_dtype, shape) + + +@infers_ufunc_descriptor('outer') +def _infer_ufunc_outer_descriptor(input_descs: Dict[str, data.Data], ufunc_name: str, left, right, out=None, **kwargs): + explicit_result = _resolve_explicit_output(input_descs, [out] if out is not None else []) + if explicit_result is not None: + return explicit_result + + left_operand = _resolve_inference_operand(input_descs, left) + right_operand = _resolve_inference_operand(input_descs, right) + left_sample = _sample_operand_value(left_operand) + right_sample = _sample_operand_value(right_operand) + if left_sample is None or right_sample is None: + return None + + try: + sample_result = getattr(np, ufunc_name)(left_sample, right_sample) + except Exception: + return None + + left_shape = [] if not isinstance(left_operand, data.Data) or isinstance(left_operand, data.Scalar) else list( + left_operand.shape) + right_shape = [] if not isinstance(right_operand, data.Data) or isinstance(right_operand, data.Scalar) else list( + right_operand.shape) + out_shape = left_shape + right_shape + dtype_override = _resolve_dtype_override(kwargs.get('dtype')) + return _descriptor_from_sample(sample_result, out_shape, dtype_override) + + +def _infer_method_reduction(self_desc, axis=None, **_kw): + return _method_reduction_descriptor(self_desc, axis) + + +for _cls in ('Array', 'View', 'Scalar'): + for _method in ('sum', 'prod', 'all', 'any'): + infers_method_descriptor(_cls, _method)(_infer_method_reduction) + + +def _infer_method_mean(self_desc, axis=None, **_kw): + import numpy as _np + out_dtype = self_desc.dtype + if out_dtype.type in (int, _np.int32, _np.int64, _np.int16, _np.int8, _np.uint8, _np.uint16, _np.uint32, _np.uint64, + bool, _np.bool_): + out_dtype = dtypes.float64 + return _method_reduction_descriptor(self_desc, axis, dtype_override=out_dtype) + + +for _cls in ('Array', 'View', 'Scalar'): + infers_method_descriptor(_cls, 'mean')(_infer_method_mean) + + @oprepo.replaces('numpy.clip') def _clip(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a, a_min=None, a_max=None, **kwargs): if a_min is None and a_max is None: diff --git a/dace/frontend/python/schedule_tree/__init__.py b/dace/frontend/python/schedule_tree/__init__.py new file mode 100644 index 0000000000..b901051553 --- /dev/null +++ b/dace/frontend/python/schedule_tree/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Support modules for the direct Python schedule-tree frontend.""" + +from .attribute_rewriter import AttributeRewriter +from .callback_support import CallbackHandler, CallbackOutliner +from .callable_support import CallableArgumentSpecializer, CallableResolver +from .desugaring import callback_reason, desugar_schedule_tree_expansions +from .dynamic_scope_copy import promote_dynamic_scope_copies +from .expression_support import ExpressionPlanningContext, GenericExpressionSupportLibrary +from .function_inlining import resolve_function_calls +from .numpy_support import NumpyLoweringContext, NumpySupportLibrary +from .tuple_assignment import is_container_initialization, is_tuple_element_assignment +from .type_inference import ScheduleTreeTypeInference, _Binding + +__all__ = [ + 'ScheduleTreeTypeInference', + '_Binding', + 'AttributeRewriter', + 'CallbackHandler', + 'CallbackOutliner', + 'CallableArgumentSpecializer', + 'CallableResolver', + 'desugar_schedule_tree_expansions', + 'callback_reason', + 'promote_dynamic_scope_copies', + 'ExpressionPlanningContext', + 'GenericExpressionSupportLibrary', + 'NumpyLoweringContext', + 'NumpySupportLibrary', + 'resolve_function_calls', + 'is_container_initialization', + 'is_tuple_element_assignment', +] diff --git a/dace/frontend/python/schedule_tree/array_literal_support.py b/dace/frontend/python/schedule_tree/array_literal_support.py new file mode 100644 index 0000000000..5f9c27fc53 --- /dev/null +++ b/dace/frontend/python/schedule_tree/array_literal_support.py @@ -0,0 +1,273 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Array-literal inference and lowering helpers for the direct frontend.""" + +import ast +import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple + +from dace import data, dtypes +from dace.frontend.python import astutils +from dace.frontend.python.replacements.array_creation_dace import infer_array_creation_descriptor +from dace.frontend.python.schedule_tree.static_evaluation import UNRESOLVED, try_resolve_static_value +from dace.memlet import Memlet +from dace.properties import CodeBlock +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +DescriptorInferer = Callable[[ast.AST], Optional[data.Data]] +ScalarDescriptorInferer = Callable[[ast.AST, Optional[data.Data]], Optional[data.Data]] +EvaluationContextFactory = Callable[[], Dict[str, Any]] +OutputTargetResolver = Callable[[ast.AST, ast.AST, Optional[data.Data]], Optional[Tuple[str, Memlet, data.Data]]] +DataAccessResolver = Callable[[ast.AST], Optional[Tuple[str, Memlet, data.Data, Optional[data.Data]]]] +CallableNameResolver = Callable[[ast.AST], str] +TaskletNameFactory = Callable[[ast.AST], str] +ArrayConstructorNameFactory = Callable[[], str] + + +@dataclass(frozen=True) +class ArrayLiteralContext: + infer_descriptor: DescriptorInferer + infer_scalar_descriptor: ScalarDescriptorInferer + evaluation_context: EvaluationContextFactory + resolve_output_target: OutputTargetResolver + resolve_data_access: DataAccessResolver + resolve_callable_name: CallableNameResolver + tasklet_name: TaskletNameFactory + array_constructor_name: ArrayConstructorNameFactory + + +class ArrayLiteralSupportLibrary: + """Descriptor inference and lowering for array-valued literals.""" + + def infer_expression_descriptor(self, context: ArrayLiteralContext, node: ast.AST) -> Optional[data.Data]: + return infer_array_literal_descriptor(node, + context.infer_descriptor, + context.infer_scalar_descriptor, + context.evaluation_context, + callable_name_resolver=context.resolve_callable_name) + + def lower_assignment(self, context: ArrayLiteralContext, target: ast.AST, value: ast.AST, + annotated_descriptor: Optional[data.Data]) -> Optional[tn.ScheduleTreeNode]: + descriptor = self.infer_expression_descriptor(context, value) + if descriptor is None or isinstance(descriptor, data.Scalar): + return None + + output = context.resolve_output_target(target, value, annotated_descriptor) + if output is None: + return None + + _, out_memlet, _ = output + rewritten, input_memlets = _rewrite_with_connectors(value, context.resolve_data_access) + lowered_value = _lowered_array_expression(rewritten, value, context) + tasklet = tn.FrontendTasklet( + name=context.tasklet_name(target), + code=CodeBlock(f'out = {astutils.unparse(ast.fix_missing_locations(lowered_value))}')) + return tn.TaskletNode(node=tasklet, in_memlets=input_memlets, out_memlets={'out': copy.deepcopy(out_memlet)}) + + +def infer_array_literal_descriptor( + node: ast.AST, + infer_descriptor: DescriptorInferer, + infer_scalar_descriptor: ScalarDescriptorInferer, + evaluation_context: EvaluationContextFactory, + *, + callable_name_resolver: Optional[CallableNameResolver] = None) -> Optional[data.Data]: + if isinstance(node, (ast.List, ast.Tuple)): + return _infer_sequence_descriptor(node, infer_descriptor, infer_scalar_descriptor, evaluation_context) + + if isinstance(node, ast.Call) and _is_array_constructor_call(node, evaluation_context, callable_name_resolver): + dtype = _parse_dtype_argument(node, evaluation_context) + ndmin = _parse_ndmin_argument(node, evaluation_context) + if ndmin is None: + return None + obj = _call_argument(node, 0, 'obj') + if obj is None: + return None + return _infer_array_call_object_descriptor(obj, + infer_descriptor, + infer_scalar_descriptor, + evaluation_context, + dtype=dtype, + ndmin=ndmin) + + return None + + +def _infer_array_call_object_descriptor(obj: ast.AST, infer_descriptor: DescriptorInferer, + infer_scalar_descriptor: ScalarDescriptorInferer, + evaluation_context: EvaluationContextFactory, *, + dtype: Optional[dtypes.typeclass], ndmin: int) -> Optional[data.Data]: + if isinstance(obj, (ast.List, ast.Tuple)): + descriptor = _infer_sequence_descriptor(obj, infer_descriptor, infer_scalar_descriptor, evaluation_context) + if descriptor is None: + return None + return _apply_array_coercions(descriptor, dtype=dtype, ndmin=ndmin) + + descriptor = infer_descriptor(obj) + if descriptor is not None and not isinstance(descriptor, data.Scalar): + return _apply_array_coercions(descriptor, dtype=dtype, ndmin=ndmin) + + static_value = try_resolve_static_value(obj, evaluation_context()) + if static_value is UNRESOLVED: + return None + + descriptor = infer_array_creation_descriptor(static_value, dtype=dtype, ndmin=ndmin) + if descriptor is None: + return None + descriptor.transient = True + return descriptor + + +def _infer_sequence_descriptor(node: ast.AST, infer_descriptor: DescriptorInferer, + infer_scalar_descriptor: ScalarDescriptorInferer, + evaluation_context: EvaluationContextFactory) -> Optional[data.Data]: + static_value = try_resolve_static_value(node, evaluation_context()) + if static_value is not UNRESOLVED: + descriptor = infer_array_creation_descriptor(static_value) + if descriptor is not None: + descriptor.transient = True + return descriptor + + shape, dtype = _infer_sequence_shape_dtype(node, infer_descriptor, infer_scalar_descriptor) + if shape is None or dtype is None: + return None + return data.Array(dtype, list(shape), transient=True) + + +def _infer_sequence_shape_dtype( + node: ast.AST, infer_descriptor: DescriptorInferer, infer_scalar_descriptor: ScalarDescriptorInferer +) -> Tuple[Optional[Tuple[int, ...]], Optional[dtypes.typeclass]]: + if not isinstance(node, (ast.List, ast.Tuple)): + descriptor = infer_descriptor(node) or infer_scalar_descriptor(node, None) + if not isinstance(descriptor, data.Scalar): + return (None, None) + return (tuple(), descriptor.dtype) + + child_shapes: list[Tuple[int, ...]] = [] + child_dtype: Optional[dtypes.typeclass] = None + for element in node.elts: + element_shape, element_dtype = _infer_sequence_shape_dtype(element, infer_descriptor, infer_scalar_descriptor) + if element_shape is None or element_dtype is None: + return (None, None) + child_shapes.append(element_shape) + child_dtype = element_dtype if child_dtype is None else dtypes.result_type_of(child_dtype, element_dtype) + + if not child_shapes: + return ((0, ), dtypes.float64) + + first_shape = child_shapes[0] + if any(shape != first_shape for shape in child_shapes[1:]): + return (None, None) + + return ((len(node.elts), ) + first_shape, child_dtype) + + +def _apply_array_coercions(descriptor: data.Data, *, dtype: Optional[dtypes.typeclass], + ndmin: int) -> Optional[data.Data]: + result = copy.deepcopy(descriptor) + if dtype is not None: + result.dtype = dtype + + shape = list(getattr(result, 'shape', ())) + if isinstance(result, data.Scalar): + if ndmin <= 0: + return data.Scalar(result.dtype, transient=True) + shape = [1] * ndmin + elif len(shape) < ndmin: + shape = [1] * (ndmin - len(shape)) + shape + + if isinstance(result, data.Scalar): + return data.Array(result.dtype, shape, transient=True) + + if hasattr(result, 'set_shape'): + result.set_shape(shape) + result.transient = True + return result + + +def _call_argument(node: ast.Call, position: int, keyword: str) -> Optional[ast.AST]: + if len(node.args) > position: + return node.args[position] + for kw in node.keywords: + if kw.arg == keyword: + return kw.value + return None + + +def _parse_dtype_argument(node: ast.Call, evaluation_context: EvaluationContextFactory) -> Optional[dtypes.typeclass]: + dtype_node = _call_argument(node, 1, 'dtype') + if dtype_node is None: + return None + dtype_value = try_resolve_static_value(dtype_node, evaluation_context()) + if dtype_value is UNRESOLVED: + return None + try: + return dtype_value if isinstance(dtype_value, dtypes.typeclass) else dtypes.typeclass(dtype_value) + except TypeError: + return None + + +def _parse_ndmin_argument(node: ast.Call, evaluation_context: EvaluationContextFactory) -> Optional[int]: + ndmin_node = _call_argument(node, 4, 'ndmin') + if ndmin_node is None: + return 0 + ndmin_value = try_resolve_static_value(ndmin_node, evaluation_context()) + if not isinstance(ndmin_value, int): + return None + return ndmin_value + + +def _is_array_constructor_call(node: ast.Call, evaluation_context: EvaluationContextFactory, + callable_name_resolver: Optional[CallableNameResolver]) -> bool: + if callable_name_resolver is not None: + call_name = callable_name_resolver(node.func) + else: + call_name = astutils.rname(node.func) + + if call_name in {'numpy.array', 'numpy.asarray'}: + return True + + resolved = try_resolve_static_value(node.func, evaluation_context()) + module_name = getattr(resolved, '__module__', None) if resolved is not UNRESOLVED else None + callable_name = getattr(resolved, '__name__', None) if resolved is not UNRESOLVED else None + return callable_name in {'array', 'asarray'} and module_name == 'numpy' + + +def _rewrite_with_connectors(node: ast.AST, + resolve_data_access: DataAccessResolver) -> Tuple[ast.AST, Dict[str, Memlet]]: + rewritten = astutils.copy_tree(node) + input_memlets: Dict[str, Memlet] = {} + connector_names: Dict[Tuple[str, str, str], str] = {} + + class _AccessRewriter(ast.NodeTransformer): + + def visit(self, current: ast.AST) -> ast.AST: + access = resolve_data_access(current) + if access is not None: + name, memlet, _descriptor, _view_descriptor = access + key = (name, str(memlet.subset), str(memlet.other_subset) if memlet.other_subset is not None else '') + connector = connector_names.get(key) + if connector is None: + connector = f'in{len(connector_names)}' + connector_names[key] = connector + input_memlets[connector] = copy.deepcopy(memlet) + return ast.copy_location(ast.Name(id=connector, ctx=ast.Load()), current) + return super().visit(current) + + return (_AccessRewriter().visit(rewritten), input_memlets) + + +def _lowered_array_expression(node: ast.AST, original: ast.AST, context: ArrayLiteralContext) -> ast.AST: + if isinstance(original, ast.Call): + return node + + constructor = _dotted_name_ast(context.array_constructor_name()) + return ast.Call(func=constructor, args=[node], keywords=[]) + + +def _dotted_name_ast(name: str) -> ast.AST: + parts = name.split('.') + expr: ast.AST = ast.Name(id=parts[0], ctx=ast.Load()) + for part in parts[1:]: + expr = ast.Attribute(value=expr, attr=part, ctx=ast.Load()) + return expr diff --git a/dace/frontend/python/schedule_tree/attribute_rewriter.py b/dace/frontend/python/schedule_tree/attribute_rewriter.py new file mode 100644 index 0000000000..928675878f --- /dev/null +++ b/dace/frontend/python/schedule_tree/attribute_rewriter.py @@ -0,0 +1,163 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Helpers for rewriting attribute access on user-defined Python objects. + +The direct schedule-tree frontend keeps ordinary attribute syntax such as +``obj.value`` unchanged for plain objects, but it makes descriptor behavior and +custom attribute hooks explicit when lowering user-defined objects. + +Example: + A descriptor-backed assignment such as ``holder.arr = A`` is rewritten to + ``type(holder).__dict__['arr'].__set__(holder, A)`` so later lowering sees + the same runtime behavior directly in the AST. +""" + +from __future__ import annotations + +import ast +import inspect +from typing import Any, Callable, Dict, Optional + +from dace import data, dtypes, symbolic +from dace.frontend.python import astutils +from dace.frontend.python.schedule_tree.static_evaluation import UNRESOLVED, try_resolve_static_value + + +class AttributeRewriter: + """Rewrite selected attribute reads and writes into explicit method calls. + + The rewriter handles user-defined objects whose attribute behavior is not + plain field access. In particular, it expands descriptor loads and stores, + and classes that override ``__getattribute__``, ``__getattr__``, or + ``__setattr__``. + + Example: + ``holder.arr`` can be rewritten to + ``type(holder).__dict__['arr'].__get__(holder, type(holder))`` when + ``arr`` is a descriptor on ``holder``'s class. + """ + + def __init__(self, evaluation_context: Callable[[], Dict[str, Any]]) -> None: + self._evaluation_context = evaluation_context + + def rewrite_expression(self, node: ast.AST) -> ast.AST: + """Return a copy of *node* with rewritten attribute reads.""" + + class _AttributeLoadRewriter(ast.NodeTransformer): + + def __init__(self, rewriter: 'AttributeRewriter') -> None: + self.rewriter = rewriter + + def visit_Attribute(self, attr_node: ast.Attribute) -> ast.AST: + attr_node.value = self.visit(attr_node.value) + rewritten = self.rewriter._rewrite_load(attr_node) + if rewritten is None: + return attr_node + return ast.copy_location(rewritten, attr_node) + + try: + working = astutils.copy_tree(node) + except Exception: + working = node + rewritten = _AttributeLoadRewriter(self).visit(working) + return ast.fix_missing_locations(rewritten) + + def rewrite_assignment(self, target: ast.AST, value: ast.AST) -> Optional[ast.AST]: + """Rewrite ``target = value`` when *target* is a special attribute write.""" + if not isinstance(target, ast.Attribute): + return None + + base_value = try_resolve_static_value(target.value, self._evaluation_context()) + if base_value is UNRESOLVED or self._is_builtin_like_base(base_value): + return None + + owner_expr = self._type_expr(astutils.copy_tree(target.value)) + obj_expr = astutils.copy_tree(target.value) + objtype = type(base_value) + rewritten_value = self.rewrite_expression(value) + + try: + static_attr = inspect.getattr_static(base_value, target.attr) + except AttributeError: + static_attr = None + + if static_attr is not None and self._is_descriptor(static_attr) and hasattr(static_attr, '__set__'): + descriptor_expr = self._descriptor_expr(astutils.copy_tree(target.value), target.attr) + return ast.Call(func=ast.Attribute(value=descriptor_expr, attr='__set__', ctx=ast.Load()), + args=[obj_expr, rewritten_value], + keywords=[]) + + setattr_method = objtype.__dict__.get('__setattr__') + if setattr_method is not None and setattr_method is not object.__setattr__: + return ast.Call(func=ast.Attribute(value=astutils.copy_tree(owner_expr), attr='__setattr__', + ctx=ast.Load()), + args=[obj_expr, ast.Constant(target.attr), rewritten_value], + keywords=[]) + + return None + + def _rewrite_load(self, node: ast.Attribute) -> Optional[ast.AST]: + if not isinstance(node.ctx, ast.Load): + return None + + base_value = try_resolve_static_value(node.value, self._evaluation_context()) + if base_value is UNRESOLVED or self._is_builtin_like_base(base_value): + return None + + owner_expr = self._type_expr(astutils.copy_tree(node.value)) + obj_expr = astutils.copy_tree(node.value) + objtype = type(base_value) + + try: + static_attr = inspect.getattr_static(base_value, node.attr) + except AttributeError: + static_attr = None + + if (static_attr is not None and self._is_descriptor(static_attr) + and not self._is_plain_method_descriptor(static_attr) and hasattr(static_attr, '__get__')): + descriptor_expr = self._descriptor_expr(astutils.copy_tree(node.value), node.attr) + return ast.Call(func=ast.Attribute(value=descriptor_expr, attr='__get__', ctx=ast.Load()), + args=[obj_expr, astutils.copy_tree(owner_expr)], + keywords=[]) + + getattribute = objtype.__dict__.get('__getattribute__') + if getattribute is not None and getattribute is not object.__getattribute__: + return ast.Call(func=ast.Attribute(value=astutils.copy_tree(owner_expr), + attr='__getattribute__', + ctx=ast.Load()), + args=[obj_expr, ast.Constant(node.attr)], + keywords=[]) + + if static_attr is None and '__getattr__' in objtype.__dict__: + return ast.Call(func=ast.Attribute(value=astutils.copy_tree(owner_expr), attr='__getattr__', + ctx=ast.Load()), + args=[obj_expr, ast.Constant(node.attr)], + keywords=[]) + + return None + + @staticmethod + def _type_expr(value_expr: ast.AST) -> ast.AST: + return ast.Call(func=ast.Name(id='type', ctx=ast.Load()), args=[value_expr], keywords=[]) + + def _descriptor_expr(self, value_expr: ast.AST, attr_name: str) -> ast.AST: + return ast.Subscript(value=ast.Attribute(value=self._type_expr(value_expr), attr='__dict__', ctx=ast.Load()), + slice=ast.Constant(attr_name), + ctx=ast.Load()) + + @staticmethod + def _is_descriptor(value: Any) -> bool: + return any(hasattr(value, attr) for attr in ('__get__', '__set__', '__delete__')) + + @staticmethod + def _is_plain_method_descriptor(value: Any) -> bool: + return (inspect.isfunction(value) or inspect.ismethod(value) or inspect.ismethoddescriptor(value) + or inspect.isbuiltin(value) or isinstance(value, (staticmethod, classmethod))) + + @staticmethod + def _is_builtin_like_base(value: Any) -> bool: + if dtypes.ismodule(value): + return True + if isinstance(value, (dtypes.typeclass, symbolic.symbol, symbolic.SymExpr, symbolic.sympy.Basic, data.Data)): + return True + module_name = getattr(type(value), '__module__', '') + return module_name.startswith(('numpy', 'dace', 'sympy', 'builtins')) diff --git a/dace/frontend/python/schedule_tree/callable_support.py b/dace/frontend/python/schedule_tree/callable_support.py new file mode 100644 index 0000000000..c8cd9417b9 --- /dev/null +++ b/dace/frontend/python/schedule_tree/callable_support.py @@ -0,0 +1,388 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Helpers for callback-like expressions and nested call specialization. + +Example: + If ``f`` is known to be ``lambda a, b: a + b``, then specializing the call + ``inner(A, f)`` marks ``f`` as callback-typed and records the recovered + lambda AST so the nested schedule-tree build can inline it later. +""" + +import ast +import copy +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple + +from dace import data + +from dace.frontend.python import astutils, preprocessing +from dace.frontend.python.schedule_tree.lambda_support import LambdaResolver +from dace.frontend.python.schedule_tree.static_evaluation import UNRESOLVED, try_resolve_static_value +from dace.frontend.python.schedule_tree.type_inference import _Binding + + +def _binding_to_descriptor(value: Any) -> data.Data: + descriptor = data.create_datadescriptor(value) + if isinstance(descriptor, data.View): + descriptor = descriptor.as_array() + descriptor.transient = False + return descriptor + + +def _callable_module_name(value: Any) -> str: + function = value.__func__ if inspect.ismethod(value) else value + module_name = getattr(function, '__module__', None) + if isinstance(module_name, str): + return module_name + return '' + + +def _is_user_parseable_callable(value: Any) -> bool: + module_name = _callable_module_name(value) + return not module_name.startswith(('dace.frontend.python', 'sympy', 'numpy')) + + +def _unwrap_inline_callable(value: Any) -> Any: + if isinstance(value, _ASTInlineCallable): + return value._callee + return value + + +class _ASTInlineCallable: + """AST-backed inline wrapper for parseable Python callables.""" + + _schedule_tree_inline_callable = True + + def __init__(self, callee: Any) -> None: + self._callee = callee + source_callable = callee.__func__ if inspect.ismethod(callee) else callee + src_ast, src_file, src_line, src = astutils.function_to_ast(source_callable) + if not src_ast.body or not isinstance(src_ast.body[0], ast.FunctionDef): + raise TypeError('Expected a FunctionDef when wrapping a Python callable for schedule-tree inlining') + + self.function_ast = ast.fix_missing_locations(astutils.copy_tree(src_ast.body[0])) + self.filename = src_file + self.src_line = src_line + self.src = src + self.name = self.function_ast.name + self.signature = inspect.signature(callee) + self.argnames = [parameter.name for parameter in self.signature.parameters.values()] + self.program_globals = copy.copy(getattr(source_callable, '__globals__', {})) + self.external_globals = copy.copy(self.program_globals) + self.constants: Dict[str, Tuple[data.Data, Any]] = {} + self.callback_mapping: Dict[str, str] = {} + self.captured_names: set[str] = set() + self._bound_self = callee.__self__ if inspect.ismethod(callee) else None + self._self_parameter = self.function_ast.args.args[ + 0].arg if self._bound_self is not None and self.function_ast.args.args else None + + def __deepcopy__(self, memo: Dict[int, Any]) -> '_ASTInlineCallable': + memo[id(self)] = self + return self + + def _generate_schedule_tree(self, + args: Tuple[Any], + kwargs: Dict[str, Any], + *, + lambda_bindings: Optional[Dict[str, ast.Lambda]] = None, + callable_bindings: Optional[Dict[str, Any]] = None): + from dace.frontend.python import schedule_tree_frontend + from dace.data.core import infer_structured_object_members + from dace.data.pydata import PythonClass + + bound_args = self.signature.bind_partial(*args, **kwargs) + argtypes = {name: _binding_to_descriptor(value) for name, value in bound_args.arguments.items()} + + seed_bindings: Dict[str, _Binding] = {} + program_globals = copy.copy(self.program_globals) + external_globals = copy.copy(self.external_globals) + + if self._bound_self is not None and self._self_parameter is not None: + program_globals[self._self_parameter] = self._bound_self + external_globals[self._self_parameter] = self._bound_self + try: + self_descriptor = PythonClass(infer_structured_object_members(self._bound_self), + name=type(self._bound_self).__name__) + except (TypeError, ValueError): + self_descriptor = PythonClass({}, name=type(self._bound_self).__name__) + seed_bindings[self._self_parameter] = _Binding(descriptor=self_descriptor, kind='container') + + parsed_ast = preprocessing.PreprocessedAST(self.filename, self.src_line, self.src, + astutils.copy_tree(self.function_ast), program_globals) + return schedule_tree_frontend.build_schedule_tree( + self.name, + parsed_ast, + argtypes, + constants=self.constants, + callback_mapping=self.callback_mapping, + arg_names=[name for name in self.argnames if name in argtypes], + lambda_bindings=lambda_bindings, + callable_bindings=callable_bindings, + seed_bindings=seed_bindings, + external_globals=external_globals) + + +class CallableResolver: + """Resolve callable values and nested-call metadata from AST nodes. + + Example: + If ``inner`` is visible in the current evaluation context, then + resolving ``inner(A, B).func`` returns the live callable object and the + helper can derive call classification, parameter binding, and callee + naming information from it. + """ + + def __init__(self, *, callable_bindings: Dict[str, Any], evaluation_context: Callable[[], Dict[str, Any]]) -> None: + self.callable_bindings = callable_bindings + self.evaluation_context = evaluation_context + + def resolve_static_value(self, node: ast.AST) -> Any: + return try_resolve_static_value(node, self.evaluation_context()) + + def resolve_callable_value(self, node: ast.AST) -> Any: + if isinstance(node, ast.Name) and node.id in self.callable_bindings: + return self._wrap_parseable_callable(self.callable_bindings[node.id]) + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.Attribute): + value = self._resolve_static_attribute(node) + else: + value = self.resolve_static_value(node) + return self._wrap_parseable_callable(value) + + def resolve_known_callable(self, node: ast.AST) -> Optional[Any]: + value = self.resolve_callable_value(node) + if value is UNRESOLVED: + return None + if getattr(value, '_schedule_tree_inline_callable', False): + return value + if hasattr(value, '__schedule_tree__'): + return None + if not callable(value): + return None + from dace import SDFG + if hasattr(value, '__sdfg__') and not isinstance(value, SDFG): + return None + return value + + def is_dace_program_call(self, node: ast.AST) -> bool: + if not isinstance(node, ast.Call): + return False + value = self.resolve_callable_value(node.func) + if value is UNRESOLVED: + return False + if getattr(value, '_schedule_tree_inline_callable', False): + return True + return hasattr(value, '__schedule_tree__') + + def is_sdfg_call(self, node: ast.AST) -> bool: + if not isinstance(node, ast.Call): + return False + value = self.resolve_callable_value(node.func) + if value is UNRESOLVED or hasattr(value, '__schedule_tree__'): + return False + from dace import SDFG + return isinstance(value, SDFG) or hasattr(value, '__sdfg__') + + def callable_signature(self, callee: Any) -> inspect.Signature: + from dace import SDFG + + if isinstance(callee, SDFG): + arg_names = list(callee.arg_names) + elif hasattr(callee, 'signature') and isinstance(callee.signature, inspect.Signature): + return callee.signature + elif hasattr(callee, '__schedule_tree_signature__'): + arg_names, _ = callee.__schedule_tree_signature__() + elif hasattr(callee, '__sdfg_signature__'): + arg_names, _ = callee.__sdfg_signature__() + elif hasattr(callee, 'f'): + return inspect.signature(callee.f) + else: + return inspect.signature(callee) + + return inspect.Signature( + [inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD) for name in arg_names]) + + def callable_name(self, callee: Any) -> str: + function_name = getattr(getattr(callee, 'f', None), '__name__', None) + if isinstance(function_name, str) and function_name: + return function_name + function_name = getattr(callee, '__name__', None) + if isinstance(function_name, str) and function_name: + return function_name + if hasattr(callee, 'name') and isinstance(callee.name, str): + return callee.name + return type(callee).__name__ + + def extract_argument_mapping(self, call_node: ast.Call, format_runtime_expression: Callable[[ast.AST], + str]) -> Dict[str, str]: + callee = self.resolve_callable_value(call_node.func) + sig = self.callable_signature(callee) + params = [param for param in sig.parameters.values() if param.name != 'self'] + + mapping: Dict[str, str] = {} + for index, arg in enumerate(call_node.args): + if index < len(params): + mapping[params[index].name] = format_runtime_expression(arg) + for kw in call_node.keywords: + mapping[kw.arg] = format_runtime_expression(kw.value) + return mapping + + def call_parameter_nodes(self, call_node: ast.Call) -> Dict[str, ast.AST]: + callee = self.resolve_callable_value(call_node.func) + sig = self.callable_signature(callee) + params = [param for param in sig.parameters.values() if param.name != 'self'] + keywords = {kw.arg: kw.value for kw in call_node.keywords if kw.arg is not None} + try: + bound = inspect.Signature(params).bind_partial(*call_node.args, **keywords) + except TypeError: + return {} + return dict(bound.arguments) + + def _resolve_static_attribute(self, node: ast.Attribute) -> Any: + owner = self.resolve_static_value(node.value) + if owner is UNRESOLVED: + return UNRESOLVED + try: + return getattr(owner, node.attr) + except Exception: + return UNRESOLVED + + def _wrap_parseable_callable(self, value: Any) -> Any: + if value is UNRESOLVED: + return value + if getattr(value, '_schedule_tree_inline_callable', False): + return value + if hasattr(value, '__schedule_tree__') or hasattr(value, '__sdfg__'): + return value + if not callable(value): + return value + from dace import SDFG + + if isinstance(value, (SDFG, data.Data)) or inspect.isclass(value): + return value + if inspect.ismethod(value) and isinstance(getattr(value, '__self__', None), data.Data): + return value + if inspect.isbuiltin(value) or inspect.ismethoddescriptor(value): + return value + try: + if inspect.ismethod(value) or inspect.isfunction(value): + source_callable = value.__func__ if inspect.ismethod(value) else value + if getattr(source_callable, '__name__', None) == '': + return value + if inspect.isgeneratorfunction(source_callable): + return value + if not _is_user_parseable_callable(value): + return value + return _ASTInlineCallable(value) + + bound_call = getattr(value, '__call__', None) + if bound_call is None or inspect.isbuiltin(bound_call) or inspect.ismethoddescriptor(bound_call): + return value + if not _is_user_parseable_callable(bound_call): + return value + + call_impl = getattr(type(value), '__call__', None) + if call_impl in {None, object.__call__}: + return value + + return _ASTInlineCallable(bound_call) + except (TypeError, OSError): + return value + + +class CallableArgumentSpecializer: + """Recognize callback-like values and specialize nested call arguments. + + The helper keeps the schedule-tree builder focused on structure creation by + isolating the rules for callback expressions, lambda argument bindings, and + argument specialization for nested function calls. + + Example: + Given ``f = lambda a, b: a + b``, specializing ``inner(A, f)`` returns + a callback descriptor for ``f`` and records ``f`` in the lambda binding + map for the nested call scope. + """ + + def __init__(self, *, lambda_resolver: LambdaResolver, callable_resolver: CallableResolver, + bindings: Dict[str, _Binding], infer_descriptor: Callable[[ast.AST], Optional[data.Data]], + resolve_data_access: Callable[[ast.AST], Optional[Tuple[str, Any, data.Data, Optional[data.Data]]]], + is_callback_descriptor: Callable[[Optional[data.Data]], + bool], callback_specialization_value: Callable[[], + data.Scalar]) -> None: + self.lambda_resolver = lambda_resolver + self.callable_resolver = callable_resolver + self.bindings = bindings + self.infer_descriptor = infer_descriptor + self.resolve_data_access = resolve_data_access + self.is_callback_descriptor = is_callback_descriptor + self.callback_specialization_value = callback_specialization_value + + def is_callback_expression(self, node: ast.AST) -> bool: + """Return whether ``node`` should stay callback-typed in the tree.""" + if self.lambda_resolver.resolve_known_lambda_node(node) is not None: + return True + if self.callable_resolver.resolve_known_callable(node) is not None: + return True + if isinstance(node, ast.Name): + binding = self.bindings.get(node.id) + if binding is not None and self.is_callback_descriptor(binding.descriptor): + return True + access = self.resolve_data_access(node) + if access is None: + return False + _, _, descriptor, view_descriptor = access + return self.is_callback_descriptor(view_descriptor or descriptor) + + def specialize_argument(self, node: ast.AST) -> Any: + """Return the specialization payload for one nested call argument.""" + lambda_node = self.lambda_resolver.resolve_known_lambda_node(node) + if lambda_node is not None: + return self.callback_specialization_value() + + callable_value = self.callable_resolver.resolve_known_callable(node) + if callable_value is not None: + return _unwrap_inline_callable(callable_value) + + descriptor = self.infer_descriptor(node) + if descriptor is not None: + specialized = copy.deepcopy(descriptor) + specialized.transient = False + return specialized + + value = self.callable_resolver.resolve_static_value(node) + if value is not UNRESOLVED: + return value + + return None + + def extract_call_specialization( + self, call_node: ast.Call, + unparse: Callable[[ast.AST], + str]) -> Tuple[List[Any], Dict[str, Any], Dict[str, ast.Lambda], Dict[str, Any]]: + """Build specialization payloads and known callable bindings for ``call_node``.""" + parameter_nodes = self.callable_resolver.call_parameter_nodes(call_node) + lambda_bindings: Dict[str, ast.Lambda] = {} + callable_bindings: Dict[str, Any] = {} + + for param_name, argument_node in parameter_nodes.items(): + lambda_node = self.lambda_resolver.resolve_known_lambda_node(argument_node) + if lambda_node is not None: + lambda_bindings[param_name] = lambda_node + continue + + callable_value = self.callable_resolver.resolve_known_callable(argument_node) + if callable_value is not None: + callable_bindings[param_name] = _unwrap_inline_callable(callable_value) + + args = [self._specialize_or_unparse(arg, unparse) for arg in call_node.args] + kwargs = { + kw.arg: self._specialize_or_unparse(kw.value, unparse) + for kw in call_node.keywords if kw.arg is not None + } + return args, kwargs, lambda_bindings, callable_bindings + + def _specialize_or_unparse(self, node: ast.AST, unparse: Callable[[ast.AST], str]) -> Any: + specialized = self.specialize_argument(node) + if specialized is None: + return unparse(node) + return specialized diff --git a/dace/frontend/python/schedule_tree/callback_support.py b/dace/frontend/python/schedule_tree/callback_support.py new file mode 100644 index 0000000000..3b11b80a66 --- /dev/null +++ b/dace/frontend/python/schedule_tree/callback_support.py @@ -0,0 +1,239 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Helpers for schedule-tree Python callback outlining. + +Example: + Wrapping ``it = iter(generator)`` as a callback can keep the original code + block for compatibility while also producing an outlined scaffold such as:: + + def __stree_callback_0(): + it = iter(generator) + return it + + it = __stree_callback_0() + + The outlined scaffold is metadata for future callback lowering work; the + schedule tree still preserves the original callback code text as well. +""" + +from __future__ import annotations + +import ast +import inspect +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union + +from dace import data +from dace.properties import CodeBlock + +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from dace.frontend.python import astutils +from dace.frontend.python.schedule_tree.callable_support import CallableResolver +from dace.frontend.python.schedule_tree.static_evaluation import UNRESOLVED, try_resolve_static_value + +CallbackBody = Union[ast.AST, Sequence[ast.stmt]] + + +class CallbackOutliner: + """Build callback scaffolding and basic name-flow metadata. + + The helper accepts either a single AST node or a list of statements. This + lets the current frontend keep wrapping individual callback statements while + also providing an API that can later outline larger statement groups. + """ + + @staticmethod + def analyze_name_flow(body: CallbackBody) -> Tuple[set[str], set[str]]: + """Return ``(load_names, store_names)`` for a callback body.""" + inputs: set[str] = set() + outputs: set[str] = set() + for node in CallbackOutliner._body_nodes(body): + for child in ast.walk(node): + if isinstance(child, ast.Name): + if isinstance(child.ctx, ast.Store): + outputs.add(child.id) + elif isinstance(child.ctx, ast.Load): + inputs.add(child.id) + elif isinstance(child, ast.alias): + if child.asname: + outputs.add(child.asname) + else: + outputs.add(child.name.split('.')[0]) + elif isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + outputs.add(child.name) + elif isinstance(child, ast.ExceptHandler) and isinstance(child.name, str): + outputs.add(child.name) + return inputs, outputs + + @staticmethod + def code_block(body: CallbackBody) -> CodeBlock: + """Return a ``CodeBlock`` for the original callback body.""" + return CodeBlock(CallbackOutliner._body_nodes(body)) + + @staticmethod + def outline(body: CallbackBody, *, callback_name: str, input_names: Sequence[str], + output_names: Sequence[str]) -> Tuple[CodeBlock, CodeBlock]: + """Build outlined function and call-site scaffolding for ``body``.""" + input_names = list(input_names) + output_names = list(output_names) + function_body = CallbackOutliner._body_nodes(body) + if output_names: + returned = ast.Name(id=output_names[0], ctx=ast.Load()) + if len(output_names) > 1: + returned = ast.Tuple(elts=[ast.Name(id=name, ctx=ast.Load()) for name in output_names], ctx=ast.Load()) + function_body.append(ast.Return(value=returned)) + + function_def = ast.FunctionDef(name=callback_name, + args=ast.arguments(posonlyargs=[], + args=[ast.arg(arg=name) for name in input_names], + vararg=None, + kwonlyargs=[], + kw_defaults=[], + kwarg=None, + defaults=[]), + body=function_body or [ast.Pass()], + decorator_list=[]) + function_code = CodeBlock([ast.fix_missing_locations(function_def)]) + + call_expr = ast.Call(func=ast.Name(id=callback_name, ctx=ast.Load()), + args=[ast.Name(id=name, ctx=ast.Load()) for name in input_names], + keywords=[]) + if not output_names: + call_stmt: ast.stmt = ast.Expr(value=call_expr) + elif len(output_names) == 1: + call_stmt = ast.Assign(targets=[ast.Name(id=output_names[0], ctx=ast.Store())], value=call_expr) + else: + call_stmt = ast.Assign(targets=[ + ast.Tuple(elts=[ast.Name(id=name, ctx=ast.Store()) for name in output_names], ctx=ast.Store()) + ], + value=call_expr) + call_code = CodeBlock([ast.fix_missing_locations(call_stmt)]) + return function_code, call_code + + @staticmethod + def _body_nodes(body: CallbackBody) -> List[ast.stmt]: + if isinstance(body, Sequence) and not isinstance(body, ast.AST): + return [astutils.copy_tree(statement) for statement in body] + if isinstance(body, ast.stmt): + return [astutils.copy_tree(body)] + if isinstance(body, ast.AST): + return [ast.Expr(value=astutils.copy_tree(body))] + return [ast.Pass()] + + +class CallbackHandler: + """Own callback wrapping, callback assignments, and callback fallback policy.""" + + def __init__(self, *, bindings: Dict[str, Any], callback_mutated_global_names: Set[str], + callable_resolver: CallableResolver, evaluation_context: Callable[[], Dict[str, Any]], + append_node: Callable[[tn.ScheduleTreeNode], + None], register_binding: Callable[[str, data.Data, str], + None], fresh_callback_name: Callable[[], str], + fresh_transient_name: Callable[[str], str], render_callback_code: Callable[[ast.AST], str], + collect_scope_declarations: Callable[[ast.AST], + Tuple[set[str], + set[str]]], raise_syntax_error: Callable[[ast.AST, str], + None], + binding_kind_for_descriptor: Callable[[data.Data], + str], pyobject_scalar_descriptor: Callable[[], data.Scalar], + is_pyobject_scalar_descriptor: Callable[[Optional[data.Data]], bool], + is_iterator_protocol_call: Callable[[ast.AST], bool], is_iterator_next_call: Callable[[ast.AST], + bool]) -> None: + self.bindings = bindings + self.callback_mutated_global_names = callback_mutated_global_names + self.callable_resolver = callable_resolver + self.evaluation_context = evaluation_context + self.append_node = append_node + self.register_binding = register_binding + self.fresh_callback_name = fresh_callback_name + self.fresh_transient_name = fresh_transient_name + self.render_callback_code = render_callback_code + self.collect_scope_declarations = collect_scope_declarations + self.raise_syntax_error = raise_syntax_error + self.binding_kind_for_descriptor = binding_kind_for_descriptor + self.pyobject_scalar_descriptor = pyobject_scalar_descriptor + self.is_pyobject_scalar_descriptor = is_pyobject_scalar_descriptor + self.is_iterator_protocol_call = is_iterator_protocol_call + self.is_iterator_next_call = is_iterator_next_call + + def reject_mutated_global_uses(self, node: Optional[ast.AST]) -> None: + if node is None or not self.callback_mutated_global_names: + return + + for child in ast.walk(node): + if isinstance(child, ast.Name) and isinstance(child.ctx, ast.Load): + if child.id in self.callback_mutated_global_names: + self.raise_syntax_error( + child, 'Nested callback functions cannot reassign global names that are used in the enclosing ' + f'program: {child.id}') + + def wrap_node(self, node: ast.AST, reason: str) -> None: + node = ast.fix_missing_locations(astutils.copy_tree(node)) + try: + code = CodeBlock(self.render_callback_code(node)) + except Exception: + code = CodeBlock('pass') + + inputs, outputs = CallbackOutliner.analyze_name_flow(node) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + global_names, _ = self.collect_scope_declarations(node) + self.callback_mutated_global_names.update(global_names) + + known_inputs = sorted(inputs & set(self.bindings)) + sorted_outputs = sorted(outputs) + for output_name in sorted_outputs: + binding = self.bindings.get(output_name) + if binding is None or getattr(binding, 'descriptor', None) is None: + self.register_binding(output_name, self.pyobject_scalar_descriptor(), 'scalar') + + callback_name = self.fresh_callback_name() + outlined_function_code, outlined_call_code = CallbackOutliner.outline(node, + callback_name=callback_name, + input_names=known_inputs, + output_names=sorted_outputs) + self.append_node( + tn.PythonCallbackNode(code=code, + reason=reason, + input_names=known_inputs, + output_names=sorted_outputs, + outlined_function_name=callback_name, + outlined_function_code=outlined_function_code, + outlined_call_code=outlined_call_code)) + + def emit_assignment(self, name: str, value: ast.AST, reason: str, descriptor: data.Data) -> None: + if reason == 'pyobject call' and self.is_pyobject_scalar_descriptor(descriptor) and self.is_iterator_next_call( + value): + import warnings + warnings.warn('Could not infer the result type of iterator next() in schedule-tree lowering; ' + 'annotate the assignment target, e.g. val: dace.float64 = next(gen).') + kind = self.binding_kind_for_descriptor(descriptor) + self.register_binding(name, descriptor, kind) + callback_assign = ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=astutils.copy_tree(value)) + callback_assign = ast.copy_location(callback_assign, value) + self.wrap_node(callback_assign, reason) + + def materialize_expression(self, + value: ast.AST, + reason: str, + descriptor: data.Data, + *, + prefix: str = '__stree_tmp') -> ast.AST: + name = self.fresh_transient_name(prefix) + self.emit_assignment(name, value, reason, descriptor) + return ast.Name(id=name, ctx=ast.Load()) + + def should_emit_pyobject_call_callback(self, value: ast.AST) -> bool: + if not isinstance(value, ast.Call): + return False + if self.is_iterator_protocol_call(value): + return True + + callee = self.callable_resolver.resolve_callable_value(value.func) + if callee is not UNRESOLVED and inspect.isgeneratorfunction(callee): + return True + + runtime_value = try_resolve_static_value(value, self.evaluation_context()) + if runtime_value is UNRESOLVED: + return False + if callable(runtime_value): + return False + return hasattr(runtime_value, '__next__') or hasattr(runtime_value, '__iter__') diff --git a/dace/frontend/python/schedule_tree/desugaring.py b/dace/frontend/python/schedule_tree/desugaring.py new file mode 100644 index 0000000000..1b20df63c9 --- /dev/null +++ b/dace/frontend/python/schedule_tree/desugaring.py @@ -0,0 +1,1342 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""AST desugaring passes for the direct Python schedule-tree frontend.""" + +from __future__ import annotations + +import ast +import builtins as pybuiltins +import copy +import numbers +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from dace import data +from dace.frontend.python.common import DaceSyntaxError +from dace.frontend.python import astutils +from dace.frontend.python.schedule_tree.callable_support import CallableResolver +from dace.frontend.python.schedule_tree.dunder_support import (rewrite_augassign, rewrite_subscript_assignment, + rewrite_subscript_delete, rewrite_sugared_expression) +from dace.frontend.python.schedule_tree.static_evaluation import UNRESOLVED, try_resolve_static_value +from dace.frontend.python.schedule_tree.tuple_assignment import lower_tuple_assignments + +_CALLBACK_REASON_ATTR = '_schedule_tree_callback_reason' + + +class _DynamicExpansionError(Exception): + + def __init__(self, node: ast.Call, *, is_sdfg_call: bool) -> None: + self.node = node + self.is_sdfg_call = is_sdfg_call + + +class ScheduleTreeExpansionDesugarer(ast.NodeTransformer): + """Rewrite schedule-tree-specific syntax into simpler AST forms. + + The pass handles compile-time-expandable ``*args`` / ``**kwargs``, starred + unpacking, and tuple or list assignments that benefit from an explicit + right-hand-side temporary. + + Example: + ``A, B = B, A`` becomes ``__stree_tuple_tmp = (B, A)`` followed by + ``(A, B) = __stree_tuple_tmp``. + + Static cases are rewritten into ordinary Python AST. Dynamic cases are marked + for callback lowering, except SDFG-backed calls, which raise ``DaceSyntaxError``. + """ + + def __init__(self, + filename: str, + global_vars: Dict[str, Any], + callable_bindings: Optional[Dict[str, Any]] = None) -> None: + self.filename = filename + self.global_vars = copy.copy(global_vars) + self.callable_bindings = dict(callable_bindings or {}) + self.callable_resolver = CallableResolver(callable_bindings=self.callable_bindings, + evaluation_context=self._evaluation_context) + self._expansion_bindings: Dict[str, ast.AST] = {} + self._temp_counter = 0 + + def visit_Module(self, node: ast.Module) -> ast.AST: + saved = self._expansion_bindings + self._expansion_bindings = {} + node.body = self._rewrite_body(node.body) + self._expansion_bindings = saved + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + saved = self._expansion_bindings + self._expansion_bindings = {} + node.body = self._rewrite_body(node.body) + self._expansion_bindings = saved + return node + + if hasattr(ast, 'AsyncFunctionDef'): + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: + saved = self._expansion_bindings + self._expansion_bindings = {} + node.body = self._rewrite_body(node.body) + self._expansion_bindings = saved + return node + + def visit_Assign(self, node: ast.Assign) -> ast.AST: + try: + node.value = self._rewrite_expression(node.value) + except _DynamicExpansionError as ex: + self._invalidate_targets(node.targets) + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + return self._mark_callback(node, 'call expansion') + + if len(node.targets) == 1 and isinstance(node.targets[0], ast.Subscript): + rewritten = rewrite_subscript_assignment(node.targets[0], node.value, self.callable_resolver) + if rewritten is not None: + self._invalidate_targets(node.targets) + return self.visit(rewritten) + + if len(node.targets) == 1 and self._has_starred_target(node.targets[0]): + expanded = self._expand_starred_assignment(node.targets[0], node.value) + if expanded is None: + self._invalidate_targets(node.targets) + return self._mark_callback(node, 'starred unpacking') + return self._rewrite_generated_assignments(expanded, template_node=node) + + materialized = self._materialize_structured_assignment(node) + if materialized is not None: + return materialized + + for target in node.targets: + self._update_target_binding(target, node.value) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST: + if node.value is None: + self._invalidate_target(node.target) + return node + + try: + node.value = self._rewrite_expression(node.value) + except _DynamicExpansionError as ex: + self._invalidate_target(node.target) + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + return self._mark_callback(node, 'call expansion') + + self._update_target_binding(node.target, node.value) + return node + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST: + rewritten = rewrite_augassign(node.target, node.op, node.value, self.callable_resolver) + if rewritten is not None: + self._invalidate_target(node.target) + return self.visit(rewritten) + + try: + node.value = self._rewrite_expression(node.value) + except _DynamicExpansionError as ex: + self._invalidate_target(node.target) + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + return self._mark_callback(node, 'call expansion') + + self._invalidate_target(node.target) + return node + + def visit_Expr(self, node: ast.Expr) -> ast.AST: + try: + node.value = self._rewrite_expression(node.value) + except _DynamicExpansionError as ex: + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + return self._mark_callback(node, 'call expansion') + return node + + def visit_Return(self, node: ast.Return) -> ast.AST: + if node.value is None: + return node + try: + node.value = self._rewrite_expression(node.value) + except _DynamicExpansionError as ex: + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + temp_name = self._fresh_name('__stree_retval') + assign_stmt = ast.Assign(targets=[ast.Name(id=temp_name, ctx=ast.Store())], + value=astutils.copy_tree(node.value)) + assign_stmt = self._mark_callback(ast.copy_location(assign_stmt, node.value), 'call expansion') + return_stmt = ast.copy_location(ast.Return(value=ast.Name(id=temp_name, ctx=ast.Load())), node) + return [assign_stmt, ast.fix_missing_locations(return_stmt)] + return node + + def visit_If(self, node: ast.If) -> ast.AST: + try: + node.test = self._rewrite_expression(node.test) + except _DynamicExpansionError as ex: + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + return self._mark_callback(node, 'call expansion') + node.body = self._rewrite_nested_body(node.body) + node.orelse = self._rewrite_nested_body(node.orelse) + return node + + def visit_While(self, node: ast.While) -> ast.AST: + try: + node.test = self._rewrite_expression(node.test) + except _DynamicExpansionError as ex: + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + return self._mark_callback(node, 'call expansion') + node.body = self._rewrite_nested_body(node.body) + node.orelse = self._rewrite_nested_body(node.orelse) + return node + + def visit_For(self, node: ast.For) -> ast.AST: + try: + node.iter = self._rewrite_expression(node.iter) + except _DynamicExpansionError as ex: + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + return self._mark_callback(node, 'call expansion') + node.body = self._rewrite_nested_body(node.body) + node.orelse = self._rewrite_nested_body(node.orelse) + self._invalidate_target(node.target) + return node + + if hasattr(ast, 'AsyncFor'): + + def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AST: + try: + node.iter = self._rewrite_expression(node.iter) + except _DynamicExpansionError as ex: + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + return self._mark_callback(node, 'call expansion') + node.body = self._rewrite_nested_body(node.body) + node.orelse = self._rewrite_nested_body(node.orelse) + self._invalidate_target(node.target) + return node + + def visit_With(self, node: ast.With) -> ast.AST: + for item in node.items: + try: + item.context_expr = self._rewrite_expression(item.context_expr) + if item.optional_vars is not None: + self._invalidate_target(item.optional_vars) + except _DynamicExpansionError as ex: + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + return self._mark_callback(node, 'call expansion') + node.body = self._rewrite_nested_body(node.body) + return node + + if hasattr(ast, 'AsyncWith'): + + def visit_AsyncWith(self, node: ast.AsyncWith) -> ast.AST: + for item in node.items: + try: + item.context_expr = self._rewrite_expression(item.context_expr) + if item.optional_vars is not None: + self._invalidate_target(item.optional_vars) + except _DynamicExpansionError as ex: + if ex.is_sdfg_call: + self._raise_dynamic_sdfg_error(ex.node) + return self._mark_callback(node, 'call expansion') + node.body = self._rewrite_nested_body(node.body) + return node + + def visit_Try(self, node: ast.Try) -> ast.AST: + node.body = self._rewrite_nested_body(node.body) + node.orelse = self._rewrite_nested_body(node.orelse) + node.finalbody = self._rewrite_nested_body(node.finalbody) + for handler in node.handlers: + handler.body = self._rewrite_nested_body(handler.body) + return node + + def visit_Delete(self, node: ast.Delete) -> ast.AST: + if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Subscript): + return node + rewritten = rewrite_subscript_delete(node.targets[0], self.callable_resolver) + if rewritten is None: + return node + self._invalidate_targets(node.targets) + return self.visit(rewritten) + + def _rewrite_body(self, body: List[ast.stmt]) -> List[ast.stmt]: + result: List[ast.stmt] = [] + for statement in body: + rewritten = self.visit(statement) + if rewritten is None: + continue + if isinstance(rewritten, list): + result.extend(rewritten) + else: + result.append(rewritten) + return result + + def _rewrite_nested_body(self, body: List[ast.stmt]) -> List[ast.stmt]: + saved = self._expansion_bindings + self._expansion_bindings = {k: astutils.copy_tree(v) for k, v in saved.items()} + rewritten = self._rewrite_body(body) + self._expansion_bindings = saved + return rewritten + + def _rewrite_generated_assignments(self, assignments: List[Tuple[ast.AST, ast.AST]], + template_node: ast.stmt) -> List[ast.stmt]: + result: List[ast.stmt] = [] + for target, value in assignments: + assign_stmt = ast.Assign(targets=[astutils.copy_tree(target)], value=astutils.copy_tree(value)) + assign_stmt = ast.copy_location(assign_stmt, template_node) + rewritten = self.visit(assign_stmt) + if rewritten is None: + continue + if isinstance(rewritten, list): + result.extend(rewritten) + else: + result.append(rewritten) + return result + + def _rewrite_generated_statements(self, statements: List[ast.stmt]) -> List[ast.stmt]: + result: List[ast.stmt] = [] + for statement in statements: + rewritten = self.visit(statement) + if rewritten is None: + continue + if isinstance(rewritten, list): + result.extend(rewritten) + else: + result.append(rewritten) + return result + + def _rewrite_expression(self, node: ast.AST) -> ast.AST: + outer = self + + class _ExpressionRewriter(ast.NodeTransformer): + + def visit_Call(self, call_node: ast.Call) -> ast.AST: + call_node = self.generic_visit(call_node) + rewritten = rewrite_sugared_expression(call_node, outer.callable_resolver) + if rewritten is not None: + call_node = self.visit(rewritten) + if not outer._is_expanded_call(call_node): + return call_node + expanded = outer._expand_call_if_static(call_node) + if expanded is not None: + return expanded + raise _DynamicExpansionError(call_node, is_sdfg_call=outer.callable_resolver.is_sdfg_call(call_node)) + + def visit_BinOp(self, expr_node: ast.BinOp) -> ast.AST: + expr_node = self.generic_visit(expr_node) + rewritten = rewrite_sugared_expression(expr_node, outer.callable_resolver) + return rewritten if rewritten is not None else expr_node + + def visit_UnaryOp(self, expr_node: ast.UnaryOp) -> ast.AST: + expr_node = self.generic_visit(expr_node) + rewritten = rewrite_sugared_expression(expr_node, outer.callable_resolver) + return rewritten if rewritten is not None else expr_node + + def visit_Compare(self, expr_node: ast.Compare) -> ast.AST: + expr_node = self.generic_visit(expr_node) + rewritten = rewrite_sugared_expression(expr_node, outer.callable_resolver) + return rewritten if rewritten is not None else expr_node + + def visit_Subscript(self, expr_node: ast.Subscript) -> ast.AST: + expr_node = self.generic_visit(expr_node) + rewritten = rewrite_sugared_expression(expr_node, outer.callable_resolver) + return rewritten if rewritten is not None else expr_node + + return _ExpressionRewriter().visit(astutils.copy_tree(node)) + + def _evaluation_context(self) -> Dict[str, Any]: + context = copy.copy(pybuiltins.__dict__) + context.update(self.global_vars) + context.update(self.callable_bindings) + return context + + @staticmethod + def _is_expanded_call(node: ast.AST) -> bool: + return isinstance(node, ast.Call) and (any(isinstance(arg, ast.Starred) + for arg in node.args) or any(keyword.arg is None + for keyword in node.keywords)) + + @staticmethod + def _has_starred_target(target: ast.AST) -> bool: + return any(isinstance(child, ast.Starred) for child in ast.walk(target)) + + def _expand_call_if_static(self, node: ast.Call) -> Optional[ast.Call]: + args: List[ast.AST] = [] + keywords: List[ast.keyword] = [] + for argument in node.args: + if isinstance(argument, ast.Starred): + expanded = self._resolve_static_sequence_nodes(argument.value) + if expanded is None: + return None + args.extend(astutils.copy_tree(value) for value in expanded) + else: + args.append(astutils.copy_tree(argument)) + + for keyword in node.keywords: + if keyword.arg is None: + expanded_items = self._resolve_static_mapping_items(keyword.value) + if expanded_items is None: + return None + keywords.extend( + ast.keyword(arg=name, value=astutils.copy_tree(value)) for name, value in expanded_items) + else: + keywords.append(ast.keyword(arg=keyword.arg, value=astutils.copy_tree(keyword.value))) + return ast.copy_location(ast.Call(func=astutils.copy_tree(node.func), args=args, keywords=keywords), node) + + def _normalized_static_expansion_ast(self, node: ast.AST) -> Optional[ast.AST]: + if isinstance(node, ast.Name): + cached = self._expansion_bindings.get(node.id) + return astutils.copy_tree(cached) if cached is not None else None + + if isinstance(node, (ast.Tuple, ast.List)): + elements: List[ast.AST] = [] + for element in node.elts: + if isinstance(element, ast.Starred): + expanded = self._resolve_static_sequence_nodes(element.value) + if expanded is None: + return None + elements.extend(astutils.copy_tree(value) for value in expanded) + else: + elements.append(astutils.copy_tree(element)) + sequence_type = ast.Tuple if isinstance(node, ast.Tuple) else ast.List + return ast.copy_location(sequence_type(elts=elements, ctx=ast.Load()), node) + + if isinstance(node, ast.Dict): + keys: List[Optional[ast.AST]] = [] + values: List[ast.AST] = [] + for key, value in zip(node.keys, node.values): + if key is None: + expanded_items = self._resolve_static_mapping_items(value) + if expanded_items is None: + return None + for expanded_key, expanded_value in expanded_items: + keys.append(ast.copy_location(ast.Constant(expanded_key), value)) + values.append(astutils.copy_tree(expanded_value)) + continue + + resolved_key = try_resolve_static_value(key, self._evaluation_context()) + if not isinstance(resolved_key, str): + return None + keys.append(ast.copy_location(ast.Constant(resolved_key), key)) + values.append(astutils.copy_tree(value)) + return ast.copy_location(ast.Dict(keys=keys, values=values), node) + + return None + + def _resolve_static_sequence_nodes(self, node: ast.AST) -> Optional[List[ast.AST]]: + normalized = self._normalized_static_expansion_ast(node) + if isinstance(normalized, (ast.Tuple, ast.List)): + return [astutils.copy_tree(element) for element in normalized.elts] + return None + + def _resolve_static_mapping_items(self, node: ast.AST) -> Optional[List[Tuple[str, ast.AST]]]: + normalized = self._normalized_static_expansion_ast(node) + if not isinstance(normalized, ast.Dict): + return None + + result: List[Tuple[str, ast.AST]] = [] + for key, value in zip(normalized.keys, normalized.values): + resolved_key = try_resolve_static_value(key, self._evaluation_context()) + if not isinstance(resolved_key, str): + return None + result.append((resolved_key, astutils.copy_tree(value))) + return result + + def _expand_starred_assignment(self, target: ast.AST, value: ast.AST) -> Optional[List[Tuple[ast.AST, ast.AST]]]: + if isinstance(target, ast.Name): + return [(astutils.copy_tree(target), astutils.copy_tree(value))] + + if isinstance(target, ast.Starred): + elements = self._resolve_static_sequence_nodes(value) + if elements is None: + return None + list_value = ast.copy_location( + ast.List(elts=[astutils.copy_tree(element) for element in elements], ctx=ast.Load()), value) + return [(astutils.copy_tree(target.value), list_value)] + + if not isinstance(target, (ast.Tuple, ast.List)): + return [(astutils.copy_tree(target), astutils.copy_tree(value))] + + elements = self._resolve_static_sequence_nodes(value) + if elements is None: + return None + + starred_indices = [index for index, element in enumerate(target.elts) if isinstance(element, ast.Starred)] + if len(starred_indices) > 1: + return None + + assignments: List[Tuple[ast.AST, ast.AST]] = [] + if not starred_indices: + if len(target.elts) != len(elements): + return None + for subtarget, subvalue in zip(target.elts, elements): + expanded = self._expand_starred_assignment(subtarget, subvalue) + if expanded is None: + return None + assignments.extend(expanded) + return assignments + + starred_index = starred_indices[0] + if len(elements) < len(target.elts) - 1: + return None + + prefix_targets = target.elts[:starred_index] + suffix_targets = target.elts[starred_index + 1:] + prefix_values = elements[:starred_index] + suffix_values = elements[len(elements) - len(suffix_targets):] + middle_values = elements[starred_index:len(elements) - len(suffix_targets)] + + for subtarget, subvalue in zip(prefix_targets, prefix_values): + expanded = self._expand_starred_assignment(subtarget, subvalue) + if expanded is None: + return None + assignments.extend(expanded) + + middle_list = ast.copy_location( + ast.List(elts=[astutils.copy_tree(element) for element in middle_values], ctx=ast.Load()), value) + expanded_middle = self._expand_starred_assignment(target.elts[starred_index], middle_list) + if expanded_middle is None: + return None + assignments.extend(expanded_middle) + + for subtarget, subvalue in zip(suffix_targets, suffix_values): + expanded = self._expand_starred_assignment(subtarget, subvalue) + if expanded is None: + return None + assignments.extend(expanded) + + return assignments + + def _materialize_structured_assignment(self, node: ast.Assign) -> Optional[List[ast.stmt]]: + if not any(isinstance(target, (ast.Tuple, ast.List)) for target in node.targets): + return None + + if not isinstance(node.value, (ast.Tuple, ast.List)): + return None + + normalized_value = self._normalized_static_expansion_ast(node.value) + if not isinstance(normalized_value, (ast.Tuple, ast.List)): + return None + + temp_name = self._fresh_name('__stree_tuple_tmp') + temp_assign = ast.Assign(targets=[ast.Name(id=temp_name, ctx=ast.Store())], value=normalized_value) + temp_assign = ast.copy_location(temp_assign, node) + + rewritten_assign = ast.Assign(targets=[astutils.copy_tree(target) for target in node.targets], + value=ast.Name(id=temp_name, ctx=ast.Load())) + rewritten_assign = ast.copy_location(rewritten_assign, node) + return self._rewrite_generated_statements([temp_assign, rewritten_assign]) + + def _update_target_binding(self, target: ast.AST, value: ast.AST) -> None: + normalized = self._normalized_static_expansion_ast(value) + if isinstance(target, ast.Name) and normalized is not None: + self._expansion_bindings[target.id] = normalized + return + self._invalidate_target(target) + + def _invalidate_targets(self, targets: List[ast.AST]) -> None: + for target in targets: + self._invalidate_target(target) + + def _invalidate_target(self, target: ast.AST) -> None: + for child in ast.walk(target): + if isinstance(child, ast.Name) and isinstance(child.ctx, ast.Store): + self._expansion_bindings.pop(child.id, None) + + def _fresh_name(self, prefix: str) -> str: + candidate = prefix + while candidate in self._expansion_bindings or candidate in self.global_vars or candidate in self.callable_bindings: + self._temp_counter += 1 + candidate = f'{prefix}{self._temp_counter}' + return candidate + + def _mark_callback(self, statement: ast.stmt, reason: str) -> ast.stmt: + setattr(statement, _CALLBACK_REASON_ATTR, reason) + return ast.fix_missing_locations(statement) + + def _raise_dynamic_sdfg_error(self, node: ast.Call) -> None: + raise DaceSyntaxError(self, node, 'Dynamic argument expansion is unsupported for SDFG calls in ' + 'schedule-tree lowering') + + +class ScheduleTreeSubscriptIndexDesugarer(ast.NodeTransformer): + """Outline nested subscript-index expressions into explicit temporaries. + + Examples: + ``A[B[i]]`` becomes ``__stree_idx = B[i]`` followed by ``A[__stree_idx]``. + + ``A[f(g[i])]`` becomes a sequence such as ``__stree_idx = g[i]``, + ``__stree_idx1 = f(__stree_idx)``, then ``A[__stree_idx1]``. + """ + + def __init__(self, global_vars: Dict[str, Any], callable_bindings: Optional[Dict[str, Any]] = None) -> None: + self.global_vars = copy.copy(global_vars) + self.callable_bindings = dict(callable_bindings or {}) + self._temp_counter = 0 + self._used_names: set[str] = set(self.global_vars) | set(self.callable_bindings) + + def visit_Module(self, node: ast.Module) -> ast.AST: + self._seed_used_names(node) + node.body = self._rewrite_body(node.body) + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + self._seed_used_names(node) + node.body = self._rewrite_body(node.body) + return node + + if hasattr(ast, 'AsyncFunctionDef'): + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: + self._seed_used_names(node) + node.body = self._rewrite_body(node.body) + return node + + def visit_Assign(self, node: ast.Assign) -> ast.AST: + prologue, value = self._outline_expression(node.value) + targets: List[ast.AST] = [] + for target in node.targets: + outlined, rewritten = self._outline_expression(target) + prologue.extend(outlined) + targets.append(rewritten) + node.targets = targets + node.value = value + return self._prepend_statements(node, prologue) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST: + if node.value is None: + return node + prologue, value = self._outline_expression(node.value) + target_prologue, target = self._outline_expression(node.target) + node.value = value + node.target = target + return self._prepend_statements(node, prologue + target_prologue) + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST: + prologue, target = self._outline_expression(node.target) + value_prologue, value = self._outline_expression(node.value) + node.target = target + node.value = value + return self._prepend_statements(node, prologue + value_prologue) + + def visit_Return(self, node: ast.Return) -> ast.AST: + if node.value is None: + return node + prologue, value = self._outline_expression(node.value) + node.value = value + return self._prepend_statements(node, prologue) + + def visit_Expr(self, node: ast.Expr) -> ast.AST: + prologue, value = self._outline_expression(node.value) + node.value = value + return self._prepend_statements(node, prologue) + + def visit_If(self, node: ast.If) -> ast.AST: + prologue, test = self._outline_expression(node.test) + node.test = test + node.body = self._rewrite_body(node.body) + node.orelse = self._rewrite_body(node.orelse) + return self._prepend_statements(node, prologue) + + def visit_While(self, node: ast.While) -> ast.AST: + prologue, test = self._outline_expression(node.test) + if prologue and node.orelse: + return self._mark_callback(node, 'while loop test outlining with else') + node.test = test + node.body = self._rewrite_body(node.body) + node.orelse = self._rewrite_body(node.orelse) + if not prologue: + return node + + guard = ast.If(test=ast.UnaryOp(op=ast.Not(), operand=astutils.copy_tree(test)), body=[ast.Break()], orelse=[]) + guard = ast.fix_missing_locations(ast.copy_location(guard, node.test)) + rewritten = ast.While(test=ast.Constant(value=True), body=prologue + [guard] + node.body, orelse=[]) + rewritten = ast.fix_missing_locations(ast.copy_location(rewritten, node)) + return rewritten + + def visit_For(self, node: ast.For) -> ast.AST: + prologue, iterator = self._outline_expression(node.iter) + node.iter = iterator + node.body = self._rewrite_body(node.body) + node.orelse = self._rewrite_body(node.orelse) + return self._prepend_statements(node, prologue) + + if hasattr(ast, 'AsyncFor'): + + def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AST: + prologue, iterator = self._outline_expression(node.iter) + node.iter = iterator + node.body = self._rewrite_body(node.body) + node.orelse = self._rewrite_body(node.orelse) + return self._prepend_statements(node, prologue) + + def visit_With(self, node: ast.With) -> ast.AST: + prologue: List[ast.stmt] = [] + for item in node.items: + item_prologue, context_expr = self._outline_expression(item.context_expr) + prologue.extend(item_prologue) + item.context_expr = context_expr + node.body = self._rewrite_body(node.body) + return self._prepend_statements(node, prologue) + + if hasattr(ast, 'AsyncWith'): + + def visit_AsyncWith(self, node: ast.AsyncWith) -> ast.AST: + prologue: List[ast.stmt] = [] + for item in node.items: + item_prologue, context_expr = self._outline_expression(item.context_expr) + prologue.extend(item_prologue) + item.context_expr = context_expr + node.body = self._rewrite_body(node.body) + return self._prepend_statements(node, prologue) + + def _rewrite_body(self, body: List[ast.stmt]) -> List[ast.stmt]: + result: List[ast.stmt] = [] + for statement in body: + rewritten = self.visit(statement) + if rewritten is None: + continue + if isinstance(rewritten, list): + result.extend(rewritten) + else: + result.append(rewritten) + return result + + def _outline_expression(self, + node: Optional[ast.AST], + *, + in_index_context: bool = False, + hoist_safe: bool = True) -> Tuple[List[ast.stmt], Optional[ast.AST]]: + if node is None: + return [], None + + if isinstance(node, ast.Subscript): + prologue, value = self._outline_expression(node.value, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + slice_prologue, slice_node = self._outline_subscript_slice(node.slice, hoist_safe=hoist_safe) + rewritten = ast.copy_location(ast.Subscript(value=value, slice=slice_node, ctx=node.ctx), node) + prologue.extend(slice_prologue) + if hoist_safe and in_index_context and self._should_outline_index_expression(rewritten): + return self._outline_to_temp(rewritten, prologue) + return prologue, rewritten + + if isinstance(node, ast.Call): + prologue, func = self._outline_expression(node.func, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + args: List[ast.AST] = [] + for arg in node.args: + arg_prologue, rewritten_arg = self._outline_expression(arg, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + prologue.extend(arg_prologue) + args.append(rewritten_arg) + keywords: List[ast.keyword] = [] + for keyword in node.keywords: + kw_prologue, rewritten_value = self._outline_expression(keyword.value, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + prologue.extend(kw_prologue) + keywords.append(ast.keyword(arg=keyword.arg, value=rewritten_value)) + rewritten = ast.copy_location(ast.Call(func=func, args=args, keywords=keywords), node) + if hoist_safe and in_index_context and self._should_outline_index_expression(rewritten): + return self._outline_to_temp(rewritten, prologue) + return prologue, rewritten + + if isinstance(node, ast.BinOp): + prologue, left = self._outline_expression(node.left, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + right_prologue, right = self._outline_expression(node.right, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + prologue.extend(right_prologue) + return prologue, ast.copy_location(ast.BinOp(left=left, op=astutils.copy_tree(node.op), right=right), node) + + if isinstance(node, ast.UnaryOp): + prologue, operand = self._outline_expression(node.operand, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + return prologue, ast.copy_location(ast.UnaryOp(op=astutils.copy_tree(node.op), operand=operand), node) + + if isinstance(node, ast.BoolOp): + prologue: List[ast.stmt] = [] + values: List[ast.AST] = [] + for index, value in enumerate(node.values): + value_prologue, rewritten_value = self._outline_expression(value, + in_index_context=in_index_context, + hoist_safe=hoist_safe and index == 0) + prologue.extend(value_prologue) + values.append(rewritten_value) + return prologue, ast.copy_location(ast.BoolOp(op=astutils.copy_tree(node.op), values=values), node) + + if isinstance(node, ast.Compare): + prologue, left = self._outline_expression(node.left, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + comparators: List[ast.AST] = [] + for index, comparator in enumerate(node.comparators): + comparator_prologue, rewritten_comparator = self._outline_expression(comparator, + in_index_context=in_index_context, + hoist_safe=hoist_safe + and index == 0) + prologue.extend(comparator_prologue) + comparators.append(rewritten_comparator) + return prologue, ast.copy_location( + ast.Compare(left=left, ops=astutils.copy_tree(node.ops), comparators=comparators), node) + + if isinstance(node, ast.IfExp): + prologue, test = self._outline_expression(node.test, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + body_prologue, body = self._outline_expression(node.body, + in_index_context=in_index_context, + hoist_safe=False) + orelse_prologue, orelse = self._outline_expression(node.orelse, + in_index_context=in_index_context, + hoist_safe=False) + prologue.extend(body_prologue) + prologue.extend(orelse_prologue) + return prologue, ast.copy_location(ast.IfExp(test=test, body=body, orelse=orelse), node) + + if isinstance(node, ast.Attribute): + prologue, value = self._outline_expression(node.value, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + return prologue, ast.copy_location(ast.Attribute(value=value, attr=node.attr, ctx=node.ctx), node) + + if isinstance(node, ast.Tuple): + prologue: List[ast.stmt] = [] + elements: List[ast.AST] = [] + for element in node.elts: + element_prologue, rewritten_element = self._outline_expression(element, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + prologue.extend(element_prologue) + elements.append(rewritten_element) + return prologue, ast.copy_location(ast.Tuple(elts=elements, ctx=node.ctx), node) + + if isinstance(node, ast.List): + prologue: List[ast.stmt] = [] + elements: List[ast.AST] = [] + for element in node.elts: + element_prologue, rewritten_element = self._outline_expression(element, + in_index_context=in_index_context, + hoist_safe=hoist_safe) + prologue.extend(element_prologue) + elements.append(rewritten_element) + return prologue, ast.copy_location(ast.List(elts=elements, ctx=node.ctx), node) + + return [], astutils.copy_tree(node) + + def _outline_subscript_slice(self, node: ast.AST, *, hoist_safe: bool = True) -> Tuple[List[ast.stmt], ast.AST]: + if isinstance(node, ast.Slice): + prologue, lower = self._outline_expression(node.lower, in_index_context=True, hoist_safe=hoist_safe) + upper_prologue, upper = self._outline_expression(node.upper, in_index_context=True, hoist_safe=hoist_safe) + step_prologue, step = self._outline_expression(node.step, in_index_context=True, hoist_safe=hoist_safe) + prologue.extend(upper_prologue) + prologue.extend(step_prologue) + return prologue, ast.copy_location(ast.Slice(lower=lower, upper=upper, step=step), node) + + if isinstance(node, ast.Tuple): + prologue: List[ast.stmt] = [] + elements: List[ast.AST] = [] + for element in node.elts: + element_prologue, rewritten_element = self._outline_subscript_slice(element, hoist_safe=hoist_safe) + prologue.extend(element_prologue) + elements.append(rewritten_element) + return prologue, ast.copy_location(ast.Tuple(elts=elements, ctx=node.ctx), node) + + return self._outline_expression(node, in_index_context=True, hoist_safe=hoist_safe) + + def _outline_to_temp(self, value: ast.AST, prologue: List[ast.stmt]) -> Tuple[List[ast.stmt], ast.AST]: + temp_name = self._fresh_name('__stree_idx') + assign = ast.Assign(targets=[ast.Name(id=temp_name, ctx=ast.Store())], value=astutils.copy_tree(value)) + assign = ast.fix_missing_locations(ast.copy_location(assign, value)) + prologue.append(assign) + return prologue, ast.copy_location(ast.Name(id=temp_name, ctx=ast.Load()), value) + + def _prepend_statements(self, node: ast.stmt, prologue: List[ast.stmt]) -> ast.AST: + if not prologue: + return node + return prologue + [node] + + def _mark_callback(self, statement: ast.stmt, reason: str) -> ast.stmt: + setattr(statement, _CALLBACK_REASON_ATTR, reason) + return ast.fix_missing_locations(statement) + + def _should_outline_index_expression(self, node: ast.AST) -> bool: + if not isinstance(node, (ast.Call, ast.Subscript)): + return False + if isinstance(node, ast.Call) and ast.unparse(node.func) == 'slice': + context = self._evaluation_context() + for arg in node.args: + if try_resolve_static_value(arg, context) is UNRESOLVED: + return True + for keyword in node.keywords: + if try_resolve_static_value(keyword.value, context) is UNRESOLVED: + return True + return False + return try_resolve_static_value(node, self._evaluation_context()) is UNRESOLVED + + def _seed_used_names(self, node: ast.AST) -> None: + for child in ast.walk(node): + if isinstance(child, ast.Name): + self._used_names.add(child.id) + elif isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + self._used_names.add(child.name) + elif isinstance(child, ast.arg): + self._used_names.add(child.arg) + + def _fresh_name(self, prefix: str) -> str: + candidate = prefix + while candidate in self._used_names: + self._temp_counter += 1 + candidate = f'{prefix}{self._temp_counter}' + self._used_names.add(candidate) + return candidate + + def _evaluation_context(self) -> Dict[str, Any]: + context = copy.copy(self.global_vars) + context.update(self.callable_bindings) + return context + + +class _DescriptorTrackingEnvironment: + + def __init__(self, + global_vars: Dict[str, Any], + *, + known_descriptors: Optional[Dict[str, data.Data]] = None, + seed_bindings: Optional[Dict[str, Any]] = None, + callable_bindings: Optional[Dict[str, Any]] = None) -> None: + self.global_vars = copy.copy(global_vars) + self.callable_bindings = dict(callable_bindings or {}) + self.static_bindings: Dict[str, Any] = {} + self.sequence_lengths: Dict[str, int] = {} + self.descriptor_bindings: Dict[str, data.Data] = {} + + for name, value in self.global_vars.items(): + self.static_bindings[name] = value + if isinstance(value, (list, tuple)): + self.sequence_lengths[name] = len(value) + if isinstance(value, data.Data): + self.descriptor_bindings[name] = copy.deepcopy(value) + + for name, descriptor in (known_descriptors or {}).items(): + self.descriptor_bindings[name] = copy.deepcopy(descriptor) + + for name, binding in (seed_bindings or {}).items(): + descriptor = getattr(binding, 'descriptor', None) + if isinstance(descriptor, data.Data): + self.descriptor_bindings[name] = copy.deepcopy(descriptor) + structure = getattr(binding, 'structure', None) + if isinstance(structure, (list, tuple)): + self.sequence_lengths[name] = len(structure) + + def child(self, *, cleared_names: Sequence[str] = ()) -> '_DescriptorTrackingEnvironment': + cloned = _DescriptorTrackingEnvironment(self.global_vars, callable_bindings=self.callable_bindings) + cloned.static_bindings = copy.copy(self.static_bindings) + cloned.sequence_lengths = dict(self.sequence_lengths) + cloned.descriptor_bindings = { + name: copy.deepcopy(descriptor) + for name, descriptor in self.descriptor_bindings.items() + } + for name in cleared_names: + cloned.static_bindings.pop(name, None) + cloned.sequence_lengths.pop(name, None) + return cloned + + def evaluation_context(self) -> Dict[str, Any]: + context = copy.copy(pybuiltins.__dict__) + context.update(self.global_vars) + context.update(self.callable_bindings) + context.update(self.static_bindings) + return context + + def descriptor_for_base(self, node: ast.AST) -> Optional[data.Data]: + if isinstance(node, ast.Name): + descriptor = self.descriptor_bindings.get(node.id) + if isinstance(descriptor, data.Data) and getattr(descriptor, 'shape', None) is not None: + return copy.deepcopy(descriptor) + return None + + def sequence_length_for_base(self, node: ast.AST) -> Optional[int]: + if isinstance(node, ast.Name) and node.id in self.sequence_lengths: + return self.sequence_lengths[node.id] + if isinstance(node, (ast.List, ast.Tuple)): + return len(node.elts) + value = try_resolve_static_value(node, self.evaluation_context()) + if isinstance(value, (list, tuple)): + return len(value) + return None + + def evaluate_descriptor(self, node: ast.AST) -> Optional[data.Data]: + descriptor = try_resolve_static_value(node, self.evaluation_context()) + if isinstance(descriptor, data.Data): + return copy.deepcopy(descriptor) + return None + + def update_target_binding(self, target: ast.AST, value: ast.AST) -> None: + if isinstance(target, ast.Name): + resolved = try_resolve_static_value(value, self.evaluation_context()) + if resolved is UNRESOLVED: + self.static_bindings.pop(target.id, None) + else: + self.static_bindings[target.id] = resolved + + if isinstance(value, (ast.Tuple, ast.List)): + self.sequence_lengths[target.id] = len(value.elts) + elif isinstance(resolved, (list, tuple)): + self.sequence_lengths[target.id] = len(resolved) + else: + self.sequence_lengths.pop(target.id, None) + + if isinstance(value, ast.Name) and value.id in self.descriptor_bindings: + self.descriptor_bindings[target.id] = copy.deepcopy(self.descriptor_bindings[value.id]) + return + + self.invalidate_target(target) + + def invalidate_target(self, target: ast.AST) -> None: + self.invalidate_names(self.stored_names(target)) + + def invalidate_names(self, names: Sequence[str]) -> None: + for name in names: + self.static_bindings.pop(name, None) + self.sequence_lengths.pop(name, None) + self.descriptor_bindings.pop(name, None) + + @staticmethod + def stored_names(target: ast.AST) -> set[str]: + names: set[str] = set() + for child in ast.walk(target): + if isinstance(child, ast.Name) and isinstance(child.ctx, ast.Store): + names.add(child.id) + return names + + def assigned_names(self, body: Sequence[ast.stmt]) -> set[str]: + names: set[str] = set() + for statement in body: + if isinstance(statement, ast.ExceptHandler) and statement.name: + names.add(statement.name) + names.update(self.stored_names(statement)) + return names + + +class ScheduleTreeNegativeIndexNormalizer(ast.NodeTransformer): + """Rewrite definitely negative indices into extent-relative expressions. + + Examples: + ``A[-1]`` becomes ``A[N - 1]`` when ``A`` has known extent ``N``. + + ``t[-i]`` becomes ``t[3 - i]`` for a statically known 3-element tuple or + list when ``i`` is known positive. + """ + + def __init__(self, + global_vars: Dict[str, Any], + *, + known_descriptors: Optional[Dict[str, data.Data]] = None, + seed_bindings: Optional[Dict[str, Any]] = None, + callable_bindings: Optional[Dict[str, Any]] = None) -> None: + self.global_vars = copy.copy(global_vars) + self.callable_bindings = dict(callable_bindings or {}) + self._env = _DescriptorTrackingEnvironment(global_vars, + known_descriptors=known_descriptors, + seed_bindings=seed_bindings, + callable_bindings=callable_bindings) + + def visit_Module(self, node: ast.Module) -> ast.AST: + node.body = self._rewrite_in_child_scope(node.body) + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + return self._visit_function_scope(node) + + if hasattr(ast, 'AsyncFunctionDef'): + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: + return self._visit_function_scope(node) + + def visit_Assign(self, node: ast.Assign) -> ast.AST: + node.value = self.visit(node.value) + node.targets = [self.visit(target) for target in node.targets] + for target in node.targets: + self._env.update_target_binding(target, node.value) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST: + if node.value is not None: + node.value = self.visit(node.value) + node.target = self.visit(node.target) + descriptor = self._env.evaluate_descriptor(node.annotation) + if isinstance(node.target, ast.Name) and descriptor is not None: + self._env.descriptor_bindings[node.target.id] = descriptor + if node.value is None: + self._env.invalidate_target(node.target) + return node + self._env.update_target_binding(node.target, node.value) + return node + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST: + node.target = self.visit(node.target) + node.value = self.visit(node.value) + self._env.invalidate_target(node.target) + return node + + def visit_Return(self, node: ast.Return) -> ast.AST: + if node.value is not None: + node.value = self.visit(node.value) + return node + + def visit_Expr(self, node: ast.Expr) -> ast.AST: + node.value = self.visit(node.value) + return node + + def visit_If(self, node: ast.If) -> ast.AST: + node.test = self.visit(node.test) + node.body = self._rewrite_in_child_scope(node.body) + node.orelse = self._rewrite_in_child_scope(node.orelse) + self._env.invalidate_names(self._env.assigned_names(node.body + node.orelse)) + return node + + def visit_While(self, node: ast.While) -> ast.AST: + node.test = self.visit(node.test) + node.body = self._rewrite_in_child_scope(node.body) + node.orelse = self._rewrite_in_child_scope(node.orelse) + self._env.invalidate_names(self._env.assigned_names(node.body + node.orelse)) + return node + + def visit_For(self, node: ast.For) -> ast.AST: + node.target = self.visit(node.target) + node.iter = self.visit(node.iter) + node.body = self._rewrite_in_child_scope(node.body) + node.orelse = self._rewrite_in_child_scope(node.orelse) + self._env.invalidate_target(node.target) + self._env.invalidate_names(self._env.assigned_names(node.body + node.orelse)) + return node + + if hasattr(ast, 'AsyncFor'): + + def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AST: + node.target = self.visit(node.target) + node.iter = self.visit(node.iter) + node.body = self._rewrite_in_child_scope(node.body) + node.orelse = self._rewrite_in_child_scope(node.orelse) + self._env.invalidate_target(node.target) + self._env.invalidate_names(self._env.assigned_names(node.body + node.orelse)) + return node + + def visit_With(self, node: ast.With) -> ast.AST: + for item in node.items: + item.context_expr = self.visit(item.context_expr) + if item.optional_vars is not None: + item.optional_vars = self.visit(item.optional_vars) + node.body = self._rewrite_in_child_scope(node.body) + invalidated = self._env.assigned_names(node.body) + for item in node.items: + if item.optional_vars is not None: + invalidated.update(self._env.stored_names(item.optional_vars)) + self._env.invalidate_names(invalidated) + return node + + if hasattr(ast, 'AsyncWith'): + + def visit_AsyncWith(self, node: ast.AsyncWith) -> ast.AST: + for item in node.items: + item.context_expr = self.visit(item.context_expr) + if item.optional_vars is not None: + item.optional_vars = self.visit(item.optional_vars) + node.body = self._rewrite_in_child_scope(node.body) + invalidated = self._env.assigned_names(node.body) + for item in node.items: + if item.optional_vars is not None: + invalidated.update(self._env.stored_names(item.optional_vars)) + self._env.invalidate_names(invalidated) + return node + + def visit_Try(self, node: ast.Try) -> ast.AST: + node.body = self._rewrite_in_child_scope(node.body) + node.orelse = self._rewrite_in_child_scope(node.orelse) + node.finalbody = self._rewrite_in_child_scope(node.finalbody) + for handler in node.handlers: + handler.body = self._rewrite_in_child_scope(handler.body) + invalidated = self._env.assigned_names(node.body + node.orelse + node.finalbody) + for handler in node.handlers: + invalidated.update(self._env.assigned_names(handler.body)) + if handler.name: + invalidated.add(handler.name) + self._env.invalidate_names(invalidated) + return node + + def visit_Subscript(self, node: ast.Subscript) -> ast.AST: + node = self.generic_visit(node) + rewritten = self._rewrite_subscript(node) + return ast.fix_missing_locations(ast.copy_location(rewritten, node)) + + def _visit_function_scope(self, node: ast.AST) -> ast.AST: + args = getattr(node, 'args', None) + cleared_names = [arg.arg for arg in args.args] if args is not None else [] + node.body = self._rewrite_in_child_scope(node.body, cleared_names=cleared_names) + return node + + def _rewrite_body(self, body: List[ast.stmt]) -> List[ast.stmt]: + result: List[ast.stmt] = [] + for statement in body: + rewritten = self.visit(statement) + if rewritten is None: + continue + if isinstance(rewritten, list): + result.extend(rewritten) + else: + result.append(rewritten) + return result + + def _rewrite_in_child_scope(self, body: List[ast.stmt], *, cleared_names: Sequence[str] = ()) -> List[ast.stmt]: + saved_env = self._env + self._env = self._env.child(cleared_names=cleared_names) + try: + return self._rewrite_body(body) + finally: + self._env = saved_env + + def _rewrite_subscript(self, node: ast.Subscript) -> ast.Subscript: + descriptor = self._env.descriptor_for_base(node.value) + if descriptor is not None: + rewritten_slice = self._rewrite_descriptor_slice(node.slice, tuple(descriptor.shape)) + return ast.copy_location(ast.Subscript(value=node.value, slice=rewritten_slice, ctx=node.ctx), node) + + sequence_length = self._env.sequence_length_for_base(node.value) + if sequence_length is not None: + rewritten_slice = self._rewrite_index_with_extent(node.slice, ast.Constant(sequence_length)) + return ast.copy_location(ast.Subscript(value=node.value, slice=rewritten_slice, ctx=node.ctx), node) + + return node + + def _rewrite_descriptor_slice(self, slice_node: ast.AST, shape: Tuple[Any, ...]) -> ast.AST: + if isinstance(slice_node, ast.Tuple): + extents = self._slice_extents(slice_node.elts, shape) + if extents is None: + return slice_node + elements = [ + self._rewrite_index_with_extent(element, extent) if extent is not None else element + for element, extent in zip(slice_node.elts, extents) + ] + return ast.copy_location(ast.Tuple(elts=elements, ctx=slice_node.ctx), slice_node) + + if not shape: + return slice_node + return self._rewrite_index_with_extent(slice_node, self._extent_ast(shape[0])) + + def _slice_extents(self, elements: Sequence[ast.AST], shape: Tuple[Any, ...]) -> Optional[List[Optional[ast.AST]]]: + if sum(1 for element in elements if self._is_ellipsis(element)) > 1: + return None + + consumed = sum(1 for element in elements if not self._is_newaxis(element) and not self._is_ellipsis(element)) + ellipsis_dims = max(len(shape) - consumed, 0) + dim_index = 0 + extents: List[Optional[ast.AST]] = [] + + for element in elements: + if self._is_newaxis(element): + extents.append(None) + continue + if self._is_ellipsis(element): + extents.append(None) + dim_index += ellipsis_dims + continue + if dim_index >= len(shape): + return None + extents.append(self._extent_ast(shape[dim_index])) + dim_index += 1 + return extents + + def _rewrite_index_with_extent(self, node: ast.AST, extent: ast.AST) -> ast.AST: + if isinstance(node, ast.Slice): + lower = self._rewrite_negative_expression(node.lower, extent) + upper = self._rewrite_negative_expression(node.upper, extent) + return ast.copy_location(ast.Slice(lower=lower, upper=upper, step=node.step), node) + + if isinstance(node, ast.List): + return ast.copy_location( + ast.List(elts=[self._rewrite_index_with_extent(element, extent) for element in node.elts], + ctx=node.ctx), node) + + if isinstance(node, ast.Tuple): + return ast.copy_location( + ast.Tuple(elts=[self._rewrite_index_with_extent(element, extent) for element in node.elts], + ctx=node.ctx), node) + + return self._rewrite_negative_expression(node, extent) + + def _rewrite_negative_expression(self, node: Optional[ast.AST], extent: ast.AST) -> Optional[ast.AST]: + if node is None: + return None + + resolved = try_resolve_static_value(node, self._env.evaluation_context()) + if not self._is_definitely_negative_value(resolved): + return node + + magnitude = self._negative_magnitude(node, resolved) + return ast.copy_location(ast.BinOp(left=astutils.copy_tree(extent), op=ast.Sub(), right=magnitude), node) + + @staticmethod + def _is_definitely_negative_value(value: Any) -> bool: + if value is UNRESOLVED or isinstance(value, bool): + return False + is_negative = getattr(value, 'is_negative', None) + if is_negative is True: + return True + try: + return (value < 0) == True + except Exception: + return False + + @staticmethod + def _negative_magnitude(node: ast.AST, resolved: Any) -> ast.AST: + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): + return astutils.copy_tree(node.operand) + if isinstance(resolved, numbers.Integral) and not isinstance(resolved, bool): + return ast.Constant(value=abs(int(resolved))) + return ast.UnaryOp(op=ast.USub(), operand=astutils.copy_tree(node)) + + @staticmethod + def _is_newaxis(node: ast.AST) -> bool: + return isinstance(node, ast.Constant) and node.value is None + + @staticmethod + def _is_ellipsis(node: ast.AST) -> bool: + return isinstance(node, ast.Constant) and node.value is Ellipsis + + @staticmethod + def _extent_ast(extent: Any) -> ast.AST: + if isinstance(extent, ast.AST): + return astutils.copy_tree(extent) + return ast.parse(str(extent), mode='eval').body + + +def desugar_schedule_tree_expansions(parsed_ast: ast.AST, + *, + filename: str, + global_vars: Dict[str, Any], + known_descriptors: Optional[Dict[str, data.Data]] = None, + seed_bindings: Optional[Dict[str, Any]] = None, + callable_bindings: Optional[Dict[str, Any]] = None) -> ast.AST: + """Rewrite schedule-tree-specific syntax before AST lowering.""" + expanded = ScheduleTreeExpansionDesugarer(filename, global_vars, + callable_bindings=callable_bindings).visit(astutils.copy_tree(parsed_ast)) + canonical = ScheduleTreeNegativeIndexNormalizer(global_vars, + known_descriptors=known_descriptors, + seed_bindings=seed_bindings, + callable_bindings=callable_bindings).visit(expanded) + tuple_lowered = lower_tuple_assignments(canonical) + outlined = ScheduleTreeSubscriptIndexDesugarer(global_vars, + callable_bindings=callable_bindings).visit(tuple_lowered) + return ast.fix_missing_locations(outlined) + + +def callback_reason(node: ast.AST) -> Optional[str]: + """Return the callback reason attached by schedule-tree desugaring, if any.""" + return getattr(node, _CALLBACK_REASON_ATTR, None) diff --git a/dace/frontend/python/schedule_tree/dict_support.py b/dace/frontend/python/schedule_tree/dict_support.py new file mode 100644 index 0000000000..5191e89525 --- /dev/null +++ b/dace/frontend/python/schedule_tree/dict_support.py @@ -0,0 +1,186 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional + +from dace import data, dtypes +from dace.data.creation import create_datadescriptor +from dace.data.pydata import PythonDict, merge_python_dict_component_descriptors +from dace.frontend.python.schedule_tree.static_evaluation import UNRESOLVED, try_resolve_static_value + +DescriptorInference = Callable[[ast.AST], Optional[data.Data]] +ScalarDescriptorInference = Callable[[ast.AST, Optional[data.Data]], Optional[data.Data]] +EvaluationContextFactory = Callable[[], Dict[str, Any]] + + +@dataclass +class StaticDictBinding: + entries: Dict[Any, data.Data] + + +@dataclass(frozen=True) +class DictSupportContext: + infer_descriptor: DescriptorInference + infer_scalar_descriptor: ScalarDescriptorInference + evaluation_context: EvaluationContextFactory + + +class DictSupportLibrary: + """Shared dict descriptor and binding helpers for the direct frontend.""" + + def infer_literal_descriptor(self, context: DictSupportContext, node: ast.Dict) -> PythonDict: + return infer_dict_literal_descriptor(node, context.infer_descriptor, context.infer_scalar_descriptor) + + def infer_literal_binding(self, context: DictSupportContext, node: ast.Dict) -> Optional[StaticDictBinding]: + return infer_dict_literal_binding(node, context.infer_descriptor, context.infer_scalar_descriptor, + context.evaluation_context) + + def infer_subscript_descriptor(self, + context: DictSupportContext, + descriptor: data.Data, + slice_node: ast.AST, + binding: Optional[StaticDictBinding] = None) -> Optional[data.Data]: + return infer_dict_subscript_descriptor(descriptor, slice_node, context.evaluation_context, binding) + + def infer_assignment_binding(self, context: DictSupportContext, descriptor: data.Data, + binding: Optional[StaticDictBinding], slice_node: ast.AST, + value_node: ast.AST) -> Optional[tuple[PythonDict, Optional[StaticDictBinding]]]: + return infer_dict_assignment_binding(descriptor, binding, slice_node, value_node, context.infer_descriptor, + context.infer_scalar_descriptor, context.evaluation_context) + + def infer_assignment_descriptor(self, context: DictSupportContext, descriptor: data.Data, slice_node: ast.AST, + value_node: ast.AST) -> Optional[PythonDict]: + return infer_dict_assignment_descriptor(descriptor, slice_node, value_node, context.infer_descriptor, + context.infer_scalar_descriptor, context.evaluation_context) + + +def infer_dict_literal_descriptor(node: ast.Dict, infer_descriptor: DescriptorInference, + infer_scalar_descriptor: ScalarDescriptorInference) -> PythonDict: + key_descriptors = [] + value_descriptors = [] + for key, value in zip(node.keys, node.values): + if key is None: + return PythonDict(transient=True) + key_descriptors.append(infer_descriptor(key) or infer_scalar_descriptor(key, None)) + value_descriptors.append(infer_descriptor(value) or infer_scalar_descriptor(value, None)) + return PythonDict(merge_python_dict_component_descriptors(key_descriptors, transient=True), + merge_python_dict_component_descriptors(value_descriptors, transient=True), + transient=True) + + +def infer_dict_literal_binding(node: ast.Dict, infer_descriptor: DescriptorInference, + infer_scalar_descriptor: ScalarDescriptorInference, + evaluation_context: EvaluationContextFactory) -> Optional[StaticDictBinding]: + entries: Dict[Any, data.Data] = {} + for key, value in zip(node.keys, node.values): + if key is None: + return None + key_value = try_resolve_static_value(key, evaluation_context()) + if key_value is UNRESOLVED: + return None + try: + hash(key_value) + except Exception: + return None + entries[key_value] = _infer_value_descriptor(value, infer_descriptor, infer_scalar_descriptor) + return StaticDictBinding(entries=entries) + + +def infer_dict_subscript_descriptor(descriptor: data.Data, + slice_node: ast.AST, + evaluation_context: EvaluationContextFactory, + binding: Optional[StaticDictBinding] = None) -> Optional[data.Data]: + if not isinstance(descriptor, PythonDict): + return None + key_value = try_resolve_static_value(slice_node, evaluation_context()) + if key_value is UNRESOLVED: + return None + if binding is not None: + entry = binding.entries.get(key_value) + if entry is None: + return None + result = copy.deepcopy(entry) + result.transient = True + return result + result = copy.deepcopy(descriptor.value_type) + result.transient = True + return result + + +def infer_dict_assignment_binding( + descriptor: data.Data, binding: Optional[StaticDictBinding], slice_node: ast.AST, value_node: ast.AST, + infer_descriptor: DescriptorInference, infer_scalar_descriptor: ScalarDescriptorInference, + evaluation_context: EvaluationContextFactory) -> Optional[tuple[PythonDict, Optional[StaticDictBinding]]]: + if not isinstance(descriptor, PythonDict): + return None + + key_descriptor = _descriptor_from_key(slice_node, infer_descriptor, infer_scalar_descriptor, evaluation_context) + if key_descriptor is None: + return None + + value_descriptor = _infer_value_descriptor(value_node, infer_descriptor, infer_scalar_descriptor) + updated_descriptor = PythonDict(merge_python_dict_component_descriptors((descriptor.key_type, key_descriptor), + transient=True), + merge_python_dict_component_descriptors((descriptor.value_type, value_descriptor), + transient=True), + transient=True) + + key_value = try_resolve_static_value(slice_node, evaluation_context()) + if key_value is UNRESOLVED or binding is None: + return (updated_descriptor, None if key_value is UNRESOLVED else binding) + + if key_value not in binding.entries: + return (updated_descriptor, None) + + updated_binding = copy.deepcopy(binding) + updated_binding.entries[key_value] = copy.deepcopy(value_descriptor) + updated_binding.entries[key_value].transient = True + return (updated_descriptor, updated_binding) + + +def infer_dict_assignment_descriptor(descriptor: data.Data, slice_node: ast.AST, value_node: ast.AST, + infer_descriptor: DescriptorInference, + infer_scalar_descriptor: ScalarDescriptorInference, + evaluation_context: EvaluationContextFactory) -> Optional[PythonDict]: + updated = infer_dict_assignment_binding(descriptor, None, slice_node, value_node, infer_descriptor, + infer_scalar_descriptor, evaluation_context) + return None if updated is None else updated[0] + + +def _infer_value_descriptor(value_node: ast.AST, infer_descriptor: DescriptorInference, + infer_scalar_descriptor: ScalarDescriptorInference) -> data.Data: + descriptor = infer_descriptor(value_node) or infer_scalar_descriptor(value_node, None) + if descriptor is None: + return data.Scalar(dtypes.pyobject(), transient=True) + descriptor = copy.deepcopy(descriptor) + descriptor.transient = True + return descriptor + + +def _descriptor_from_key(slice_node: ast.AST, infer_descriptor: DescriptorInference, + infer_scalar_descriptor: ScalarDescriptorInference, + evaluation_context: EvaluationContextFactory) -> Optional[data.Data]: + descriptor = _descriptor_from_static_key(slice_node, evaluation_context) + if descriptor is not None: + return descriptor + descriptor = infer_descriptor(slice_node) or infer_scalar_descriptor(slice_node, None) + if descriptor is None: + return None + descriptor = copy.deepcopy(descriptor) + descriptor.transient = True + return descriptor + + +def _descriptor_from_static_key(slice_node: ast.AST, + evaluation_context: EvaluationContextFactory) -> Optional[data.Data]: + key_value = try_resolve_static_value(slice_node, evaluation_context()) + if key_value is UNRESOLVED: + return None + try: + descriptor = create_datadescriptor(key_value) + except Exception: + return None + descriptor.transient = True + return descriptor diff --git a/dace/frontend/python/schedule_tree/dunder_support.py b/dace/frontend/python/schedule_tree/dunder_support.py new file mode 100644 index 0000000000..9a0620b042 --- /dev/null +++ b/dace/frontend/python/schedule_tree/dunder_support.py @@ -0,0 +1,335 @@ +"""Helpers for lowering Python syntax sugar to explicit dunder calls.""" + +from __future__ import annotations + +import ast +import builtins as pybuiltins +import inspect +import math +from typing import Dict, Optional, Tuple + +from dace.frontend.python import astutils +from dace.frontend.python.schedule_tree.callable_support import CallableResolver +from dace.frontend.python.schedule_tree.static_evaluation import UNRESOLVED + +_BINARY_DUNDERS: Dict[type[ast.operator], Tuple[str, str, Optional[str]]] = { + ast.Add: ('__add__', '__radd__', '__iadd__'), + ast.Sub: ('__sub__', '__rsub__', '__isub__'), + ast.Mult: ('__mul__', '__rmul__', '__imul__'), + ast.MatMult: ('__matmul__', '__rmatmul__', '__imatmul__'), + ast.Div: ('__truediv__', '__rtruediv__', '__itruediv__'), + ast.FloorDiv: ('__floordiv__', '__rfloordiv__', '__ifloordiv__'), + ast.Mod: ('__mod__', '__rmod__', '__imod__'), + ast.Pow: ('__pow__', '__rpow__', '__ipow__'), + ast.LShift: ('__lshift__', '__rlshift__', '__ilshift__'), + ast.RShift: ('__rshift__', '__rrshift__', '__irshift__'), + ast.BitOr: ('__or__', '__ror__', '__ior__'), + ast.BitXor: ('__xor__', '__rxor__', '__ixor__'), + ast.BitAnd: ('__and__', '__rand__', '__iand__'), +} +_BINARY_DUNDERS = {key: value for key, value in _BINARY_DUNDERS.items() if value is not None} + +_UNARY_DUNDERS: Dict[type[ast.unaryop], str] = { + ast.UAdd: '__pos__', + ast.USub: '__neg__', + ast.Invert: '__invert__', +} + +_COMPARE_DUNDERS: Dict[type[ast.cmpop], Tuple[Optional[str], Optional[str]]] = { + ast.Eq: ('__eq__', '__eq__'), + ast.NotEq: ('__ne__', '__ne__'), + ast.Lt: ('__lt__', '__gt__'), + ast.LtE: ('__le__', '__ge__'), + ast.Gt: ('__gt__', '__lt__'), + ast.GtE: ('__ge__', '__le__'), +} + +_UNARY_CALL_DUNDERS = { + pybuiltins.hash: '__hash__', + pybuiltins.repr: '__repr__', + pybuiltins.str: '__str__', + pybuiltins.bool: '__bool__', + pybuiltins.int: '__int__', + pybuiltins.float: '__float__', + pybuiltins.bytes: '__bytes__', + pybuiltins.complex: '__complex__', + pybuiltins.len: '__len__', + pybuiltins.iter: '__iter__', + pybuiltins.reversed: '__reversed__', + pybuiltins.next: '__next__', + pybuiltins.abs: '__abs__', + pybuiltins.dir: '__dir__', + math.trunc: '__trunc__', + math.floor: '__floor__', + math.ceil: '__ceil__', +} + + +def rewrite_sugared_expression(node: ast.AST, callable_resolver: CallableResolver) -> Optional[ast.AST]: + """Return an explicit dunder-call AST for a sugared expression when possible.""" + if isinstance(node, ast.Call): + return _rewrite_call(node, callable_resolver) + + if isinstance(node, ast.BinOp): + return _rewrite_binop(node.left, node.op, node.right, callable_resolver) + + if isinstance(node, ast.UnaryOp): + method_name = _UNARY_DUNDERS.get(type(node.op)) + if method_name is None: + return None + return _call_on_operand(node.operand, method_name, (), callable_resolver, template=node) + + if isinstance(node, ast.Compare) and len(node.ops) == 1 and len(node.comparators) == 1: + return _rewrite_compare(node.left, node.ops[0], node.comparators[0], callable_resolver, node) + + if isinstance(node, ast.Subscript) and isinstance(node.ctx, ast.Load): + return _rewrite_subscript(node, callable_resolver) + + return None + + +def rewrite_subscript_assignment(target: ast.Subscript, value: ast.AST, + callable_resolver: CallableResolver) -> Optional[ast.Expr]: + call = _call_on_operand(target.value, '__setitem__', (target.slice, value), callable_resolver, template=target) + if call is None: + return None + return ast.copy_location(ast.Expr(value=call), target) + + +def rewrite_subscript_delete(target: ast.Subscript, callable_resolver: CallableResolver) -> Optional[ast.Expr]: + call = _call_on_operand(target.value, '__delitem__', (target.slice, ), callable_resolver, template=target) + if call is None: + return None + return ast.copy_location(ast.Expr(value=call), target) + + +def rewrite_augassign(target: ast.AST, op: ast.operator, value: ast.AST, + callable_resolver: CallableResolver) -> Optional[ast.stmt]: + if isinstance(target, ast.Subscript): + current_value = ast.copy_location( + ast.Subscript(value=astutils.copy_tree(target.value), + slice=astutils.copy_tree(target.slice), + ctx=ast.Load()), target) + getter_call = _rewrite_subscript(current_value, callable_resolver) + updated = _rewrite_augassign_value(getter_call, op, value, template=target) if getter_call is not None else None + if updated is None: + updated = _rewrite_binop(current_value, op, value, callable_resolver, prefer_inplace=True) + if updated is None: + return None + return rewrite_subscript_assignment(target, updated, callable_resolver) + + load_target = _load_context_copy(target) + if load_target is None: + return None + + updated = _rewrite_binop(load_target, op, value, callable_resolver, prefer_inplace=True) + if updated is None: + return None + + assign = ast.Assign(targets=[astutils.copy_tree(target)], value=updated) + return ast.copy_location(assign, target) + + +def _rewrite_call(node: ast.Call, callable_resolver: CallableResolver) -> Optional[ast.Call]: + if isinstance(node.func, ast.Attribute) and node.func.attr == '__call__': + return None + value = callable_resolver.resolve_static_value(node.func) + if value is UNRESOLVED or not callable(value): + return None + + builtin_rewritten = _rewrite_builtin_call(node, value, callable_resolver) + if builtin_rewritten is not None: + return builtin_rewritten + + from dace import SDFG + from dace import data + + if isinstance(value, (SDFG, data.Data)) or inspect.isclass(value): + return None + if (inspect.isfunction(value) or inspect.ismethod(value) or inspect.isbuiltin(value) + or inspect.ismethoddescriptor(value) or getattr(value, '_schedule_tree_inline_callable', False) + or hasattr(value, '__schedule_tree__') or hasattr(value, '__sdfg__')): + return None + return _call_on_operand(node.func, + '__call__', + tuple(node.args), + callable_resolver, + template=node, + keywords=node.keywords) + + +def _rewrite_builtin_call(node: ast.Call, builtin_value, callable_resolver: CallableResolver) -> Optional[ast.Call]: + if builtin_value in _UNARY_CALL_DUNDERS and len(node.args) == 1 and not node.keywords: + return _call_on_operand(node.args[0], _UNARY_CALL_DUNDERS[builtin_value], (), callable_resolver, template=node) + + if builtin_value is pybuiltins.format and not node.keywords and len(node.args) in {1, 2}: + format_spec = ast.Constant(value='') if len(node.args) == 1 else node.args[1] + return _call_on_operand(node.args[0], '__format__', (format_spec, ), callable_resolver, template=node) + + if builtin_value is pybuiltins.round and not node.keywords and len(node.args) in {1, 2}: + return _call_on_operand(node.args[0], '__round__', tuple(node.args[1:]), callable_resolver, template=node) + + if builtin_value is pybuiltins.divmod and not node.keywords and len(node.args) == 2: + rewritten = _call_on_operand(node.args[0], '__divmod__', (node.args[1], ), callable_resolver, template=node) + if rewritten is not None: + return rewritten + return _call_on_operand(node.args[1], '__rdivmod__', (node.args[0], ), callable_resolver, template=node) + + if builtin_value is pybuiltins.isinstance and not node.keywords and len(node.args) == 2: + return _call_on_operand(node.args[1], '__instancecheck__', (node.args[0], ), callable_resolver, template=node) + + if builtin_value is pybuiltins.issubclass and not node.keywords and len(node.args) == 2: + return _call_on_operand(node.args[1], '__subclasscheck__', (node.args[0], ), callable_resolver, template=node) + + return None + + +def _rewrite_binop(left: ast.AST, + op: ast.operator, + right: ast.AST, + callable_resolver: CallableResolver, + *, + prefer_inplace: bool = False) -> Optional[ast.Call]: + dunders = _BINARY_DUNDERS.get(type(op)) + if dunders is None: + return None + + direct_name, reflected_name, inplace_name = dunders + if prefer_inplace and inplace_name is not None: + rewritten = _call_on_operand(left, inplace_name, (right, ), callable_resolver, template=left) + if rewritten is not None: + return rewritten + + rewritten = _call_on_operand(left, direct_name, (right, ), callable_resolver, template=left) + if rewritten is not None: + return rewritten + + return _call_on_operand(right, reflected_name, (left, ), callable_resolver, template=right) + + +def _rewrite_augassign_value(left: ast.AST, op: ast.operator, right: ast.AST, *, + template: ast.AST) -> Optional[ast.Call]: + dunders = _BINARY_DUNDERS.get(type(op)) + if dunders is None: + return None + + direct_name, _, inplace_name = dunders + method_name = inplace_name or direct_name + if method_name is None: + return None + return _build_method_call(left, method_name, (right, ), template=template) + + +def _rewrite_compare(left: ast.AST, op: ast.cmpop, right: ast.AST, callable_resolver: CallableResolver, + template: ast.AST) -> Optional[ast.AST]: + if isinstance(op, ast.In): + return _call_on_operand(right, '__contains__', (left, ), callable_resolver, template=template) + if isinstance(op, ast.NotIn): + contains = _call_on_operand(right, '__contains__', (left, ), callable_resolver, template=template) + if contains is None: + return None + return ast.copy_location(ast.UnaryOp(op=ast.Not(), operand=contains), template) + + dunders = _COMPARE_DUNDERS.get(type(op)) + if dunders is None: + return None + + direct_name, reflected_name = dunders + if direct_name is not None: + rewritten = _call_on_operand(left, direct_name, (right, ), callable_resolver, template=template) + if rewritten is not None: + return rewritten + if reflected_name is not None: + return _call_on_operand(right, reflected_name, (left, ), callable_resolver, template=template) + return None + + +def _rewrite_subscript(node: ast.Subscript, callable_resolver: CallableResolver) -> Optional[ast.Call]: + owner = _resolve_static_owner(node.value, callable_resolver) + if owner is not UNRESOLVED and inspect.isclass(owner): + rewritten = _call_on_operand(node.value, '__class_getitem__', (node.slice, ), callable_resolver, template=node) + if rewritten is not None: + return rewritten + return _call_on_operand(node.value, '__getitem__', (node.slice, ), callable_resolver, template=node) + + +def _call_on_operand(operand: ast.AST, + method_name: str, + args: Tuple[ast.AST, ...], + callable_resolver: CallableResolver, + *, + template: ast.AST, + keywords: Optional[list[ast.keyword]] = None) -> Optional[ast.Call]: + method_value = _resolve_dunder_method(operand, method_name, callable_resolver) + if not _is_parseable_dunder_value(method_value): + return None + return _build_method_call(operand, method_name, args, template=template, keywords=keywords) + + +def _build_method_call(operand: ast.AST, + method_name: str, + args: Tuple[ast.AST, ...], + *, + template: ast.AST, + keywords: Optional[list[ast.keyword]] = None) -> ast.Call: + method = ast.copy_location(ast.Attribute(value=astutils.copy_tree(operand), attr=method_name, ctx=ast.Load()), + template) + return ast.copy_location( + ast.Call(func=method, + args=[astutils.copy_tree(arg) for arg in args], + keywords=[astutils.copy_tree(keyword) for keyword in (keywords or [])]), template) + + +def _resolve_dunder_method(operand: ast.AST, method_name: str, callable_resolver: CallableResolver): + owner = _resolve_static_owner(operand, callable_resolver) + if owner is UNRESOLVED: + return UNRESOLVED + try: + return getattr(owner, method_name) + except Exception: + return UNRESOLVED + + +def _resolve_static_owner(node: ast.AST, callable_resolver: CallableResolver): + if isinstance(node, ast.Attribute): + owner = _resolve_static_owner(node.value, callable_resolver) + if owner is UNRESOLVED: + return UNRESOLVED + try: + return getattr(owner, node.attr) + except Exception: + return UNRESOLVED + return callable_resolver.resolve_static_value(node) + + +def _is_parseable_dunder_value(value) -> bool: + if value is UNRESOLVED or not callable(value): + return False + if getattr(value, '_schedule_tree_inline_callable', False): + return True + if hasattr(value, '__schedule_tree__') or hasattr(value, '__sdfg__') or hasattr(value, '_generate_schedule_tree'): + return True + module_name = getattr(value.__func__ if inspect.ismethod(value) else value, '__module__', '') + if isinstance(module_name, str) and module_name.startswith(('dace.frontend.python', 'sympy', 'numpy')): + return False + if inspect.ismethod(value) or inspect.isfunction(value): + return True + function = getattr(value, '__func__', None) + if inspect.isfunction(function): + return True + wrapped = getattr(value, 'f', None) + return callable(wrapped) + + +def _load_context_copy(target: ast.AST) -> Optional[ast.AST]: + copied = astutils.copy_tree(target) + if isinstance(copied, ast.Name): + copied.ctx = ast.Load() + return copied + if isinstance(copied, ast.Attribute): + copied.ctx = ast.Load() + return copied + if isinstance(copied, ast.Subscript): + copied.ctx = ast.Load() + return copied + return None diff --git a/dace/frontend/python/schedule_tree/dynamic_scope_copy.py b/dace/frontend/python/schedule_tree/dynamic_scope_copy.py new file mode 100644 index 0000000000..b80a20b4db --- /dev/null +++ b/dace/frontend/python/schedule_tree/dynamic_scope_copy.py @@ -0,0 +1,107 @@ +"""Normalize frontend dynamic scope inputs into dedicated schedule-tree copy nodes.""" + +from __future__ import annotations + +import ast +import copy +from typing import Optional, Set + +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +def promote_dynamic_scope_copies(root: tn.ScheduleTreeRoot) -> None: + """Rewrite scalar-copy tasklets that feed dynamic frontend scopes. + + The direct Python schedule-tree frontend outlines unresolved subscript + expressions into scalar temporaries such as ``__stree_idx = A[i]`` before a + dynamic ``FrontendMap``. Those nodes are semantically dynamic scope inputs, + so normalize them into ``DynScopeCopyNode`` to match the schedule-tree IR + contract used by SDFG-derived trees. + """ + + _DynamicScopeCopyPromoter().visit(root) + + +class _DynamicScopeCopyPromoter(tn.ScheduleNodeTransformer): + + def visit_scope(self, node: tn.ScheduleTreeScope): + self.generic_visit(node) + + for index, child in enumerate(node.children): + dynamic_inputs = _frontend_dynamic_input_names(child) + if not dynamic_inputs: + continue + + cursor = index - 1 + while cursor >= 0: + sibling = node.children[cursor] + if isinstance(sibling, tn.DynScopeCopyNode): + cursor -= 1 + continue + if not isinstance(sibling, tn.TaskletNode): + break + + replacement = _dynscope_replacement(sibling, dynamic_inputs) + if replacement is None: + break + + replacement.parent = node + node.children[cursor] = replacement + dynamic_inputs.discard(replacement.target) + cursor -= 1 + + return node + + +def _frontend_dynamic_input_names(node: tn.ScheduleTreeNode) -> Set[str]: + if not isinstance(node, tn.MapScope) or not isinstance(node.node, tn.FrontendMap): + return set() + + result: Set[str] = set() + for start, stop, step in node.node.ranges: + for expr in (start, stop, step): + name = _simple_name(expr) + if name is not None: + result.add(name) + return result + + +def _simple_name(expr: str) -> Optional[str]: + try: + parsed = ast.parse(expr, mode='eval') + except SyntaxError: + return None + return parsed.body.id if isinstance(parsed.body, ast.Name) else None + + +def _dynscope_replacement(node: tn.TaskletNode, dynamic_inputs: Set[str]) -> Optional[tn.DynScopeCopyNode]: + if len(node.in_memlets) != 1 or len(node.out_memlets) != 1: + return None + + target = next(iter(node.out_memlets.values())).data + if target not in dynamic_inputs or not _is_direct_assignment_to(node, target): + return None + + memlet = copy.deepcopy(next(iter(node.in_memlets.values()))) + return tn.DynScopeCopyNode(target=target, memlet=memlet) + + +def _is_direct_assignment_to(node: tn.TaskletNode, target: str) -> bool: + code = getattr(node.node, 'code', None) + text = getattr(code, 'as_string', None) + if not text: + return False + + try: + parsed = ast.parse(text) + except SyntaxError: + return False + + if len(parsed.body) != 1 or not isinstance(parsed.body[0], ast.Assign) or len(parsed.body[0].targets) != 1: + return False + + assign = parsed.body[0] + if not isinstance(assign.targets[0], ast.Name) or assign.targets[0].id != target: + return False + + return isinstance(assign.value, (ast.Name, ast.Attribute, ast.Subscript)) diff --git a/dace/frontend/python/schedule_tree/expression_support.py b/dace/frontend/python/schedule_tree/expression_support.py new file mode 100644 index 0000000000..91e9a7c491 --- /dev/null +++ b/dace/frontend/python/schedule_tree/expression_support.py @@ -0,0 +1,410 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Generic expression planning helpers for the direct schedule-tree frontend. + +Terminology +----------- +"Expression planning" is not meant as a standardized compiler term here. +Within the schedule-tree frontend it is an internal shorthand for the step +that decides: + +1. which array-valued subexpressions should be materialized into temporaries, +2. in what order they should be materialized, and +3. which lowering path should handle each materialized step. + +Conceptually this is closest to turning nested expressions into a restricted +3-address-code / A-normal-form style representation before schedule-tree +lowering. The goal is not to preserve a single opaque source expression, but to +expose the intermediate array operations that later passes can lower as maps, +library calls, or fallback tasklets. + +Examples +-------- +Nested call arguments: + + Source: + inner(A + 1, B + 2) + + Planned form: + __stree_tmp = A + 1 + __stree_tmp1 = B + 2 + inner(__stree_tmp, __stree_tmp1) + + Lowering effect: + The two temporary assignments can each become explicit elementwise map + scopes through the NumPy lowering layer, while the call itself sees only + simple array arguments. + +Array-valued returns: + + Source: + return A + B + + Planned form: + __stree_tmp = A + B + return __stree_tmp + + Lowering effect: + The returned expression is no longer an opaque return value; it becomes + a normal assignment that can be lowered structurally before the final + ReturnNode. + +Chained matmul: + + Source: + return A @ B @ C + + Planned form: + __stree_tmp = A @ B + __stree_tmp1 = __stree_tmp @ C + return __stree_tmp1 + + Lowering effect: + Each matmul step can lower independently as a schedule-tree library call + instead of treating the whole chain as one opaque expression. +""" + +import ast +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple + +from dace import data, dtypes +from dace.frontend.common import op_repository as oprepo +from dace.frontend.python import astutils +from dace.frontend.python.replacements.utils import broadcast_together +from dace.memlet import Memlet +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +DescriptorInferer = Callable[[ast.AST], Optional[data.Data]] +ExpressionMaterializer = Callable[[ast.AST, data.Data], ast.AST] +DataAccessResolver = Callable[[ast.AST], Optional[Tuple[str, Memlet, data.Data, Optional[data.Data]]]] +InputMemletCollector = Callable[[ast.AST], Dict[str, Memlet]] +OutputTargetResolver = Callable[[ast.AST, ast.AST, Optional[data.Data]], Optional[Tuple[str, Memlet, data.Data]]] +CallableNameResolver = Callable[[ast.AST], str] +CallMaterializationPredicate = Callable[[ast.Call], bool] + + +@dataclass(frozen=True) +class ExpressionPlanningContext: + """Callbacks needed by the expression planner. + + The planner is deliberately frontend-agnostic: it does not own descriptor + repositories or emit schedule-tree nodes by itself. Instead it asks the + surrounding builder how to infer descriptors, how to materialize a chosen + subexpression, and how to resolve accesses/memlets for lowering passes. + """ + + infer_descriptor: DescriptorInferer + materialize_expression: ExpressionMaterializer + resolve_data_access: DataAccessResolver + collect_input_memlets: InputMemletCollector + resolve_output_target: OutputTargetResolver + resolve_callable_name: Optional[CallableNameResolver] = None + should_materialize_call: Optional[CallMaterializationPredicate] = None + + +class GenericExpressionSupportLibrary: + """Planning and lowering helpers for non-trivial array-valued expressions. + + This module currently has two responsibilities: + + 1. Rewrite nested array-valued expressions into a sequence of simpler + expressions by materializing selected subexpressions into temporaries. + 2. Provide lowering hooks for expression forms that are better handled by a + dedicated pass than by the generic NumPy/tasklet fallback. + + At the moment the dedicated lowering pass covers matmul. Other array-valued + expressions are planned here and then handed back to the builder, which in + turn routes them through the NumPy support layer or the generic fallback. + """ + + def __init__(self) -> None: + self.assignment_passes = (_OperatorAssignmentPass(), ) + + def plan_expression(self, context: ExpressionPlanningContext, node: ast.AST, *, materialize_root: bool) -> ast.AST: + return _ExpressionPlanner(context).rewrite(node, materialize_root=materialize_root) + + def lower_assignment(self, context: ExpressionPlanningContext, target: ast.AST, value: ast.AST, + annotated_descriptor: Optional[data.Data]) -> Optional[tn.ScheduleTreeNode]: + for lowering_pass in self.assignment_passes: + lowered = lowering_pass.lower_assignment(context, target, value, annotated_descriptor) + if lowered is not None: + return lowered + return None + + def infer_expression_descriptor(self, context: ExpressionPlanningContext, node: ast.AST) -> Optional[data.Data]: + for lowering_pass in self.assignment_passes: + descriptor = lowering_pass.infer_expression_descriptor(context, node) + if descriptor is not None: + return descriptor + return None + + +class _ExpressionPlanner: + + def __init__(self, context: ExpressionPlanningContext) -> None: + self.context = context + + def rewrite(self, node: ast.AST, *, materialize_root: bool) -> ast.AST: + """Rewrite an expression tree into a planned form. + + If ``materialize_root`` is true and the final expression is still a + non-trivial array expression, the root expression is also turned into a + temporary. This is used for contexts such as ``return A + B``, where the + frontend wants to lower the array expression structurally before + emitting the final ReturnNode. + """ + + rewritten = self._rewrite(astutils.copy_tree(node)) + if materialize_root and self._should_materialize(rewritten): + return self._materialize(rewritten) + return rewritten + + def _rewrite(self, node: ast.AST) -> ast.AST: + if isinstance(node, ast.BinOp): + return ast.copy_location( + ast.BinOp(left=self._rewrite_binop_child(node.left, node.right), + op=astutils.copy_tree(node.op), + right=self._rewrite_binop_child(node.right, node.left)), node) + + if isinstance(node, ast.UnaryOp): + return ast.copy_location( + ast.UnaryOp(op=astutils.copy_tree(node.op), operand=self._rewrite_child(node.operand)), node) + + if isinstance(node, ast.BoolOp): + return ast.copy_location( + ast.BoolOp(op=astutils.copy_tree(node.op), + values=[self._rewrite_child(value) for value in node.values]), node) + + if isinstance(node, ast.Compare): + return ast.copy_location( + ast.Compare(left=self._rewrite_child(node.left), + ops=astutils.copy_tree(node.ops), + comparators=[self._rewrite_child(comp) for comp in node.comparators]), node) + + if isinstance(node, ast.IfExp): + return ast.copy_location( + ast.IfExp(test=self._rewrite_child(node.test), + body=self._rewrite_child(node.body), + orelse=self._rewrite_child(node.orelse)), node) + + if isinstance(node, ast.Attribute): + return ast.copy_location(ast.Attribute(value=self._rewrite_child(node.value), attr=node.attr, ctx=node.ctx), + node) + + if isinstance(node, ast.Call): + iterator_protocol_call = self._is_iterator_protocol_call(node) + array_constructor_call = self._is_array_constructor_call(node) + return ast.copy_location( + ast.Call(func=self._rewrite_call_func(node.func, iterator_protocol_call=iterator_protocol_call), + args=[ + self._rewrite_child(arg, + materialize_pyobject_call=iterator_protocol_call, + preserve_array_literal=array_constructor_call and index == 0) + for index, arg in enumerate(node.args) + ], + keywords=[ + ast.keyword(arg=kw.arg, + value=self._rewrite_child(kw.value, + materialize_pyobject_call=iterator_protocol_call)) + for kw in node.keywords + ]), node) + + if isinstance(node, ast.Tuple): + return ast.copy_location(ast.Tuple(elts=[self._rewrite_child(elt) for elt in node.elts], ctx=node.ctx), + node) + + if isinstance(node, ast.List): + return ast.copy_location(ast.List(elts=[self._rewrite_child(elt) for elt in node.elts], ctx=node.ctx), node) + + return node + + def _rewrite_call_func(self, func: ast.AST, *, iterator_protocol_call: bool) -> ast.AST: + if isinstance(func, ast.Attribute): + return ast.copy_location( + ast.Attribute(value=self._rewrite_child(func.value, materialize_pyobject_call=iterator_protocol_call), + attr=func.attr, + ctx=func.ctx), func) + return astutils.copy_tree(func) + + def _rewrite_child(self, + node: ast.AST, + *, + materialize_pyobject_call: bool = False, + preserve_array_literal: bool = False) -> ast.AST: + rewritten = self._rewrite(astutils.copy_tree(node)) + if preserve_array_literal: + return rewritten + if self._should_materialize(rewritten): + return self._materialize(rewritten) + if materialize_pyobject_call and self._should_materialize_pyobject_call(rewritten): + return self._materialize(rewritten) + return rewritten + + def _rewrite_binop_child(self, node: ast.AST, sibling: ast.AST) -> ast.AST: + rewritten = self._rewrite(astutils.copy_tree(node)) + array_literal_descriptor = None + if isinstance(rewritten, (ast.List, ast.Tuple)): + array_literal_descriptor = self.context.infer_descriptor( + ast.Call(func=ast.Attribute(value=ast.Name(id='numpy', ctx=ast.Load()), attr='array', ctx=ast.Load()), + args=[astutils.copy_tree(rewritten)], + keywords=[])) + sibling_descriptor = self.context.infer_descriptor(astutils.copy_tree(sibling)) + if (array_literal_descriptor is not None and sibling_descriptor is not None + and not isinstance(sibling_descriptor, data.Scalar)): + return self.context.materialize_expression(rewritten, array_literal_descriptor) + + if self._should_materialize(rewritten): + return self._materialize(rewritten) + return rewritten + + def _should_materialize(self, node: ast.AST) -> bool: + """Return whether ``node`` should become a temporary. + + The planner only materializes array-valued expressions that are not + already simple data accesses. Scalars remain inline. Plain accesses such + as ``A`` or ``A[i:j]`` stay inline as well; they are already representable + without introducing extra storage. + """ + + if (isinstance(node, ast.Call) and self.context.should_materialize_call is not None + and self.context.should_materialize_call(node)): + return True + + descriptor = self.context.infer_descriptor(node) + if descriptor is None: + return False + if isinstance(descriptor, data.Scalar): + return isinstance(node, ast.Call) + if self.context.resolve_data_access(node) is not None: + return False + return isinstance(node, (ast.Attribute, ast.BinOp, ast.BoolOp, ast.Call, ast.Compare, ast.IfExp, ast.UnaryOp)) + + def _should_materialize_pyobject_call(self, node: ast.AST) -> bool: + descriptor = self.context.infer_descriptor(node) + return isinstance(node, ast.Call) and isinstance(descriptor, data.Scalar) and isinstance( + descriptor.dtype, dtypes.pyobject) + + def _is_iterator_protocol_call(self, node: ast.Call) -> bool: + if isinstance(node.func, ast.Name): + return node.func.id in {'iter', 'next', '__dace_iterator_init', '__dace_iterator_next'} + return isinstance(node.func, ast.Attribute) and node.func.attr == '__next__' + + def _is_array_constructor_call(self, node: ast.Call) -> bool: + if self.context.resolve_callable_name is not None: + return self.context.resolve_callable_name(node.func) in {'numpy.array', 'numpy.asarray'} + return astutils.rname(node.func) in {'numpy.array', 'numpy.asarray'} + + def _materialize(self, node: ast.AST) -> ast.AST: + descriptor = self.context.infer_descriptor(node) + if (descriptor is None and isinstance(node, ast.Call) and self.context.should_materialize_call is not None + and self.context.should_materialize_call(node)): + descriptor = data.Scalar(dtypes.pyobject(), transient=True) + if descriptor is None: + return node + return self.context.materialize_expression(node, descriptor) + + +class _OperatorAssignmentPass: + """Lower materialized binary operator assignments as frontend library calls. + + Uses the operator descriptor-inference registry to handle any binary + operator that has a registered inference function (currently ``@`` / MatMult). + The planner first linearizes chains such as ``A @ B @ C`` into temporary + assignments, then each individual assignment is recognized here. + """ + + # Maps AST operator class -> (registry_name, library_name) + _OP_MAP = {ast.MatMult: ('MatMult', 'MatMul')} + + def lower_assignment(self, context: ExpressionPlanningContext, target: ast.AST, value: ast.AST, + annotated_descriptor: Optional[data.Data]) -> Optional[tn.ScheduleTreeNode]: + if not isinstance(value, ast.BinOp): + return None + entry = self._OP_MAP.get(type(value.op)) + if entry is None: + return None + registry_name, library_name = entry + descriptor = self.infer_expression_descriptor(context, value) + if descriptor is None or isinstance(descriptor, data.Scalar): + return None + output = context.resolve_output_target(target, value, annotated_descriptor) + if output is None: + return None + _, out_memlet, _ = output + in_memlets = context.collect_input_memlets(value) + if len(in_memlets) != 2: + return None + return tn.LibraryCall(node=tn.FrontendLibrary(name=library_name, + properties=self._operator_properties(registry_name)), + in_memlets=in_memlets, + out_memlets={'out': out_memlet}) + + def infer_expression_descriptor(self, context: ExpressionPlanningContext, node: ast.AST) -> Optional[data.Data]: + if not isinstance(node, ast.BinOp): + return None + entry = self._OP_MAP.get(type(node.op)) + if entry is None: + return None + registry_name, _library_name = entry + + left_descriptor = context.infer_descriptor(node.left) + right_descriptor = context.infer_descriptor(node.right) + if left_descriptor is None or right_descriptor is None: + return None + + infer_fn = oprepo.Replacements.get_operator_descriptor_inference(registry_name) + if infer_fn is not None: + try: + result = infer_fn(left_descriptor, right_descriptor) + if result is not None: + return result + except Exception: + pass + + return None + + @staticmethod + def _operator_properties(registry_name: str) -> dict: + if registry_name == 'MatMult': + return {'alpha': 1, 'beta': 0} + return {} + + +def _matmul_output_shape(left_shape: Tuple[object, ...], right_shape: Tuple[object, + ...]) -> Optional[Tuple[object, ...]]: + """Infer the result shape for NumPy-style matmul semantics. + + This mirrors the subset of ``numpy.matmul`` shape rules that the direct + schedule-tree frontend currently lowers structurally: vector-vector, + matrix-vector, vector-matrix, matrix-matrix, and batched matrix-matrix. + """ + + if len(left_shape) == 1 and len(right_shape) == 1: + return tuple() + + if len(left_shape) == 2 and len(right_shape) == 1: + return (left_shape[0], ) + + if len(left_shape) == 1 and len(right_shape) == 2: + return (right_shape[1], ) + + if len(left_shape) < 2 or len(right_shape) < 2: + return None + + batch_shape = _broadcast_prefix_shapes(left_shape[:-2], right_shape[:-2]) + if batch_shape is None: + return None + return batch_shape + (left_shape[-2], right_shape[-1]) + + +def _broadcast_prefix_shapes(left_prefix: Tuple[object, ...], right_prefix: Tuple[object, + ...]) -> Optional[Tuple[object, ...]]: + if not left_prefix: + return tuple(right_prefix) + if not right_prefix: + return tuple(left_prefix) + try: + result, _, _, _, _ = broadcast_together(left_prefix, right_prefix) + except Exception: + return None + return tuple(result) diff --git a/dace/frontend/python/schedule_tree/function_inlining.py b/dace/frontend/python/schedule_tree/function_inlining.py new file mode 100644 index 0000000000..9b5c5d01b3 --- /dev/null +++ b/dace/frontend/python/schedule_tree/function_inlining.py @@ -0,0 +1,401 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +""" +Bottom-up parallel inlining of ``@dace.program`` calls in schedule trees. + +After the schedule-tree builder emits :class:`FunctionCallScope` placeholders +for every nested ``@dace.program`` call, :func:`resolve_function_calls` +collects them, generates the callee schedule trees (in parallel when there are +multiple independent callees), and inlines the results. +""" + +import ast +import copy +import re +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Dict, List, Optional, Set, Tuple + +from dace.cli.progress import optional_progressbar +from dace import data +from dace.data.pydata import PythonClass +from dace.memlet import Memlet +from dace.properties import CodeBlock +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.utils import find_new_name + +# -------------------------------------------------------------------- # +# Public entry point # +# -------------------------------------------------------------------- # + + +def resolve_function_calls(root: tn.ScheduleTreeRoot) -> None: + """ + Collect all :class:`FunctionCallScope` nodes in *root*, parse each + callee's schedule tree (in parallel where possible), and inline the + results. + + Callee schedule trees may themselves contain nested calls; since + ``to_schedule_tree`` calls ``resolve_function_calls`` recursively, + bottom-up ordering is automatic — leaf-level callees are fully + resolved before their callers. + """ + scopes = _collect_function_call_scopes(root) + if not scopes: + return + + # Build callee schedule trees — keyed by (callee_id, arg_types). + callee_trees = _build_callee_trees(scopes) + + # Inline each call site. + for scope in scopes: + key = _callee_key(scope) + callee_tree = callee_trees[key] + _inline_callee(scope, callee_tree, root) + + +# -------------------------------------------------------------------- # +# Collecting FunctionCallScope nodes # +# -------------------------------------------------------------------- # + + +def _collect_function_call_scopes(root: tn.ScheduleTreeRoot) -> List[tn.FunctionCallScope]: + result: List[tn.FunctionCallScope] = [] + for node in root.preorder_traversal(): + if isinstance(node, tn.FunctionCallScope): + result.append(node) + return result + + +# -------------------------------------------------------------------- # +# Building callee schedule trees (with parallelism) # +# -------------------------------------------------------------------- # + + +def _callee_key(scope: tn.FunctionCallScope) -> int: + """ + Return a hashable key for a call-scope's callee. + + For now this is just the callee identity. When we need to support + multiple specialisations of the same function for different argument + types, this should be extended to include the type signature. + """ + return hash((id(scope._callee_program), _specialization_key(scope))) + + +def _specialization_key(scope: tn.FunctionCallScope) -> Tuple: + return (_specialization_values_key(getattr(scope, '_call_args', + [])), _specialization_kwargs_key(getattr(scope, '_call_kwargs', {})), + tuple( + sorted((name, ast.dump(lambda_node)) + for name, lambda_node in getattr(scope, '_lambda_bindings', {}).items())), + tuple(sorted((name, id(value)) for name, value in getattr(scope, '_callable_bindings', {}).items()))) + + +def _specialization_values_key(values: List[object]) -> Tuple: + return tuple(_specialization_value_key(value) for value in values) + + +def _specialization_kwargs_key(values: Dict[str, object]) -> Tuple: + return tuple(sorted((name, _specialization_value_key(value)) for name, value in values.items())) + + +def _specialization_value_key(value: object) -> Tuple[str, str]: + if isinstance(value, data.Data): + return ('descriptor', repr(value)) + return (type(value).__name__, repr(value)) + + +def _build_callee_trees(scopes: List[tn.FunctionCallScope]) -> Dict[int, tn.ScheduleTreeRoot]: + """ + For every unique callee referenced by *scopes*, build its schedule + tree. When there are multiple independent callees, parse them in + parallel. Returns a mapping from :func:`_callee_key` to the parsed + :class:`ScheduleTreeRoot`. + """ + # De-duplicate by callee identity. + unique: Dict[int, tn.FunctionCallScope] = {} + for scope in scopes: + key = _callee_key(scope) + if key not in unique: + unique[key] = scope + + if len(unique) == 1: + # Only one callee — no need for thread overhead. + scope = next(iter(unique.values())) + tree = _parse_callee(scope) + return {_callee_key(scope): tree} + + results: Dict[int, tn.ScheduleTreeRoot] = {} + with ThreadPoolExecutor() as pool: + futures = {pool.submit(_parse_callee, scope): key for key, scope in unique.items()} + completed = optional_progressbar(as_completed(futures), title='Parsing nested DaCe functions', n=len(futures)) + for future in completed: + results[futures[future]] = future.result() + return results + + +def _parse_callee(scope: tn.FunctionCallScope) -> tn.ScheduleTreeRoot: + """ + Parse a callee ``DaceProgram`` into its schedule tree. + + The callee's ``to_schedule_tree`` method triggers preprocessing + + schedule-tree building + recursive ``resolve_function_calls``, so + leaf-level callees are fully inlined before we return. + """ + callee = scope._callee_program + call_args = tuple(getattr(scope, '_call_args', [])) + call_kwargs = dict(getattr(scope, '_call_kwargs', {})) + lambda_bindings = dict(getattr(scope, '_lambda_bindings', {})) + callable_bindings = dict(getattr(scope, '_callable_bindings', {})) + + if hasattr(callee, '__schedule_tree__'): + return callee.__schedule_tree__(*call_args, + lambda_bindings=lambda_bindings, + callable_bindings=callable_bindings, + **call_kwargs) + + return callee._generate_schedule_tree(call_args, + call_kwargs, + lambda_bindings=lambda_bindings, + callable_bindings=callable_bindings) + + +# -------------------------------------------------------------------- # +# Inlining a callee tree into a FunctionCallScope # +# -------------------------------------------------------------------- # + + +def _inline_callee(scope: tn.FunctionCallScope, callee_tree: tn.ScheduleTreeRoot, + caller_root: tn.ScheduleTreeRoot) -> None: + """ + Inline *callee_tree* into *scope*, renaming containers to match the + caller's namespace, merging descriptors, and rewriting return nodes. + """ + arguments = scope.call.arguments # callee_param -> caller_expr + callee_arg_names = set(callee_tree.arg_names) + captured_names = set(getattr(scope, '_captured_names', set())) + + # 1. Build rename map: callee name -> caller name. + rename_map = _build_rename_map(arguments, callee_tree, caller_root, callee_arg_names, captured_names) + + # 2. Deep-copy callee body. + body = copy.deepcopy(callee_tree.children) + + # 3. Rename all data references in the cloned body. + renamer = _ContainerRenamer(rename_map) + body = [renamer.visit(child) for child in body] + body = [child for child in body if child is not None] + + # 4. Propagate callee argument descriptor upgrades back to the caller. + _propagate_argument_descriptor_updates(arguments, callee_tree, caller_root) + + # 5. Merge callee transient containers and symbols into the caller. + for cname, desc in callee_tree.containers.items(): + new_name = rename_map.get(cname, cname) + if new_name not in caller_root.containers: + caller_root.containers[new_name] = copy.deepcopy(desc) + for sname, stype in callee_tree.symbols.items(): + caller_root.symbols.setdefault(sname, stype) + + # 6. Merge callee constants and callbacks. + for cname, cval in callee_tree.constants.items(): + caller_root.constants.setdefault(cname, cval) + for cbname, cbval in callee_tree.callback_mapping.items(): + caller_root.callback_mapping.setdefault(cbname, cbval) + + # 7. Handle return values. + return_targets = getattr(scope, '_return_targets', None) + body = _rewrite_returns(body, return_targets, rename_map) + + # 8. Populate the scope. + scope.children = body + for child in body: + child.parent = scope + + +def _propagate_argument_descriptor_updates(arguments: Dict[str, str], callee_tree: tn.ScheduleTreeRoot, + caller_root: tn.ScheduleTreeRoot) -> None: + for callee_param, caller_expr in arguments.items(): + callee_descriptor = callee_tree.containers.get(callee_param) + if not isinstance(callee_descriptor, PythonClass): + continue + caller_descriptor = caller_root.containers.get(caller_expr) + if caller_descriptor is None or isinstance(caller_descriptor, PythonClass): + continue + caller_root.containers[caller_expr] = copy.deepcopy(callee_descriptor) + + +# -------------------------------------------------------------------- # +# Rename-map construction # +# -------------------------------------------------------------------- # + + +def _build_rename_map(arguments: Dict[str, str], callee_tree: tn.ScheduleTreeRoot, caller_root: tn.ScheduleTreeRoot, + callee_arg_names: Set[str], captured_names: Set[str]) -> Dict[str, str]: + """ + Build ``{callee_name: caller_name}`` for every container in the + callee's schedule tree. + + * Arguments are mapped via *arguments* (callee_param -> caller_expr). + * Transients that collide with caller names get fresh names. + """ + rename: Dict[str, str] = {} + occupied = set(caller_root.containers.keys()) + + # Map callee parameters to caller arguments. + for callee_param, caller_expr in arguments.items(): + rename[callee_param] = caller_expr + + # Explicit global/nonlocal captures must keep the caller-visible name. + for captured_name in captured_names: + rename[captured_name] = captured_name + occupied.add(captured_name) + + # Handle callee-internal transients (everything not in arg_names). + for cname in callee_tree.containers: + if cname in rename: + continue + if cname in callee_arg_names: + # Argument not in the mapping (e.g. default-valued) — keep as-is. + continue + new_name = find_new_name(cname, list(occupied)) + rename[cname] = new_name + occupied.add(new_name) + + return rename + + +# -------------------------------------------------------------------- # +# Container renaming transformer # +# -------------------------------------------------------------------- # + + +class _ContainerRenamer(tn.ScheduleNodeTransformer): + """Rename data-container references throughout a schedule sub-tree.""" + + def __init__(self, rename_map: Dict[str, str]) -> None: + self._map = {k: v for k, v in rename_map.items() if k != v} + + # -- helpers -------------------------------------------------------- + + def _rename(self, name: str) -> str: + return self._map.get(name, name) + + def _rename_memlet(self, memlet: Optional[Memlet]) -> Optional[Memlet]: + if memlet is None: + return None + if memlet.data in self._map: + memlet.data = self._map[memlet.data] + return memlet + + def _rename_memlet_dict(self, d): + if isinstance(d, dict): + return {k: self._rename_memlet(copy.deepcopy(m)) for k, m in d.items()} + if isinstance(d, set): + return {self._rename_memlet(copy.deepcopy(m)) for m in d} + return d + + def _rename_code_block(self, cb: Optional[CodeBlock]) -> Optional[CodeBlock]: + if cb is None: + return None + text = cb.as_string + for old, new in self._map.items(): + text = re.sub(r'\b' + re.escape(old) + r'\b', new, text) + return CodeBlock(text) + + # -- leaf node visitors --------------------------------------------- + + def visit_CopyNode(self, node: tn.CopyNode): + node.target = self._rename(node.target) + self._rename_memlet(node.memlet) + return node + + def visit_DynScopeCopyNode(self, node: tn.DynScopeCopyNode): + node.target = self._rename(node.target) + self._rename_memlet(node.memlet) + return node + + def visit_ViewNode(self, node: tn.ViewNode): + node.target = self._rename(node.target) + node.source = self._rename(node.source) + self._rename_memlet(node.memlet) + return node + + def visit_RefSetNode(self, node: tn.RefSetNode): + node.target = self._rename(node.target) + self._rename_memlet(node.memlet) + if node.source_expr is not None: + for old, new in self._map.items(): + node.source_expr = re.sub(r'\b' + re.escape(old) + r'\b', new, node.source_expr) + return node + + def visit_TaskletNode(self, node: tn.TaskletNode): + node.in_memlets = self._rename_memlet_dict(node.in_memlets) + node.out_memlets = self._rename_memlet_dict(node.out_memlets) + if isinstance(node.node, tn.FrontendTasklet): + node.node = tn.FrontendTasklet(name=node.node.name, code=self._rename_code_block(node.node.code)) + return node + + def visit_LibraryCall(self, node: tn.LibraryCall): + node.in_memlets = self._rename_memlet_dict(node.in_memlets) + node.out_memlets = self._rename_memlet_dict(node.out_memlets) + return node + + def visit_AssignNode(self, node: tn.AssignNode): + node.name = self._rename(node.name) + node.value = self._rename_code_block(node.value) + return node + + def visit_ReassignExternalNode(self, node: tn.ReassignExternalNode): + node.value = self._rename_code_block(node.value) + return node + + def visit_StatementNode(self, node: tn.StatementNode): + node.code = self._rename_code_block(node.code) + return node + + def visit_PythonCallbackNode(self, node: tn.PythonCallbackNode): + node.code = self._rename_code_block(node.code) + if node.outlined_function_code is not None: + node.outlined_function_code = self._rename_code_block(node.outlined_function_code) + if node.outlined_call_code is not None: + node.outlined_call_code = self._rename_code_block(node.outlined_call_code) + return node + + def visit_RaiseNode(self, node: tn.RaiseNode): + if node.exception_type is not None: + node.exception_type = self._rename_code_block(node.exception_type) + node.args = [self._rename_code_block(argument) for argument in node.args] + node.kwargs = {name: self._rename_code_block(value) for name, value in node.kwargs.items()} + return node + + def visit_ReturnNode(self, node: tn.ReturnNode): + node.values = [self._rename(v) for v in node.values] + return node + + +# -------------------------------------------------------------------- # +# Return-value rewriting # +# -------------------------------------------------------------------- # + + +def _rewrite_returns(body: List[tn.ScheduleTreeNode], return_targets: Optional[List[str]], + rename_map: Dict[str, str]) -> List[tn.ScheduleTreeNode]: + """ + Replace :class:`ReturnNode` instances in *body* with assignments to + *return_targets*. If *return_targets* is ``None`` (the call was used + as a bare statement), remove ``ReturnNode`` instances entirely. + """ + result: List[tn.ScheduleTreeNode] = [] + for node in body: + if isinstance(node, tn.ReturnNode): + if return_targets and node.values: + for target, value_name in zip(return_targets, node.values): + if value_name != target: + result.append(tn.AssignNode(name=target, value=CodeBlock(value_name))) + # else: bare call — drop the return + elif isinstance(node, tn.ScheduleTreeScope): + node.children = _rewrite_returns(node.children, return_targets, rename_map) + result.append(node) + else: + result.append(node) + return result diff --git a/dace/frontend/python/schedule_tree/lambda_support.py b/dace/frontend/python/schedule_tree/lambda_support.py new file mode 100644 index 0000000000..7a5766902d --- /dev/null +++ b/dace/frontend/python/schedule_tree/lambda_support.py @@ -0,0 +1,420 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Helpers for recovering and inlining lambda expressions. + +Example: + If ``f`` is known to be ``lambda a, b: a + b``, then a call such as + ``f(A, B)`` can be rewritten directly to ``A + B`` before schedule-tree + lowering continues. +""" + +import ast +import inspect +from typing import Any, Dict, Optional + +import sympy + +from dace import dtypes, symbolic +from dace.frontend.python import astutils + + +def extract_lambda_ast(func) -> Optional[ast.Lambda]: + """Recover the AST for a Python lambda when its source is available.""" + if not inspect.isfunction(func) or getattr(func, '__name__', None) != '': + return None + + try: + src_ast, _, _, _ = astutils.function_to_ast(func) + except Exception: + return None + + target_lineno = getattr(func.__code__, 'co_firstlineno', None) + candidates = [node for node in ast.walk(src_ast) if isinstance(node, ast.Lambda)] + if target_lineno is not None: + exact = [node for node in candidates if getattr(node, 'lineno', None) == target_lineno] + if exact: + candidates = exact + if len(candidates) != 1: + return None + + return ast.fix_missing_locations(astutils.copy_tree(candidates[0])) + + +def inline_lambda_call(lambda_node: ast.Lambda, call_node: ast.Call) -> ast.AST: + """Inline a call to ``lambda_node`` by substituting actual arguments.""" + bindings = _bind_lambda_arguments(lambda_node.args, call_node) + + class _LambdaInliner(ast.NodeTransformer): + + def visit_Name(self, node: ast.Name) -> ast.AST: + if isinstance(node.ctx, ast.Load) and node.id in bindings: + return astutils.copy_tree(bindings[node.id]) + return node + + def visit_Lambda(self, node: ast.Lambda) -> ast.AST: + return node + + return ast.fix_missing_locations(_LambdaInliner().visit(astutils.copy_tree(lambda_node.body))) + + +def _bind_lambda_arguments(args: ast.arguments, call_node: ast.Call) -> Dict[str, ast.AST]: + if args.vararg is not None or args.kwarg is not None or args.kwonlyargs: + raise TypeError('Only simple positional/keyword lambda arguments are supported') + + parameters = list(args.posonlyargs) + list(args.args) + defaults = list(args.defaults) + default_offset = len(parameters) - len(defaults) + + bindings: Dict[str, ast.AST] = {} + positional = list(call_node.args) + if len(positional) > len(parameters): + raise TypeError('Too many positional arguments for lambda call') + + for parameter, actual in zip(parameters, positional): + bindings[parameter.arg] = astutils.copy_tree(actual) + + remaining_keywords = {kw.arg: kw.value for kw in call_node.keywords if kw.arg is not None} + for index, parameter in enumerate(parameters[len(positional):], start=len(positional)): + if parameter.arg in remaining_keywords: + bindings[parameter.arg] = astutils.copy_tree(remaining_keywords.pop(parameter.arg)) + continue + default_index = index - default_offset + if default_index >= 0: + bindings[parameter.arg] = astutils.copy_tree(defaults[default_index]) + continue + raise TypeError(f'Missing argument {parameter.arg!r} for lambda call') + + for parameter in parameters[:len(positional)]: + if parameter.arg in remaining_keywords: + raise TypeError(f'Multiple values for argument {parameter.arg!r} in lambda call') + + if remaining_keywords: + raise TypeError(f'Unexpected keyword arguments in lambda call: {sorted(remaining_keywords)}') + + return bindings + + +class LambdaResolver: + """Resolve known lambda values and inline their call sites. + + The resolver keeps track of lambda values that are visible by name in the + current lowering scope and can recover AST for global or closure-backed + lambdas when source is available. + + Example: + Given a binding ``f = lambda a, b: a + b``, calling + ``inline_known_lambda_calls(...)`` on the AST for ``f(A, B)`` returns an + expression equivalent to ``A + B``. + """ + + def __init__(self, + globals_env: Dict[str, Any], + lambda_bindings: Dict[str, ast.Lambda], + callable_bindings: Dict[str, Any], + *, + cache: Optional[Dict[str, Optional[ast.Lambda]]] = None) -> None: + self.globals = globals_env + self.lambda_bindings = lambda_bindings + self.callable_bindings = callable_bindings + self._global_lambda_cache = cache if cache is not None else {} + + def update_binding(self, name: str, value: ast.AST) -> None: + lambda_node = self.resolve_known_lambda_node(value) + if lambda_node is None: + self.lambda_bindings.pop(name, None) + return + self.lambda_bindings[name] = lambda_node + + def bind_value(self, name: str, value: Any) -> None: + lambda_node = self.resolve_global_lambda_node(value) + if lambda_node is None: + self.lambda_bindings.pop(name, None) + return + self.lambda_bindings[name] = lambda_node + self._global_lambda_cache[name] = astutils.copy_tree(lambda_node) + + def resolve_known_lambda_node(self, node: ast.AST) -> Optional[ast.Lambda]: + if isinstance(node, ast.Lambda): + return astutils.copy_tree(node) + if not isinstance(node, ast.Name): + return None + if node.id in self.lambda_bindings: + return astutils.copy_tree(self.lambda_bindings[node.id]) + if node.id in self.callable_bindings: + lambda_node = self.resolve_global_lambda_node(self.callable_bindings[node.id]) + return astutils.copy_tree(lambda_node) if lambda_node is not None else None + if node.id in self._global_lambda_cache: + cached = self._global_lambda_cache[node.id] + return astutils.copy_tree(cached) if cached is not None else None + value = self.globals.get(node.id) + lambda_node = self.resolve_global_lambda_node(value) if value is not None else None + self._global_lambda_cache[node.id] = astutils.copy_tree(lambda_node) if lambda_node is not None else None + return astutils.copy_tree(lambda_node) if lambda_node is not None else None + + def resolve_global_lambda_node(self, value: Any) -> Optional[ast.Lambda]: + lambda_node = extract_lambda_ast(value) + if lambda_node is None: + return None + + lambda_globals = _resolve_lambda_environment(value) + for name, captured_value in lambda_globals.items(): + self.globals.setdefault(name, captured_value) + + inline_globals = { + name: captured_value + for name, captured_value in lambda_globals.items() if _can_inline_lambda_capture(captured_value) + } + if not inline_globals: + return astutils.copy_tree(lambda_node) + + return _rewrite_lambda_free_names(lambda_node, inline_globals) + + def inline_known_lambda_calls(self, node: ast.AST) -> ast.AST: + resolver = self + + class _KnownLambdaInliner(ast.NodeTransformer): + + def visit_Call(self, call_node: ast.Call) -> ast.AST: + rewritten = ast.Call( + func=self.visit(call_node.func), + args=[self.visit(arg) for arg in call_node.args], + keywords=[ast.keyword(arg=kw.arg, value=self.visit(kw.value)) for kw in call_node.keywords]) + lambda_node = resolver.resolve_known_lambda_node(rewritten.func) + if lambda_node is None: + return ast.copy_location(rewritten, call_node) + try: + inlined = inline_lambda_call(lambda_node, rewritten) + except TypeError: + return ast.copy_location(rewritten, call_node) + return ast.copy_location(self.visit(inlined), call_node) + + return ast.fix_missing_locations(_KnownLambdaInliner().visit(astutils.copy_tree(node))) + + +def _resolve_lambda_environment(value: Any) -> Dict[str, Any]: + try: + closure_vars = inspect.getclosurevars(value) + except Exception: + closure_vars = None + + if closure_vars is not None: + resolved = dict(closure_vars.globals) + resolved.update(closure_vars.nonlocals) + return resolved + + resolved = {} + globals_env = getattr(value, '__globals__', {}) + closure = getattr(value, '__closure__', None) + freevars = getattr(getattr(value, '__code__', None), 'co_freevars', ()) + if closure is not None: + for name, cell in zip(freevars, closure): + try: + resolved[name] = cell.cell_contents + except ValueError: + resolved[name] = None + + for name in _lambda_loaded_names(extract_lambda_ast(value)): + if name not in resolved and name in globals_env: + resolved[name] = globals_env[name] + return resolved + + +def _can_inline_lambda_capture(value: Any) -> bool: + if isinstance(value, ast.AST): + return True + if isinstance(value, (symbolic.symbol, sympy.Basic)): + return True + if dtypes.isconstant(value): + return True + if isinstance(value, tuple): + return all(_can_inline_lambda_capture(element) for element in value) + if isinstance(value, list): + return all(_can_inline_lambda_capture(element) for element in value) + if isinstance(value, dict): + return all( + _can_inline_lambda_capture(key) and _can_inline_lambda_capture(element) for key, element in value.items()) + return False + + +def _rewrite_lambda_free_names(lambda_node: ast.Lambda, env: Dict[str, Any]) -> ast.Lambda: + rewriter = _LambdaFreeNameRewriter(env) + return ast.fix_missing_locations(rewriter.visit(astutils.copy_tree(lambda_node))) + + +def _value_to_ast(value: Any, template_node: ast.AST) -> Optional[ast.AST]: + if isinstance(value, ast.AST): + return ast.copy_location(astutils.copy_tree(value), template_node) + + if isinstance(value, symbolic.symbol): + return ast.copy_location(ast.Name(id=value.name, ctx=ast.Load()), template_node) + + if isinstance(value, sympy.Basic): + return ast.copy_location(ast.parse(symbolic.symstr(value), mode='eval').body, template_node) + + if isinstance(value, list): + elements = [_value_to_ast(element, template_node) for element in value] + if any(element is None for element in elements): + return None + return ast.copy_location(ast.List(elts=elements, ctx=ast.Load()), template_node) + + if isinstance(value, tuple): + elements = [_value_to_ast(element, template_node) for element in value] + if any(element is None for element in elements): + return None + return ast.copy_location(ast.Tuple(elts=elements, ctx=ast.Load()), template_node) + + if isinstance(value, dict): + keys = [] + values = [] + for key, item in value.items(): + key_ast = _value_to_ast(key, template_node) + value_ast = _value_to_ast(item, template_node) + if key_ast is None or value_ast is None: + return None + keys.append(key_ast) + values.append(value_ast) + return ast.copy_location(ast.Dict(keys=keys, values=values), template_node) + + if dtypes.isconstant(value): + return astutils.create_constant(value, template_node) + + return None + + +class _LambdaFreeNameRewriter(ast.NodeTransformer): + + def __init__(self, env: Dict[str, Any]) -> None: + self.env = env + self.scope_stack = [] + + def visit_Lambda(self, node: ast.Lambda) -> ast.AST: + self.scope_stack.append(_lambda_parameter_names(node.args)) + try: + node.body = self.visit(node.body) + return node + finally: + self.scope_stack.pop() + + def visit_Name(self, node: ast.Name) -> ast.AST: + if not isinstance(node.ctx, ast.Load): + if self.scope_stack and isinstance(node.ctx, ast.Store): + self.scope_stack[-1].add(node.id) + return node + + if any(node.id in scope for scope in reversed(self.scope_stack)): + return node + if node.id not in self.env: + return node + + replacement = _value_to_ast(self.env[node.id], node) + return replacement if replacement is not None else node + + def visit_ListComp(self, node: ast.ListComp) -> ast.AST: + return self._visit_comprehension(node, 'elt') + + def visit_SetComp(self, node: ast.SetComp) -> ast.AST: + return self._visit_comprehension(node, 'elt') + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.AST: + return self._visit_comprehension(node, 'elt') + + def visit_DictComp(self, node: ast.DictComp) -> ast.AST: + pushed = self._push_comprehension_scopes(node.generators) + try: + node.key = self.visit(node.key) + node.value = self.visit(node.value) + return node + finally: + self._pop_comprehension_scopes(pushed) + + def _visit_comprehension(self, node: ast.AST, field: str) -> ast.AST: + pushed = self._push_comprehension_scopes(node.generators) + try: + setattr(node, field, self.visit(getattr(node, field))) + return node + finally: + self._pop_comprehension_scopes(pushed) + + def _push_comprehension_scopes(self, generators) -> int: + pushed = 0 + for generator in generators: + generator.iter = self.visit(generator.iter) + bound_names = _store_target_names(generator.target) + self.scope_stack.append(bound_names) + pushed += 1 + generator.ifs = [self.visit(condition) for condition in generator.ifs] + return pushed + + def _pop_comprehension_scopes(self, pushed: int) -> None: + for _ in range(pushed): + self.scope_stack.pop() + + +class _LambdaLoadedNameCollector(ast.NodeVisitor): + + def __init__(self) -> None: + self.scope_stack = [] + self.loaded_names = set() + + def visit_Lambda(self, node: ast.Lambda) -> None: + self.scope_stack.append(_lambda_parameter_names(node.args)) + try: + self.visit(node.body) + finally: + self.scope_stack.pop() + + def visit_Name(self, node: ast.Name) -> None: + if isinstance(node.ctx, ast.Load) and not any(node.id in scope for scope in reversed(self.scope_stack)): + self.loaded_names.add(node.id) + + def visit_ListComp(self, node: ast.ListComp) -> None: + self._visit_comprehension(node.generators, lambda: self.visit(node.elt)) + + def visit_SetComp(self, node: ast.SetComp) -> None: + self._visit_comprehension(node.generators, lambda: self.visit(node.elt)) + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> None: + self._visit_comprehension(node.generators, lambda: self.visit(node.elt)) + + def visit_DictComp(self, node: ast.DictComp) -> None: + self._visit_comprehension(node.generators, lambda: (self.visit(node.key), self.visit(node.value))) + + def _visit_comprehension(self, generators, visit_result) -> None: + pushed = 0 + for generator in generators: + self.visit(generator.iter) + self.scope_stack.append(_store_target_names(generator.target)) + pushed += 1 + for condition in generator.ifs: + self.visit(condition) + try: + visit_result() + finally: + for _ in range(pushed): + self.scope_stack.pop() + + +def _lambda_loaded_names(lambda_node: Optional[ast.Lambda]) -> set[str]: + if lambda_node is None: + return set() + collector = _LambdaLoadedNameCollector() + collector.visit(astutils.copy_tree(lambda_node)) + return collector.loaded_names + + +def _lambda_parameter_names(args: ast.arguments) -> set[str]: + names = {arg.arg for arg in args.posonlyargs} + names.update(arg.arg for arg in args.args) + names.update(arg.arg for arg in args.kwonlyargs) + if args.vararg is not None: + names.add(args.vararg.arg) + if args.kwarg is not None: + names.add(args.kwarg.arg) + return names + + +def _store_target_names(target: ast.AST) -> set[str]: + names = set() + for child in ast.walk(target): + if isinstance(child, ast.Name) and isinstance(child.ctx, ast.Store): + names.add(child.id) + return names diff --git a/dace/frontend/python/schedule_tree/match_support.py b/dace/frontend/python/schedule_tree/match_support.py new file mode 100644 index 0000000000..8611c4d21b --- /dev/null +++ b/dace/frontend/python/schedule_tree/match_support.py @@ -0,0 +1,138 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Helpers for lowering Python ``match`` statements to simpler AST forms.""" + +import ast +from typing import Dict, List, Optional, Tuple +from dace.frontend.python import astutils + + +class UnsupportedMatchPatternError(TypeError): + """Raised when a match pattern cannot be lowered to an if-chain.""" + + +_Bindings = List[Tuple[str, ast.AST]] + + +def lower_match_to_statements(node: ast.Match, subject_expr: ast.AST) -> List[ast.stmt]: + """Lower a supported ``match`` node to equivalent ``if`` statements. + + Supported patterns are limited to value, singleton, wildcard, capture, + alias, guarded cases, ``or`` patterns without bindings, and fixed-length + sequence patterns without starred items. Class and mapping patterns, and + more advanced structural patterns, still fall back to callbacks. + """ + lowered: List[ast.stmt] = [] + + for case in reversed(node.cases): + condition, bindings = _lower_pattern(case.pattern, subject_expr) + binding_map = {name: astutils.copy_tree(expr) for name, expr in bindings} + + if case.guard is not None: + guard = _substitute_capture_loads(case.guard, binding_map) + condition = guard if condition is None else ast.BoolOp(op=ast.And(), values=[condition, guard]) + + body: List[ast.stmt] = [ + ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=astutils.copy_tree(expr)) + for name, expr in bindings + ] + body.extend(astutils.copy_tree(case.body)) + + if condition is None: + lowered = body + else: + lowered = [ast.If(test=condition, body=body, orelse=lowered)] + + return [ast.fix_missing_locations(stmt) for stmt in lowered] + + +def _lower_pattern(pattern: ast.pattern, subject_expr: ast.AST) -> Tuple[Optional[ast.AST], _Bindings]: + if isinstance(pattern, ast.MatchValue): + return _eq_condition(subject_expr, pattern.value), [] + + if isinstance(pattern, ast.MatchSingleton): + return ast.Compare(left=astutils.copy_tree(subject_expr), + ops=[ast.Is()], + comparators=[ast.Constant(pattern.value)]), [] + + if isinstance(pattern, ast.MatchSequence): + if any(isinstance(subpattern, ast.MatchStar) for subpattern in pattern.patterns): + raise UnsupportedMatchPatternError('sequence patterns with starred items are not supported yet') + + conditions: List[ast.AST] = [_fixed_length_sequence_condition(subject_expr, len(pattern.patterns))] + bindings: _Bindings = [] + for index, subpattern in enumerate(pattern.patterns): + element_expr = ast.Subscript(value=astutils.copy_tree(subject_expr), + slice=ast.Constant(index), + ctx=ast.Load()) + element_condition, element_bindings = _lower_pattern(subpattern, element_expr) + if element_condition is not None: + conditions.append(element_condition) + bindings.extend(element_bindings) + return _combine_conditions(conditions), bindings + + if isinstance(pattern, ast.MatchAs): + if pattern.pattern is None: + if pattern.name is None: + return None, [] + return None, [(pattern.name, astutils.copy_tree(subject_expr))] + + condition, bindings = _lower_pattern(pattern.pattern, subject_expr) + if pattern.name is not None: + bindings = bindings + [(pattern.name, astutils.copy_tree(subject_expr))] + return condition, bindings + + if isinstance(pattern, ast.MatchOr): + alternatives = [_lower_pattern(alt, subject_expr) for alt in pattern.patterns] + if any(bindings for _, bindings in alternatives): + raise UnsupportedMatchPatternError('or-patterns with bindings are not supported yet') + if any(condition is None for condition, _ in alternatives): + return None, [] + return ast.BoolOp(op=ast.Or(), values=[condition for condition, _ in alternatives]), [] + + if isinstance(pattern, (ast.MatchClass, ast.MatchMapping, ast.MatchStar)): + raise UnsupportedMatchPatternError(f'Unsupported match pattern: {type(pattern).__name__}') + + raise UnsupportedMatchPatternError(f'Unsupported match pattern: {type(pattern).__name__}') + + +def _eq_condition(left: ast.AST, right: ast.AST) -> ast.Compare: + return ast.Compare(left=astutils.copy_tree(left), ops=[ast.Eq()], comparators=[astutils.copy_tree(right)]) + + +def _fixed_length_sequence_condition(subject_expr: ast.AST, length: int) -> ast.AST: + return _combine_conditions([ + ast.BoolOp(op=ast.Or(), + values=[ + ast.Call(func=ast.Name(id='isinstance', ctx=ast.Load()), + args=[astutils.copy_tree(subject_expr), + ast.Name(id='tuple', ctx=ast.Load())], + keywords=[]), + ast.Call(func=ast.Name(id='isinstance', ctx=ast.Load()), + args=[astutils.copy_tree(subject_expr), + ast.Name(id='list', ctx=ast.Load())], + keywords=[]), + ]), + ast.Compare(left=ast.Call(func=ast.Name(id='len', ctx=ast.Load()), + args=[astutils.copy_tree(subject_expr)], + keywords=[]), + ops=[ast.Eq()], + comparators=[ast.Constant(length)]), + ]) + + +def _combine_conditions(conditions: List[ast.AST]) -> ast.AST: + if len(conditions) == 1: + return conditions[0] + return ast.BoolOp(op=ast.And(), values=conditions) + + +def _substitute_capture_loads(node: ast.AST, bindings: Dict[str, ast.AST]) -> ast.AST: + + class _CaptureSubstituter(ast.NodeTransformer): + + def visit_Name(self, inner: ast.Name) -> ast.AST: + if isinstance(inner.ctx, ast.Load) and inner.id in bindings: + return astutils.copy_tree(bindings[inner.id]) + return inner + + return ast.fix_missing_locations(_CaptureSubstituter().visit(astutils.copy_tree(node))) diff --git a/dace/frontend/python/schedule_tree/numpy_support.py b/dace/frontend/python/schedule_tree/numpy_support.py new file mode 100644 index 0000000000..eea03eb732 --- /dev/null +++ b/dace/frontend/python/schedule_tree/numpy_support.py @@ -0,0 +1,1552 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""NumPy-oriented lowering helpers for the direct schedule-tree frontend.""" + +import ast +import copy +import numbers +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import numpy as np + +from dace import data, dtypes, subsets, symbolic +from dace.data.pydata import PythonDict, PythonList, PythonTuple +from dace.frontend.python import astutils, memlet_parser +from dace.frontend.python.replacements.array_creation import arange_promoted_symbol_name +from dace.frontend.python.replacements.utils import broadcast_to, broadcast_together +from dace.frontend.python.schedule_tree.static_evaluation import UNRESOLVED, try_resolve_static_value +from dace.frontend.python.schedule_tree.type_inference import _Binding +from dace.memlet import Memlet +from dace.properties import CodeBlock +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.type_inference import infer_expr_type + +OutputTargetResolver = Callable[[ast.AST, ast.AST, Optional[data.Data]], Optional[Tuple[str, Memlet, data.Data]]] +TaskletNameFactory = Callable[[ast.AST], str] +EvaluationContextFactory = Callable[[], Dict[str, Any]] +FreshSymbolFactory = Callable[[str], symbolic.symbol] +SymbolRegistrar = Callable[[str], symbolic.symbol] +FreshNameFactory = Callable[[str], str] +NodeAppender = Callable[[tn.ScheduleTreeNode], None] +BindingRegistrar = Callable[[str, data.Data, str], None] + + +@dataclass(frozen=True) +class NumpyLoweringContext: + bindings: Dict[str, _Binding] + evaluation_context: EvaluationContextFactory + resolve_output_target: OutputTargetResolver + tasklet_name: TaskletNameFactory + fresh_symbol: FreshSymbolFactory + register_symbol: SymbolRegistrar + fresh_name: FreshNameFactory + append_node: NodeAppender + register_binding: BindingRegistrar + + +@dataclass(frozen=True) +class _AdvancedIndexBlueprint: + output_shape: Tuple[Any, ...] + output_subset: subsets.Range + source_memlet: Memlet + index_memlets: Tuple[Memlet, ...] + + +@dataclass(frozen=True) +class _ResolvedAccess: + node: ast.AST + name: str + descriptor: data.Data + subset: subsets.Range + logical_ranges: Tuple[Tuple[Any, Any, Any], ...] + array_connector: str + index_connectors: Tuple[str, ...] + output_shape: Tuple[Any, ...] + new_axes: Tuple[int, ...] = tuple() + scalar_dims: Tuple[int, ...] = tuple() + blueprint: Optional[_AdvancedIndexBlueprint] = None + + +@dataclass(frozen=True) +class _BooleanGatherPlan: + source_name: str + source_descriptor: data.Data + result_descriptor: data.Array + input_memlets: Dict[str, Memlet] + nnz_symbol: symbolic.symbol + mask_expr: Optional[ast.AST] = None + + +@dataclass(frozen=True) +class _ExpressionAnalysis: + tasklet_value: ast.AST + typing_value: ast.AST + accesses: Tuple[_ResolvedAccess, ...] + result_shape: Tuple[Any, ...] + result_dtype: dtypes.typeclass + + +@dataclass(frozen=True) +class _IterationPlan: + original_subset: subsets.Range + squeezed_subset: subsets.Range + non_singleton_dims: Tuple[int, ...] + params: Tuple[str, ...] + ranges: Tuple[Tuple[str, str, str], ...] + + +@dataclass(frozen=True) +class _AdvancedTarget: + name: str + output_shape: Tuple[Any, ...] + output_memlet: Memlet + target_expr: ast.AST + input_memlets: Dict[str, Memlet] + guard_expr: Optional[ast.AST] = None + + +class NumpySupportLibrary: + """Ordered NumPy-specific lowering and inference helpers.""" + + def __init__(self) -> None: + self.assignment_passes = (_ArangePass(), _BooleanMaskReadPass(), _ElementwiseAssignmentPass()) + + def lower_assignment(self, context: NumpyLoweringContext, target: ast.AST, value: ast.AST, + annotated_descriptor: Optional[data.Data]) -> Optional[tn.ScheduleTreeNode]: + for lowering_pass in self.assignment_passes: + lowered = lowering_pass.lower_assignment(context, target, value, annotated_descriptor) + if lowered is not None: + return lowered + return None + + def infer_expression_descriptor(self, context: NumpyLoweringContext, value: ast.AST) -> Optional[data.Data]: + for lowering_pass in self.assignment_passes: + descriptor = lowering_pass.infer_expression_descriptor(context, value) + if descriptor is not None: + return descriptor + return None + + +class _ArangePass: + """Lower numpy.arange calls with explicit symbolic properties.""" + + _LIBRARY_NAME = 'numpy.arange' + + def lower_assignment(self, context: NumpyLoweringContext, target: ast.AST, value: ast.AST, + annotated_descriptor: Optional[data.Data]) -> Optional[tn.ScheduleTreeNode]: + if not _is_arange_call(value): + return None + + output = context.resolve_output_target(target, value, annotated_descriptor) + if output is None: + return None + _, output_memlet, output_descriptor = output + if not isinstance(output_descriptor, data.Array) or len(output_descriptor.shape) != 1: + return None + + resolved_args = _resolve_arange_arguments(context, value) + if resolved_args is None: + return None + start, stop, step = resolved_args + + properties: Dict[str, Any] = { + 'start': str(start), + 'stop': str(stop), + 'step': str(step), + 'dtype': str(output_descriptor.dtype), + } + return tn.LibraryCall(node=tn.FrontendLibrary(name=self._LIBRARY_NAME, properties=properties), + in_memlets={}, + out_memlets={'out': output_memlet}) + + def infer_expression_descriptor(self, context: NumpyLoweringContext, value: ast.AST) -> Optional[data.Data]: + return None + + +class _BooleanMaskReadPass: + """Lower RHS boolean-mask gathers as frontend library calls.""" + + _LIBRARY_NAME = 'boolean_mask_gather' + + def lower_assignment(self, context: NumpyLoweringContext, target: ast.AST, value: ast.AST, + annotated_descriptor: Optional[data.Data]) -> Optional[tn.ScheduleTreeNode]: + output = context.resolve_output_target(target, value, annotated_descriptor) + if output is None: + return None + + _, output_memlet, output_descriptor = output + plan = _resolve_boolean_gather(value, context, output_descriptor) + if plan is None: + return None + + properties: Dict[str, Any] = { + 'nnz_symbol': str(plan.nnz_symbol), + 'upper_bound': str(plan.result_descriptor.total_size), + } + if plan.mask_expr is not None: + properties['mask_expr'] = _unparse(plan.mask_expr) + + return tn.LibraryCall(node=tn.FrontendLibrary(name=self._LIBRARY_NAME, properties=properties), + in_memlets=plan.input_memlets, + out_memlets={'out': output_memlet}) + + def infer_expression_descriptor(self, context: NumpyLoweringContext, value: ast.AST) -> Optional[data.Data]: + plan = _resolve_boolean_gather(value, context, None) + if plan is None: + return None + return plan.result_descriptor + + +def _is_arange_call(node: ast.AST) -> bool: + return isinstance(node, ast.Call) and astutils.rname(node.func) in {'numpy.arange', 'dace.arange'} + + +def _resolve_arange_arguments(context: NumpyLoweringContext, node: ast.Call) -> Optional[Tuple[Any, Any, Any]]: + if len(node.args) == 1: + start_node, stop_node, step_node = ast.Constant(value=0), node.args[0], ast.Constant(value=1) + elif len(node.args) == 2: + start_node, stop_node, step_node = node.args[0], node.args[1], ast.Constant(value=1) + elif len(node.args) >= 3: + start_node, stop_node, step_node = node.args[0], node.args[1], node.args[2] + else: + return None + + resolved = tuple(_resolve_arange_argument(context, arg) for arg in (start_node, stop_node, step_node)) + if any(value is None for value in resolved): + return None + return resolved + + +def _resolve_arange_argument(context: NumpyLoweringContext, node: ast.AST) -> Optional[Any]: + value = try_resolve_static_value(node, context.evaluation_context()) + if value is not UNRESOLVED and not isinstance(value, data.Data): + return value + if isinstance(node, ast.Name): + binding = context.bindings.get(node.id) + if binding is not None and isinstance(binding.descriptor, data.Scalar): + return _promote_arange_scalar_expression(context, node) + analysis = _ElementwiseExpressionAnalyzer(context).analyze(node) + if analysis is None or analysis.result_shape: + return None + return _promote_arange_scalar_expression(context, node, analysis) + + +def _promote_arange_scalar_expression(context: NumpyLoweringContext, + node: ast.AST, + analysis: Optional[_ExpressionAnalysis] = None) -> symbolic.symbol: + analysis = analysis or _ElementwiseExpressionAnalyzer(context).analyze(node) + symbol_name = arange_promoted_symbol_name(_unparse(node)) + symbol_value = context.fresh_symbol(symbol_name) + + transient_name = context.fresh_name('__stree_arange_arg') + transient_dtype = analysis.result_dtype if analysis is not None else dtypes.int64 + transient_descriptor = data.Scalar(transient_dtype, transient=True) + context.register_binding(transient_name, transient_descriptor, 'scalar') + + input_memlets: Dict[str, Memlet] = {} + tasklet_value = astutils.copy_tree(node) + if analysis is not None: + tasklet_value = analysis.tasklet_value + for access in analysis.accesses: + access_memlet = _build_scalar_input_memlet(access) + if access_memlet is not None: + input_memlets.update(access_memlet) + + tasklet = tn.FrontendTasklet(name=f'{transient_name}_tasklet', code=CodeBlock(f'out = {_unparse(tasklet_value)}')) + context.append_node( + tn.TaskletNode(node=tasklet, + in_memlets=input_memlets, + out_memlets={'out': Memlet.from_array(transient_name, transient_descriptor)})) + context.append_node(tn.AssignNode(name=str(symbol_value), value=CodeBlock(transient_name))) + return symbol_value + + +class _ElementwiseAssignmentPass: + """Lower NumPy-style elementwise assignments to explicit map scopes.""" + + def lower_assignment(self, context: NumpyLoweringContext, target: ast.AST, value: ast.AST, + annotated_descriptor: Optional[data.Data]) -> Optional[tn.ScheduleTreeNode]: + if isinstance(target, (ast.Tuple, ast.List)): + return self._lower_multi_output_ufunc_assignment(context, target, value) + + boolean_target = _resolve_boolean_target(context, target, value) + if boolean_target is not None: + return self._lower_boolean_target_assignment(context, boolean_target, target, value) + + advanced_target = _resolve_integer_target(context, target) + + analysis = _ElementwiseExpressionAnalyzer(context).analyze(value) + scalar_only_value = analysis is None and _is_trivial_scalar(value, context) + if analysis is None and not scalar_only_value: + return None + + if advanced_target is not None: + if isinstance(value, ast.BinOp) and _ast_equivalent(value.left, target): + return self._lower_integer_target_augassign(context, advanced_target, target, value) + if analysis is not None and analysis.result_shape and not _is_shape_compatible_shape( + advanced_target.output_shape, analysis.result_shape): + return None + iteration_plan = _build_iteration_plan_from_shape(advanced_target.output_shape) + if iteration_plan is None: + return None + + input_memlets: Dict[str, Memlet] = {} + if analysis is not None: + for access in analysis.accesses: + access_memlets = _build_input_memlets(access, iteration_plan) + if access_memlets is None: + return None + input_memlets.update(access_memlets) + input_memlets.update(advanced_target.input_memlets) + + tasklet_value = analysis.tasklet_value if analysis is not None else astutils.copy_tree(value) + tasklet = tn.FrontendTasklet( + name=context.tasklet_name(target), + code=CodeBlock(f'{_unparse(advanced_target.target_expr)} = {_unparse(tasklet_value)}')) + tasklet_node = tn.TaskletNode(node=tasklet, + in_memlets=input_memlets, + out_memlets={'out': advanced_target.output_memlet}) + map_scope = tn.MapScope(node=tn.FrontendMap(params=list(iteration_plan.params), + ranges=list(iteration_plan.ranges)), + children=[]) + _register_iteration_symbols(context, iteration_plan) + tasklet_node.parent = map_scope + map_scope.children.append(tasklet_node) + return map_scope + + output = context.resolve_output_target(target, value, annotated_descriptor) + if output is None: + return None + + target_name, target_memlet, _ = output + if not isinstance(target_memlet.subset, subsets.Range): + return None + if analysis is not None and analysis.result_shape and not _is_shape_compatible( + target_memlet.subset, analysis.result_shape): + return None + + if isinstance(target, ast.Subscript) and target_memlet.subset.num_elements() == 1: + input_memlets: Dict[str, Memlet] = {} + if analysis is not None: + if analysis.result_shape: + return None + for access in analysis.accesses: + access_memlet = _build_scalar_input_memlet(access) + if access_memlet is None: + return None + input_memlets.update(access_memlet) + + tasklet_value = analysis.tasklet_value if analysis is not None else astutils.copy_tree(value) + tasklet = tn.FrontendTasklet(name=context.tasklet_name(target), + code=CodeBlock(f'out = {_unparse(tasklet_value)}')) + return tn.TaskletNode( + node=tasklet, + in_memlets=input_memlets, + out_memlets={'out': Memlet(data=target_name, subset=copy.deepcopy(target_memlet.subset))}) + + iteration_plan = _build_iteration_plan(target_memlet.subset) + if iteration_plan is None: + return None + + input_memlets: Dict[str, Memlet] = {} + if analysis is not None: + for access in analysis.accesses: + access_memlets = _build_input_memlets(access, iteration_plan) + if access_memlets is None: + return None + input_memlets.update(access_memlets) + + output_memlet = _build_output_memlet(target_name, iteration_plan) + tasklet_value = analysis.tasklet_value if analysis is not None else astutils.copy_tree(value) + tasklet = tn.FrontendTasklet(name=context.tasklet_name(target), + code=CodeBlock(f'out = {_unparse(tasklet_value)}')) + tasklet_node = tn.TaskletNode(node=tasklet, in_memlets=input_memlets, out_memlets={'out': output_memlet}) + + map_scope = tn.MapScope(node=tn.FrontendMap(params=list(iteration_plan.params), + ranges=list(iteration_plan.ranges)), + children=[]) + _register_iteration_symbols(context, iteration_plan) + tasklet_node.parent = map_scope + map_scope.children.append(tasklet_node) + return map_scope + + def _lower_multi_output_ufunc_assignment(self, context: NumpyLoweringContext, target: ast.AST, + value: ast.AST) -> Optional[tn.ScheduleTreeNode]: + if not isinstance(value, ast.Call): + return None + + ufunc = _resolve_plain_multi_output_ufunc(value, context) + if ufunc is None or len(target.elts) != ufunc.nout: + return None + + analysis = _ElementwiseExpressionAnalyzer(context).analyze(value) + if analysis is None: + return None + + outputs: List[Tuple[str, Memlet, data.Data]] = [] + for element in target.elts: + output = context.resolve_output_target(element, value, None) + if output is None: + return None + if analysis.result_shape and not isinstance(output[1].subset, subsets.Range): + return None + if analysis.result_shape and not _is_shape_compatible(output[1].subset, analysis.result_shape): + return None + outputs.append(output) + + output_connectors = [f'out{index}' for index in range(len(outputs))] + tasklet = tn.FrontendTasklet( + name=context.tasklet_name(target), + code=CodeBlock(f'{", ".join(output_connectors)} = {_unparse(analysis.tasklet_value)}')) + + if not analysis.result_shape: + input_memlets: Dict[str, Memlet] = {} + for access in analysis.accesses: + access_memlet = _build_scalar_input_memlet(access) + if access_memlet is None: + return None + input_memlets.update(access_memlet) + out_memlets = { + connector: copy.deepcopy(output_memlet) + for connector, (_, output_memlet, _) in zip(output_connectors, outputs) + } + return tn.TaskletNode(node=tasklet, in_memlets=input_memlets, out_memlets=out_memlets) + + iteration_plan = _build_iteration_plan(outputs[0][1].subset) + if iteration_plan is None: + return None + + input_memlets: Dict[str, Memlet] = {} + for access in analysis.accesses: + access_memlets = _build_input_memlets(access, iteration_plan) + if access_memlets is None: + return None + input_memlets.update(access_memlets) + + out_memlets = { + connector: _build_output_memlet_for_subset(output_name, output_memlet.subset, iteration_plan) + for connector, (output_name, output_memlet, _) in zip(output_connectors, outputs) + } + tasklet_node = tn.TaskletNode(node=tasklet, in_memlets=input_memlets, out_memlets=out_memlets) + map_scope = tn.MapScope(node=tn.FrontendMap(params=list(iteration_plan.params), + ranges=list(iteration_plan.ranges)), + children=[]) + _register_iteration_symbols(context, iteration_plan) + tasklet_node.parent = map_scope + map_scope.children.append(tasklet_node) + return map_scope + + def _lower_integer_target_augassign(self, context: NumpyLoweringContext, advanced_target: _AdvancedTarget, + target: ast.AST, value: ast.BinOp) -> Optional[tn.ScheduleTreeNode]: + rhs_analysis = _ElementwiseExpressionAnalyzer(context).analyze(value.right) + if rhs_analysis is None and not _is_trivial_scalar(value.right, context): + return None + + iteration_plan = _build_iteration_plan_from_shape(advanced_target.output_shape) + if iteration_plan is None: + return None + + input_memlets: Dict[str, Memlet] = dict(advanced_target.input_memlets) + if rhs_analysis is not None: + if rhs_analysis.result_shape and not _is_shape_compatible_shape(advanced_target.output_shape, + rhs_analysis.result_shape): + return None + for access in rhs_analysis.accesses: + access_memlets = _build_input_memlets(access, iteration_plan) + if access_memlets is None: + return None + input_memlets.update(access_memlets) + rhs_tasklet = rhs_analysis.tasklet_value + else: + rhs_tasklet = astutils.copy_tree(value.right) + + input_memlets['cur'] = Memlet(data=advanced_target.name, + subset=copy.deepcopy(advanced_target.output_memlet.subset), + volume=advanced_target.output_memlet.volume) + tasklet_value = ast.copy_location( + ast.BinOp(left=ast.Name(id='cur', ctx=ast.Load()), op=astutils.copy_tree(value.op), right=rhs_tasklet), + value) + tasklet = tn.FrontendTasklet( + name=context.tasklet_name(target), + code=CodeBlock(f'{_unparse(advanced_target.target_expr)} = {_unparse(tasklet_value)}')) + tasklet_node = tn.TaskletNode(node=tasklet, + in_memlets=input_memlets, + out_memlets={'out': advanced_target.output_memlet}) + map_scope = tn.MapScope(node=tn.FrontendMap(params=list(iteration_plan.params), + ranges=list(iteration_plan.ranges)), + children=[]) + _register_iteration_symbols(context, iteration_plan) + tasklet_node.parent = map_scope + map_scope.children.append(tasklet_node) + return map_scope + + def _lower_boolean_target_assignment(self, context: NumpyLoweringContext, boolean_target: _AdvancedTarget, + target: ast.AST, value: ast.AST) -> Optional[tn.ScheduleTreeNode]: + is_augassign = isinstance(value, ast.BinOp) and _ast_equivalent(value.left, target) + rhs_node = value.right if is_augassign else value + rhs_analysis = _ElementwiseExpressionAnalyzer(context).analyze(rhs_node) + if rhs_analysis is None and not _is_trivial_scalar(rhs_node, context): + return None + + iteration_plan = _build_iteration_plan_from_shape(boolean_target.output_shape) + if iteration_plan is None: + return None + + input_memlets: Dict[str, Memlet] = dict(boolean_target.input_memlets) + rhs_tasklet = ast.copy_location(ast.Constant(value=None), rhs_node) + if rhs_analysis is not None: + if rhs_analysis.result_shape and not _is_shape_compatible_shape(boolean_target.output_shape, + rhs_analysis.result_shape): + return None + rhs_tasklet = rhs_analysis.tasklet_value + for access in rhs_analysis.accesses: + access_memlets = _build_input_memlets(access, iteration_plan) + if access_memlets is None: + return None + input_memlets.update(access_memlets) + else: + rhs_tasklet = astutils.copy_tree(rhs_node) + + if is_augassign: + current_memlet = Memlet(data=boolean_target.name, subset=copy.deepcopy(boolean_target.output_memlet.subset)) + input_memlets['cur'] = current_memlet + rhs_tasklet = ast.copy_location( + ast.BinOp(left=ast.Name(id='cur', ctx=ast.Load()), op=astutils.copy_tree(value.op), right=rhs_tasklet), + value) + + if boolean_target.guard_expr is None: + return None + code = f'if {_unparse(boolean_target.guard_expr)}:\n out = {_unparse(rhs_tasklet)}' + output_memlet = copy.deepcopy(boolean_target.output_memlet) + output_memlet.dynamic = True + tasklet = tn.FrontendTasklet(name=context.tasklet_name(target), code=CodeBlock(code)) + tasklet_node = tn.TaskletNode(node=tasklet, in_memlets=input_memlets, out_memlets={'out': output_memlet}) + map_scope = tn.MapScope(node=tn.FrontendMap(params=list(iteration_plan.params), + ranges=list(iteration_plan.ranges)), + children=[]) + _register_iteration_symbols(context, iteration_plan) + tasklet_node.parent = map_scope + map_scope.children.append(tasklet_node) + return map_scope + + def infer_expression_descriptor(self, context: NumpyLoweringContext, value: ast.AST) -> Optional[data.Data]: + analysis = _ElementwiseExpressionAnalyzer(context).analyze(value) + if analysis is None: + return None + if not analysis.result_shape: + return data.Scalar(analysis.result_dtype, transient=True) + return data.Array(analysis.result_dtype, list(analysis.result_shape), transient=True) + + +class _ElementwiseExpressionAnalyzer: + """Recognizes scalarized NumPy expressions over array accesses.""" + + def __init__(self, context: NumpyLoweringContext, start_index: int = 0) -> None: + self.context = context + self.start_index = start_index + self.accesses: List[_ResolvedAccess] = [] + self.access_map: Dict[Tuple[str, str, Tuple[str, ...]], _ResolvedAccess] = {} + + def analyze(self, node: ast.AST) -> Optional[_ExpressionAnalysis]: + rewritten = self._rewrite(astutils.copy_tree(node)) + if rewritten is None or not self.accesses: + return None + + tasklet_value, typing_value = rewritten + result_shape = _broadcast_shape(tuple(access.output_shape for access in self.accesses)) + if result_shape is None: + return None + + scalar_types = _scalar_type_environment(self.context, self.accesses) + try: + result_dtype = infer_expr_type(_unparse(typing_value), scalar_types) + except Exception: + result_dtype = None + + if result_dtype is None: + result_dtype = self.accesses[0].descriptor.dtype + + return _ExpressionAnalysis(tasklet_value=ast.fix_missing_locations(tasklet_value), + typing_value=ast.fix_missing_locations(typing_value), + accesses=tuple(self.accesses), + result_shape=result_shape, + result_dtype=result_dtype) + + def _rewrite(self, node: ast.AST) -> Optional[Tuple[ast.AST, ast.AST]]: + access = self._resolve_array_access(node) + if access is not None: + return (_tasklet_expr_for_access(access), + ast.copy_location(ast.Name(id=access.array_connector, ctx=ast.Load()), node)) + + if isinstance(node, ast.Constant): + copied = astutils.copy_tree(node) + return (copied, astutils.copy_tree(copied)) + + if isinstance(node, ast.Name): + if not _is_scalar_leaf(node, self.context): + return None + copied = astutils.copy_tree(node) + return (copied, astutils.copy_tree(copied)) + + if isinstance(node, ast.Attribute): + if not _is_scalar_leaf(node, self.context): + return None + copied = astutils.copy_tree(node) + return (copied, astutils.copy_tree(copied)) + + if isinstance(node, ast.Subscript): + if not _is_scalar_leaf(node, self.context): + return None + copied = astutils.copy_tree(node) + return (copied, astutils.copy_tree(copied)) + + if isinstance(node, ast.BinOp): + left = self._rewrite(node.left) + right = self._rewrite(node.right) + if left is None or right is None: + return None + return (ast.copy_location(ast.BinOp(left=left[0], op=astutils.copy_tree(node.op), right=right[0]), node), + ast.copy_location(ast.BinOp(left=left[1], op=astutils.copy_tree(node.op), right=right[1]), node)) + + if isinstance(node, ast.UnaryOp): + operand = self._rewrite(node.operand) + if operand is None: + return None + return (ast.copy_location(ast.UnaryOp(op=astutils.copy_tree(node.op), operand=operand[0]), node), + ast.copy_location(ast.UnaryOp(op=astutils.copy_tree(node.op), operand=operand[1]), node)) + + if isinstance(node, ast.BoolOp): + values = [self._rewrite(value) for value in node.values] + if any(value is None for value in values): + return None + return (ast.copy_location(ast.BoolOp(op=astutils.copy_tree(node.op), values=[value[0] for value in values]), + node), + ast.copy_location(ast.BoolOp(op=astutils.copy_tree(node.op), values=[value[1] for value in values]), + node)) + + if isinstance(node, ast.Compare): + left = self._rewrite(node.left) + comparators = [self._rewrite(comp) for comp in node.comparators] + if left is None or any(comp is None for comp in comparators): + return None + return (ast.copy_location( + ast.Compare(left=left[0], + ops=astutils.copy_tree(node.ops), + comparators=[comp[0] for comp in comparators]), node), + ast.copy_location( + ast.Compare(left=left[1], + ops=astutils.copy_tree(node.ops), + comparators=[comp[1] for comp in comparators]), node)) + + if isinstance(node, ast.IfExp): + test = self._rewrite(node.test) + body = self._rewrite(node.body) + orelse = self._rewrite(node.orelse) + if test is None or body is None or orelse is None: + return None + return (ast.copy_location(ast.IfExp(test=test[0], body=body[0], orelse=orelse[0]), node), + ast.copy_location(ast.IfExp(test=test[1], body=body[1], orelse=orelse[1]), node)) + + if isinstance(node, ast.Call): + if not _is_supported_call(node, self.context): + return None + args = [self._rewrite(arg) for arg in node.args] + if any(arg is None for arg in args): + return None + keywords: List[Tuple[ast.keyword, ast.keyword]] = [] + for keyword in node.keywords: + rewritten_value = self._rewrite(keyword.value) + if rewritten_value is None: + return None + keywords.append( + (ast.keyword(arg=keyword.arg, + value=rewritten_value[0]), ast.keyword(arg=keyword.arg, value=rewritten_value[1]))) + return (ast.copy_location( + ast.Call(func=astutils.copy_tree(node.func), + args=[arg[0] for arg in args], + keywords=[kw[0] for kw in keywords]), node), + ast.copy_location( + ast.Call(func=astutils.copy_tree(node.func), + args=[arg[1] for arg in args], + keywords=[kw[1] for kw in keywords]), node)) + + return None + + def _resolve_array_access(self, node: ast.AST) -> Optional[_ResolvedAccess]: + if isinstance(node, ast.Name): + binding = _ensure_binding_for_name(node.id, self.context) + if binding is None or binding.descriptor is None or not _is_numpy_arraylike(binding.descriptor): + return None + subset = subsets.Range.from_array(binding.descriptor) + return self._register_basic_access(node, node.id, binding.descriptor, subset) + + if isinstance(node, ast.Subscript) and isinstance(node.value, ast.Name): + binding = _ensure_binding_for_name(node.value.id, self.context) + if binding is None or binding.descriptor is None or not _is_numpy_arraylike(binding.descriptor): + return None + try: + subset, new_axes, arrdims = memlet_parser.parse_memlet_subset(binding.descriptor, node, + self.context.evaluation_context()) + except Exception: + return None + + if arrdims: + return self._register_advanced_access(node, node.value.id, binding.descriptor, subset, new_axes, + arrdims) + return self._register_basic_access(node, node.value.id, binding.descriptor, subset, new_axes) + + return None + + def _register_basic_access(self, + node: ast.AST, + name: str, + descriptor: data.Data, + subset: subsets.Range, + new_axes: Sequence[int] = ()) -> _ResolvedAccess: + axis_key = tuple(str(axis) for axis in new_axes) + key = (name, str(subset), axis_key) + existing = self.access_map.get(key) + if existing is not None: + return existing + + scalar_dims = _scalar_indexed_dims(node, len(descriptor.shape), self.context) + logical_ranges = _logical_ranges_from_basic_access(node, subset, descriptor, self.context) + output_shape = tuple(_shape_from_ranges(logical_ranges)) + if output_shape: + logical_shape = list(output_shape) + for axis in sorted(new_axes): + logical_shape.insert(axis, 1) + output_shape = tuple(logical_shape) + access = _ResolvedAccess(node=node, + name=name, + descriptor=_clone_descriptor(descriptor), + subset=copy.deepcopy(subset), + logical_ranges=logical_ranges, + array_connector=f'in{self.start_index + len(self.accesses)}', + index_connectors=tuple(), + output_shape=output_shape, + new_axes=tuple(new_axes), + scalar_dims=scalar_dims, + blueprint=None) + self.accesses.append(access) + self.access_map[key] = access + return access + + def _register_advanced_access(self, node: ast.AST, name: str, descriptor: data.Data, subset: subsets.Range, + new_axes: Sequence[int], arrdims: Dict[int, Any]) -> Optional[_ResolvedAccess]: + if any(_is_boolean_index(index_name, self.context) for index_name in arrdims.values()): + return None + + blueprint = _build_advanced_blueprint(name, subset, new_axes, arrdims, self.context) + if blueprint is None: + return None + + key = (name, str(subset), tuple(str(index) for index in arrdims.values())) + existing = self.access_map.get(key) + if existing is not None: + return existing + + access_index = self.start_index + len(self.accesses) + index_connectors = tuple(f'idx{access_index}_{i}' for i in range(len(blueprint.index_memlets))) + access = _ResolvedAccess(node=node, + name=name, + descriptor=_clone_descriptor(descriptor), + subset=copy.deepcopy(subset), + logical_ranges=tuple(copy.deepcopy(blueprint.source_memlet.subset).ranges), + array_connector=f'in{access_index}', + index_connectors=index_connectors, + output_shape=blueprint.output_shape, + blueprint=blueprint) + self.accesses.append(access) + self.access_map[key] = access + return access + + +def _build_iteration_plan(target_subset: subsets.Range) -> Optional[_IterationPlan]: + if target_subset.num_elements() == 1: + return None + + squeezed_subset = copy.deepcopy(target_subset) + non_singleton_dims = tuple(range(len(target_subset.ranges))) + params = tuple(f'__i{i}' for i in range(len(target_subset.ranges))) + ranges = tuple(_frontend_range_tuple(dim) for dim in target_subset.ranges) + return _IterationPlan(original_subset=copy.deepcopy(target_subset), + squeezed_subset=squeezed_subset, + non_singleton_dims=non_singleton_dims, + params=params, + ranges=ranges) + + +def _register_iteration_symbols(context: NumpyLoweringContext, iteration_plan: _IterationPlan) -> None: + for param in iteration_plan.params: + context.register_symbol(param) + + +def _build_iteration_plan_from_shape(shape: Sequence[Any]) -> Optional[_IterationPlan]: + if not shape: + return None + return _build_iteration_plan(subsets.Range([(0, dim - 1, 1) for dim in shape])) + + +def _build_output_memlet(target_name: str, iteration_plan: _IterationPlan) -> Memlet: + param_iter = iter(iteration_plan.params) + indices: List[Any] = [] + for dim, (start, _, _) in enumerate(iteration_plan.original_subset.ranges): + if dim in iteration_plan.non_singleton_dims: + indices.append(symbolic.symbol(next(param_iter), dtypes.int64)) + else: + indices.append(start) + return Memlet(data=target_name, subset=subsets.Range.from_indices(indices)) + + +def _build_output_memlet_for_subset(target_name: str, target_subset: subsets.Range, + iteration_plan: _IterationPlan) -> Memlet: + param_iter = iter(iteration_plan.params) + indices: List[Any] = [] + for dim, (start, _, _) in enumerate(target_subset.ranges): + if dim in iteration_plan.non_singleton_dims: + indices.append(symbolic.symbol(next(param_iter), dtypes.int64)) + else: + indices.append(start) + return Memlet(data=target_name, subset=subsets.Range.from_indices(indices)) + + +def _build_input_memlets(access: _ResolvedAccess, iteration_plan: _IterationPlan) -> Optional[Dict[str, Memlet]]: + if access.blueprint is not None: + return _build_advanced_input_memlets(access, iteration_plan) + + if access.new_axes: + return _build_newaxis_input_memlets(access, iteration_plan) + + if isinstance(access.node, ast.Subscript) and access.subset.num_elements() == 1: + return {access.array_connector: Memlet(data=access.name, subset=copy.deepcopy(access.subset))} + + try: + _, all_idx_tuples, _, _, inp_idx = broadcast_to(iteration_plan.squeezed_subset.size(), access.output_shape) + except Exception: + return None + + input_indices = [part.strip() for part in inp_idx.split(',')] if inp_idx else [] + missing_dimensions = list(iteration_plan.squeezed_subset.ranges[:len(all_idx_tuples) - len(input_indices)]) + fake_subset = subsets.Range(missing_dimensions + list(access.logical_ranges)) + + offset_indices_to_ignore = set() + for index, idx in enumerate(input_indices): + if not symbolic.issymbolic(symbolic.pystr_to_symbolic(idx)): + offset_indices_to_ignore.add(index) + offset_indices = [index for index in range(len(fake_subset)) if index not in offset_indices_to_ignore] + fake_subset.offset(iteration_plan.squeezed_subset, True, indices=offset_indices) + + idx_and_subset = reversed(list(zip(reversed(input_indices), reversed(fake_subset.ranges)))) + subset_indices = [_compose_input_index(idx, subset) for idx, subset in idx_and_subset] + subset_indices = _reinsert_scalar_dims(access, subset_indices) + return {access.array_connector: Memlet(data=access.name, subset=subsets.Range.from_indices(subset_indices))} + + +def _build_scalar_input_memlet(access: _ResolvedAccess) -> Optional[Dict[str, Memlet]]: + if access.blueprint is not None or access.new_axes or access.output_shape: + return None + if not isinstance(access.subset, subsets.Range) or access.subset.num_elements() != 1: + return None + return {access.array_connector: Memlet(data=access.name, subset=copy.deepcopy(access.subset))} + + +def _build_newaxis_input_memlets(access: _ResolvedAccess, + iteration_plan: _IterationPlan) -> Optional[Dict[str, Memlet]]: + try: + _, _, _, _, inp_idx = broadcast_to(iteration_plan.squeezed_subset.size(), access.output_shape) + except Exception: + return None + + input_indices = [part.strip() for part in inp_idx.split(',')] if inp_idx else [] + logical_ranges = list(copy.deepcopy(access.logical_ranges)) + for axis in sorted(access.new_axes): + logical_ranges.insert(axis, (0, 0, 1)) + + if len(input_indices) != len(logical_ranges): + return None + + subset_indices = [ + _compose_input_index(idx, subset) for dim, (idx, subset) in enumerate(zip(input_indices, logical_ranges)) + if dim not in access.new_axes + ] + subset_indices = _reinsert_scalar_dims(access, subset_indices) + return {access.array_connector: Memlet(data=access.name, subset=subsets.Range.from_indices(subset_indices))} + + +def _reinsert_scalar_dims(access: _ResolvedAccess, subset_indices: Sequence[Any]) -> List[Any]: + if not access.scalar_dims: + return list(subset_indices) + + scalar_dims = set(access.scalar_dims) + subset_iter = iter(subset_indices) + result: List[Any] = [] + for dim, rng in enumerate(access.subset.ranges): + if dim in scalar_dims: + result.append(rng[0]) + else: + result.append(next(subset_iter)) + return result + + +def _build_advanced_input_memlets(access: _ResolvedAccess, + iteration_plan: _IterationPlan) -> Optional[Dict[str, Memlet]]: + if access.blueprint is None: + return None + + mapping = _build_access_symbol_mapping(iteration_plan, access.blueprint.output_subset, access.output_shape) + if mapping is None: + return None + + result = { + access.array_connector: + Memlet(data=access.name, + subset=_substitute_subset(access.blueprint.source_memlet.subset, mapping), + volume=access.blueprint.source_memlet.volume) + } + for connector, memlet in zip(access.index_connectors, access.blueprint.index_memlets): + result[connector] = Memlet(data=memlet.data, + subset=_substitute_subset(memlet.subset, mapping), + volume=memlet.volume) + return result + + +def _build_access_symbol_mapping(iteration_plan: _IterationPlan, output_subset: subsets.Range, + operand_shape: Tuple[Any, ...]) -> Optional[Dict[Any, Any]]: + varying_dims = [ + index for index, (start, end, step) in enumerate(output_subset.ranges) + if step == 1 and start == end and symbolic.issymbolic(start) + ] + if not varying_dims: + return {} + try: + _, _, _, _, operand_idx = broadcast_to(iteration_plan.squeezed_subset.size(), operand_shape) + except Exception: + return None + operand_indices = [part.strip() for part in operand_idx.split(',')] if operand_idx else [] + if len(operand_indices) != len(output_subset.ranges): + return None + varying_operand_indices = [operand_indices[index] for index in varying_dims] + symbols = [str(output_subset.ranges[index][0]) for index in varying_dims] + if len(symbols) != len(varying_operand_indices): + return None + return { + symbolic.symbol(symbol): symbolic.pystr_to_symbolic(index) + for symbol, index in zip(symbols, varying_operand_indices) + } + + +def _build_advanced_blueprint(name: str, subset: subsets.Range, new_axes: Sequence[int], arrdims: Dict[int, Any], + context: NumpyLoweringContext) -> Optional[_AdvancedIndexBlueprint]: + output_shape = _compute_output_shape_from_advanced_indexing(subset, new_axes, arrdims, context) + if output_shape is None: + return None + + ndrange = subset.ndrange() + output_ndrange = [(symbolic.symbol(f'__i{i}', dtypes.int64), symbolic.symbol(f'__i{i}', dtypes.int64), + 1) if rng[0] != rng[1] else (0, 0, 1) for i, rng in enumerate(ndrange)] + input_subset = subsets.Range([(rb + ind * rs, rb + ind * rs, 1) + for (rb, _, rs), (ind, _, _) in zip(ndrange, output_ndrange)]) + index_memlets: List[Memlet] = [] + + output_shape_marks = [size if index not in arrdims else None for index, size in enumerate(subset.size())] + output_shape_marks = [None if rng[0] == rng[1] else size for size, rng in zip(output_shape_marks, subset.ndrange())] + output_ndrange_marks: List[Optional[Tuple[Any, Any, Any]]] = [ + None if output_shape_marks[i] is None else rng for i, rng in enumerate(output_ndrange) + ] + + advanced_dims = [ + index for index, size in enumerate(output_shape_marks) + if size is None and (index == 0 or output_shape_marks[index - 1] is not None) + ] + prefix_dims = len(advanced_dims) > 1 + if prefix_dims: + output_shape_marks = [None] + [size for size in output_shape_marks if size is not None] + output_ndrange_marks = [None] + [rng for rng in output_ndrange_marks if rng is not None] + dim_position = 0 + else: + dim_position = advanced_dims[0] + + for new_axis in reversed(new_axes): + if prefix_dims: + output_shape_marks.insert(new_axis + 1, 1) + output_ndrange_marks.insert(new_axis + 1, (0, 0, 1)) + else: + output_shape_marks.insert(new_axis, 1) + output_ndrange_marks.insert(new_axis, (0, 0, 1)) + if new_axis <= dim_position: + dim_position += 1 + + output_shape_marks = [ + size for index, size in enumerate(output_shape_marks) + if size is not None or index == 0 or output_shape_marks[index - 1] is not None + ] + output_ndrange_marks = [ + rng for index, rng in enumerate(output_ndrange_marks) + if rng is not None or index == 0 or output_ndrange_marks[index - 1] is not None + ] + + advidx_shape: Optional[Tuple[Any, ...]] = None + out_idx: Optional[str] = None + advidx_arrays: Dict[int, Tuple[str, Sequence[Any]]] = {} + for index, idxarrname in arrdims.items(): + if not isinstance(idxarrname, str): + return None + descriptor = _index_descriptor(idxarrname, context) + if descriptor is None: + return None + advidx_arrays[index] = (idxarrname, descriptor.shape) + if advidx_shape is not None: + advidx_shape, _, out_idx, *_ = broadcast_together(descriptor.shape, advidx_shape) + else: + advidx_shape = tuple(descriptor.shape) + out_idx = ', '.join(f'__i{i}' for i in range(len(descriptor.shape))) + if advidx_shape is None or out_idx is None: + return None + + out_idx = out_idx.replace('__i', '__ind') + advidx_index: List[Tuple[Any, Any, Any]] = [] + for index_name, size in zip((part.strip() for part in out_idx.split(',')), advidx_shape): + sym = symbolic.symbol(index_name, dtypes.int64) + advidx_index.append((sym, sym, 1)) + + for dim, (idxarrname, shape) in advidx_arrays.items(): + _, _, _, arr_idx, _ = broadcast_together(shape, advidx_shape) + arr_idx = arr_idx.replace('__i', '__ind').split(',') + arr_subset = subsets.Range([(symbolic.symbol(index.strip(), + dtypes.int64), symbolic.symbol(index.strip(), dtypes.int64), 1) + for index in arr_idx]) + index_memlets.append(Memlet(data=idxarrname, subset=arr_subset, volume=1)) + input_subset[dim] = ndrange[dim] + + output_ndrange_final = output_ndrange_marks[:dim_position] + advidx_index + output_ndrange_marks[dim_position + 1:] + return _AdvancedIndexBlueprint(output_shape=tuple(output_shape), + output_subset=subsets.Range(output_ndrange_final), + source_memlet=Memlet(data=name, subset=input_subset, volume=1), + index_memlets=tuple(index_memlets)) + + +def _compute_output_shape_from_advanced_indexing(subset: subsets.Range, new_axes: Sequence[int], arrdims: Dict[int, + Any], + context: NumpyLoweringContext) -> Optional[List[Any]]: + output_shape = [size if index not in arrdims else None for index, size in enumerate(subset.size())] + if arrdims: + output_shape = [None if rng[0] == rng[1] else size for size, rng in zip(output_shape, subset.ndrange())] + + advanced_dims = [ + index for index, size in enumerate(output_shape) + if size is None and (index == 0 or output_shape[index - 1] is not None) + ] + prefix_dims = len(advanced_dims) > 1 + if prefix_dims: + output_shape = [None] + [size for size in output_shape if size is not None] + dim_position = 0 + else: + dim_position = advanced_dims[0] + + for new_axis in new_axes: + if prefix_dims: + output_shape.insert(new_axis + 1, 1) + else: + output_shape.insert(new_axis, 1) + if new_axis <= dim_position: + dim_position += 1 + + output_shape = [ + size for index, size in enumerate(output_shape) + if size is not None or index == 0 or output_shape[index - 1] is not None + ] + + chunk_shape: Optional[Tuple[Any, ...]] = None + for arrname in arrdims.values(): + if not isinstance(arrname, str): + return None + descriptor = _index_descriptor(arrname, context) + if descriptor is None: + return None + if chunk_shape is None: + chunk_shape = tuple(descriptor.shape) + else: + try: + chunk_shape, *_ = broadcast_together(descriptor.shape, chunk_shape) + except Exception: + return None + + if chunk_shape is None: + return None + return output_shape[:dim_position] + list(chunk_shape) + output_shape[dim_position + 1:] + + +def _varying_subset_symbols(subset: subsets.Range) -> List[str]: + result: List[str] = [] + for start, end, step in subset.ranges: + if step == 1 and start == end and symbolic.issymbolic(start): + result.append(str(start)) + return result + + +def _substitute_subset(subset: subsets.Range, mapping: Dict[Any, Any]) -> subsets.Range: + replaced = [] + for start, end, step in subset.ranges: + replaced.append((_substitute_expr(start, mapping), _substitute_expr(end, + mapping), _substitute_expr(step, mapping))) + return subsets.Range(replaced) + + +def _substitute_expr(expr: Any, mapping: Dict[Any, Any]) -> Any: + if isinstance(expr, symbolic.SymExpr): + return symbolic.SymExpr(expr.expr.subs(mapping, simultaneous=True), expr.approx.subs(mapping, + simultaneous=True)) + if hasattr(expr, 'subs'): + return expr.subs(mapping, simultaneous=True) + return expr + + +def _compose_input_index(index_expr: str, subset: Tuple[Any, Any, Any]) -> Any: + start, _, _ = subset + if index_expr == '0': + return start + symbolic_index = symbolic.pystr_to_symbolic(index_expr) + if symbolic.pystr_to_symbolic(str(start)) == 0: + return symbolic_index + return symbolic_index + start + + +def _broadcast_shape(shapes: Sequence[Tuple[Any, ...]]) -> Optional[Tuple[Any, ...]]: + concrete_shapes = [shape for shape in shapes if shape] + if not concrete_shapes: + return tuple() + result = concrete_shapes[0] + for shape in concrete_shapes[1:]: + try: + result, _, _, _, _ = broadcast_together(result, shape) + except Exception: + return None + return tuple(result) + + +def _is_shape_compatible(target_subset: subsets.Range, source_shape: Tuple[Any, ...]) -> bool: + try: + broadcast_to(target_subset.size(), source_shape) + except Exception: + return False + return True + + +def _is_shape_compatible_shape(target_shape: Sequence[Any], source_shape: Tuple[Any, ...]) -> bool: + try: + broadcast_to(tuple(target_shape), source_shape) + except Exception: + return False + return True + + +def _scalar_type_environment(context: NumpyLoweringContext, + accesses: Sequence[_ResolvedAccess]) -> Dict[str, dtypes.typeclass]: + result = {access.array_connector: access.descriptor.dtype for access in accesses} + for name, binding in context.bindings.items(): + if binding.descriptor is not None and isinstance(binding.descriptor, data.Scalar): + result[name] = binding.descriptor.dtype + for name, value in context.evaluation_context().items(): + if isinstance(value, symbolic.symbol): + result[name] = value.dtype + return result + + +def _tasklet_expr_for_access(access: _ResolvedAccess) -> ast.AST: + base = ast.Name(id=access.array_connector, ctx=ast.Load()) + if not access.index_connectors: + return base + if len(access.index_connectors) == 1: + return ast.Subscript(value=base, slice=ast.Name(id=access.index_connectors[0], ctx=ast.Load()), ctx=ast.Load()) + return ast.Subscript(value=base, + slice=ast.Tuple( + elts=[ast.Name(id=connector, ctx=ast.Load()) for connector in access.index_connectors], + ctx=ast.Load()), + ctx=ast.Load()) + + +def _resolve_boolean_gather(value: ast.AST, context: NumpyLoweringContext, + output_descriptor: Optional[data.Data]) -> Optional[_BooleanGatherPlan]: + if not isinstance(value, ast.Subscript) or not isinstance(value.value, ast.Name): + return None + + source_name = value.value.id + binding = context.bindings.get(source_name) + if binding is None or binding.descriptor is None or not _is_numpy_arraylike(binding.descriptor): + return None + source_descriptor = _clone_descriptor(binding.descriptor) + + try: + subset, new_axes, arrdims = memlet_parser.parse_memlet_subset(source_descriptor, value, + context.evaluation_context()) + except Exception: + subset, new_axes, arrdims = None, [], {} + + if subset is not None and arrdims: + bool_indices = [index_name for index_name in arrdims.values() if _is_boolean_index(index_name, context)] + if len(bool_indices) != 1 or len(arrdims) != 1 or new_axes: + return None + mask_name = bool_indices[0] + mask_descriptor = _index_descriptor(mask_name, context) + if mask_descriptor is None or tuple(mask_descriptor.shape) != tuple(source_descriptor.shape): + return None + result_descriptor = _boolean_gather_descriptor(source_descriptor, output_descriptor, context) + return _BooleanGatherPlan(source_name=source_name, + source_descriptor=source_descriptor, + result_descriptor=result_descriptor, + input_memlets={ + 'data': Memlet.from_array(source_name, source_descriptor), + 'mask': Memlet.from_array(mask_name, mask_descriptor), + }, + nnz_symbol=result_descriptor.shape[0]) + + mask_analysis = _ElementwiseExpressionAnalyzer(context, start_index=100).analyze(value.slice) + if mask_analysis is None or mask_analysis.result_dtype != dtypes.bool: + return None + if tuple(mask_analysis.result_shape) != tuple(source_descriptor.shape): + return None + if any(access.blueprint is not None for access in mask_analysis.accesses): + return None + + input_memlets = {'data': Memlet.from_array(source_name, source_descriptor)} + for access in mask_analysis.accesses: + input_memlets[access.array_connector] = Memlet(data=access.name, subset=copy.deepcopy(access.subset)) + + result_descriptor = _boolean_gather_descriptor(source_descriptor, output_descriptor, context) + return _BooleanGatherPlan(source_name=source_name, + source_descriptor=source_descriptor, + result_descriptor=result_descriptor, + input_memlets=input_memlets, + nnz_symbol=result_descriptor.shape[0], + mask_expr=mask_analysis.tasklet_value) + + +def _boolean_gather_descriptor(source_descriptor: data.Data, output_descriptor: Optional[data.Data], + context: NumpyLoweringContext) -> data.Array: + nnz_symbol = _boolean_gather_symbol(output_descriptor, context) + bound = _shape_product(source_descriptor.shape) + return data.Array(source_descriptor.dtype, [nnz_symbol], transient=True, total_size=bound) + + +def _boolean_gather_symbol(output_descriptor: Optional[data.Data], context: NumpyLoweringContext) -> symbolic.symbol: + if isinstance(output_descriptor, data.Array) and len(output_descriptor.shape) == 1 and symbolic.issymbolic( + output_descriptor.shape[0]): + return symbolic.pystr_to_symbolic(str(output_descriptor.shape[0])) + return context.fresh_symbol('__stree_mask_nnz') + + +def _shape_product(shape: Sequence[Any]) -> Any: + result: Any = 1 + for dim in shape: + result = result * dim + return result + + +def _resolve_integer_target(context: NumpyLoweringContext, target: ast.AST) -> Optional[_AdvancedTarget]: + if not isinstance(target, ast.Subscript) or not isinstance(target.value, ast.Name): + return None + binding = context.bindings.get(target.value.id) + if binding is None or binding.descriptor is None or not _is_numpy_arraylike(binding.descriptor): + return None + try: + subset, new_axes, arrdims = memlet_parser.parse_memlet_subset(binding.descriptor, target, + context.evaluation_context()) + except Exception: + return None + if not arrdims or any(_is_boolean_index(index_name, context) for index_name in arrdims.values()): + return None + blueprint = _build_advanced_blueprint(target.value.id, subset, new_axes, arrdims, context) + if blueprint is None: + return None + iteration_plan = _build_iteration_plan_from_shape(blueprint.output_shape) + if iteration_plan is None: + return None + mapping = _build_access_symbol_mapping(iteration_plan, blueprint.output_subset, blueprint.output_shape) + if mapping is None: + return None + access_index = 1000 + index_connectors = tuple(f'outidx_{i}' for i in range(len(blueprint.index_memlets))) + input_memlets = { + connector: Memlet(data=memlet.data, subset=_substitute_subset(memlet.subset, mapping), volume=memlet.volume) + for connector, memlet in zip(index_connectors, blueprint.index_memlets) + } + output_memlet = Memlet(data=target.value.id, + subset=_substitute_subset(blueprint.source_memlet.subset, mapping), + volume=blueprint.source_memlet.volume) + return _AdvancedTarget(name=target.value.id, + output_shape=blueprint.output_shape, + output_memlet=output_memlet, + target_expr=_subscript_expr('out', index_connectors), + input_memlets=input_memlets) + + +def _resolve_boolean_target(context: NumpyLoweringContext, target: ast.AST, + value: ast.AST) -> Optional[_AdvancedTarget]: + if not isinstance(target, ast.Subscript) or not isinstance(target.value, ast.Name): + return None + binding = context.bindings.get(target.value.id) + if binding is None or binding.descriptor is None or not _is_numpy_arraylike(binding.descriptor): + return None + + subset: Optional[subsets.Range] = None + guard_expr: Optional[ast.AST] = None + input_memlets: Dict[str, Memlet] = {} + target_shape: Optional[Tuple[Any, ...]] = None + + try: + subset, new_axes, arrdims = memlet_parser.parse_memlet_subset(binding.descriptor, target, + context.evaluation_context()) + except Exception: + subset, new_axes, arrdims = None, [], {} + + if subset is not None and arrdims: + bool_indices = [index_name for index_name in arrdims.values() if _is_boolean_index(index_name, context)] + if len(bool_indices) != 1 or len(arrdims) != 1 or new_axes: + return None + bool_name = bool_indices[0] + mask_desc = _index_descriptor(bool_name, context) + if mask_desc is None or tuple(mask_desc.shape) != tuple(binding.descriptor.shape): + return None + target_shape = tuple(_shape_from_basic_subset(subset)) + iteration_plan = _build_iteration_plan(subset) + if iteration_plan is None: + return None + mask_subset = _build_output_memlet(bool_name, iteration_plan).subset + input_memlets['mask'] = Memlet(data=bool_name, subset=mask_subset) + guard_expr = ast.Name(id='mask', ctx=ast.Load()) + output_memlet = Memlet(data=target.value.id, + subset=_build_output_memlet(target.value.id, iteration_plan).subset) + return _AdvancedTarget(name=target.value.id, + output_shape=target_shape, + output_memlet=output_memlet, + target_expr=ast.Name(id='out', ctx=ast.Store()), + input_memlets=input_memlets, + guard_expr=guard_expr) + + if isinstance(target.slice, ast.Compare): + subset = subsets.Range.from_array(binding.descriptor) + target_shape = tuple(binding.descriptor.shape) + iteration_plan = _build_iteration_plan(subset) + if iteration_plan is None: + return None + mask_analysis = _ElementwiseExpressionAnalyzer(context, start_index=100).analyze(target.slice) + if mask_analysis is None: + return None + if mask_analysis.result_shape and not _is_shape_compatible(subset, mask_analysis.result_shape): + return None + for access in mask_analysis.accesses: + access_memlets = _build_input_memlets(access, iteration_plan) + if access_memlets is None: + return None + input_memlets.update(access_memlets) + output_memlet = Memlet(data=target.value.id, + subset=_build_output_memlet(target.value.id, iteration_plan).subset) + return _AdvancedTarget(name=target.value.id, + output_shape=target_shape, + output_memlet=output_memlet, + target_expr=ast.Name(id='out', ctx=ast.Store()), + input_memlets=input_memlets, + guard_expr=mask_analysis.tasklet_value) + + return None + + +def _subscript_expr(base_name: str, connectors: Sequence[str]) -> ast.AST: + if len(connectors) == 1: + return ast.Subscript(value=ast.Name(id=base_name, ctx=ast.Load()), + slice=ast.Name(id=connectors[0], ctx=ast.Load()), + ctx=ast.Store()) + return ast.Subscript(value=ast.Name(id=base_name, ctx=ast.Load()), + slice=ast.Tuple(elts=[ast.Name(id=connector, ctx=ast.Load()) for connector in connectors], + ctx=ast.Load()), + ctx=ast.Store()) + + +def _ast_equivalent(left: ast.AST, right: ast.AST) -> bool: + return _unparse(left) == _unparse(right) + + +def _is_trivial_scalar(node: ast.AST, context: NumpyLoweringContext) -> bool: + if isinstance(node, ast.Constant): + return True + return _is_scalar_leaf(node, context) + + +def _is_supported_call(node: ast.Call, context: NumpyLoweringContext) -> bool: + if astutils.rname(node.func) == 'abs': + return True + value = try_resolve_static_value(node.func, context.evaluation_context()) + if value is UNRESOLVED: + return False + return isinstance(value, np.ufunc) + + +def _resolve_plain_multi_output_ufunc(node: ast.Call, context: NumpyLoweringContext) -> Optional[np.ufunc]: + if node.keywords: + return None + value = try_resolve_static_value(node.func, context.evaluation_context()) + if value is UNRESOLVED or not isinstance(value, np.ufunc): + return None + if value.nout <= 1 or len(node.args) != value.nin: + return None + return value + + +def _is_scalar_leaf(node: ast.AST, context: NumpyLoweringContext) -> bool: + if isinstance(node, ast.Name): + binding = _ensure_binding_for_name(node.id, context) + if binding is not None and binding.descriptor is not None: + return isinstance(binding.descriptor, data.Scalar) + value = try_resolve_static_value(node, context.evaluation_context()) + if value is UNRESOLVED: + return False + if isinstance(value, symbolic.symbol): + return True + if isinstance(value, data.Data): + return isinstance(value, data.Scalar) + return isinstance(value, (numbers.Number, bool, str, bytes, type(None))) or symbolic.issymbolic(value) + + +def _shape_from_basic_subset(subset: subsets.Range) -> List[Any]: + squeezed = copy.deepcopy(subset) + squeezed.squeeze(offset=False) + return list(squeezed.size()) + + +def _shape_from_ranges(ranges: Sequence[Tuple[Any, Any, Any]]) -> List[Any]: + if not ranges: + return [] + return list(subsets.Range(list(ranges)).size()) + + +def _logical_ranges_from_basic_access(node: ast.AST, subset: subsets.Range, descriptor: data.Data, + context: NumpyLoweringContext) -> Tuple[Tuple[Any, Any, Any], ...]: + if isinstance(node, ast.Name): + return tuple(copy.deepcopy(subset).ranges) + + scalar_dims = set(_scalar_indexed_dims(node, len(descriptor.shape), context)) + return tuple(copy.deepcopy(rng) for index, rng in enumerate(subset.ranges) if index not in scalar_dims) + + +def _scalar_indexed_dims(node: ast.AST, rank: int, context: NumpyLoweringContext) -> Tuple[int, ...]: + if not isinstance(node, ast.Subscript): + return tuple() + + remaining_dims = list(range(rank)) + scalar_dims: List[int] = [] + for ast_ndslice in astutils.subscript_to_ast_slice_recursive(node): + expanded_dims = _expand_basic_slice_dims(ast_ndslice, len(remaining_dims), context) + next_remaining: List[int] = [] + remaining_iter = iter(remaining_dims) + for kind in expanded_dims: + if kind == 'newaxis': + continue + base_dim = next(remaining_iter) + if kind == 'scalar': + scalar_dims.append(base_dim) + else: + next_remaining.append(base_dim) + next_remaining.extend(remaining_iter) + remaining_dims = next_remaining + + return tuple(scalar_dims) + + +def _expand_basic_slice_dims(ast_ndslice: Sequence[Any], remaining_rank: int, + context: NumpyLoweringContext) -> List[str]: + kinds = [_basic_dim_kind(dim, context) for dim in ast_ndslice] + consumed_dims = sum(1 for kind in kinds if kind not in {'newaxis', 'ellipsis'}) + if 'ellipsis' not in kinds: + kinds.extend(['slice'] * max(0, remaining_rank - consumed_dims)) + return kinds + + expanded: List[str] = [] + for kind in kinds: + if kind != 'ellipsis': + expanded.append(kind) + continue + ellipsis_dims = max(0, remaining_rank - consumed_dims) + expanded.extend(['slice'] * ellipsis_dims) + return expanded + + +def _basic_dim_kind(dim: Any, context: NumpyLoweringContext) -> str: + if isinstance(dim, tuple): + return 'slice' + if dim is None or (isinstance(dim, ast.Constant) and dim.value is None): + return 'newaxis' + if dim is Ellipsis or (isinstance(dim, ast.Constant) and dim.value is Ellipsis): + return 'ellipsis' + + resolved = try_resolve_static_value(dim, context.evaluation_context()) + if isinstance(resolved, slice): + return 'slice' + + return 'scalar' + + +def _is_scalar_subscript(node: ast.AST, subset: subsets.Range) -> bool: + if not isinstance(node, ast.Subscript): + return False + if isinstance(node.slice, ast.Slice): + return False + for (start, end, step), tile in zip(subset.ranges, subset.tile_sizes): + if tile != 1 or step != 1 or start != end: + return False + return True + + +def _is_boolean_index(index_name: Any, context: NumpyLoweringContext) -> bool: + if not isinstance(index_name, str): + return False + descriptor = _index_descriptor(index_name, context) + return descriptor is not None and descriptor.dtype == dtypes.bool + + +def _index_descriptor(index_name: str, context: NumpyLoweringContext) -> Optional[data.Data]: + binding = _ensure_binding_for_name(index_name, context) + if binding is not None and binding.descriptor is not None: + return binding.descriptor + value = context.evaluation_context().get(index_name) + if isinstance(value, data.Data): + return value + return None + + +def _ensure_binding_for_name(name: str, context: NumpyLoweringContext) -> Optional[_Binding]: + binding = context.bindings.get(name) + if binding is not None and binding.descriptor is not None: + return binding + + value = context.evaluation_context().get(name, UNRESOLVED) + if value is UNRESOLVED or symbolic.issymbolic(value): + return None + + try: + descriptor = data.create_datadescriptor(value) + except Exception: + return None + + kind = 'scalar' if isinstance(descriptor, data.Scalar) else 'container' + context.register_binding(name, descriptor, kind) + return context.bindings.get(name) + + +def _is_numpy_arraylike(descriptor: data.Data) -> bool: + return not isinstance(descriptor, + (data.Scalar, PythonDict, PythonList, PythonTuple)) and hasattr(descriptor, 'shape') + + +def _frontend_range_tuple(dim: Tuple[Any, Any, Any]) -> Tuple[str, str, str]: + start, end, step = dim + offset = -1 if (step < 0) == True else 1 + return (str(start), str(end + offset), str(step)) + + +def _clone_descriptor(descriptor: data.Data) -> data.Data: + return copy.deepcopy(descriptor) + + +def _unparse(node: ast.AST) -> str: + return astutils.unparse(node) diff --git a/dace/frontend/python/schedule_tree/static_evaluation.py b/dace/frontend/python/schedule_tree/static_evaluation.py new file mode 100644 index 0000000000..a71ffc9dfb --- /dev/null +++ b/dace/frontend/python/schedule_tree/static_evaluation.py @@ -0,0 +1,276 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Static AST evaluation helpers for schedule-tree inference. + +These helpers resolve a narrow subset of Python AST without executing user code. +They are intended for parser-time metadata recovery such as shapes, dtypes, and +simple literal/container reasoning. +""" + +import ast +import inspect +import operator +import types +from typing import Any, Dict + +import numpy as np + +from dace import data +from dace.frontend.python import astutils + +UNRESOLVED = object() + +_SAFE_BINARY_OPERATORS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.BitAnd: operator.and_, + ast.BitOr: operator.or_, + ast.BitXor: operator.xor, + ast.LShift: operator.lshift, + ast.RShift: operator.rshift, +} +_SAFE_UNARY_OPERATORS = { + ast.UAdd: operator.pos, + ast.USub: operator.neg, + ast.Not: operator.not_, + ast.Invert: operator.invert, +} +_SAFE_COMPARE_OPERATORS = { + ast.Eq: operator.eq, + ast.NotEq: operator.ne, + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge, + ast.Is: operator.is_, + ast.IsNot: operator.is_not, + ast.In: lambda left, right: left in right, + ast.NotIn: lambda left, right: left not in right, +} + + +def try_resolve_static_value(node: ast.AST, env: Dict[str, Any]) -> Any: + """Resolve a static value from ``node`` or return ``UNRESOLVED``.""" + return _StaticValueResolver(env).resolve(node) + + +class _StaticValueResolver: + + def __init__(self, env: Dict[str, Any]) -> None: + self.env = env + + def resolve(self, node: ast.AST) -> Any: + if isinstance(node, ast.Constant): + return node.value + + if isinstance(node, ast.Name): + return self.env.get(node.id, UNRESOLVED) + + if isinstance(node, ast.Attribute): + base = self.resolve(node.value) + if base is UNRESOLVED: + return UNRESOLVED + if _supports_simple_object_attribute_lookup(base): + return _resolve_simple_object_attribute(base, node.attr) + if not _supports_attribute_lookup(base): + return UNRESOLVED + try: + return getattr(base, node.attr) + except Exception: + return UNRESOLVED + + if isinstance(node, ast.Tuple): + values = [self.resolve(element) for element in node.elts] + if any(value is UNRESOLVED for value in values): + return UNRESOLVED + return tuple(values) + + if isinstance(node, ast.List): + values = [self.resolve(element) for element in node.elts] + if any(value is UNRESOLVED for value in values): + return UNRESOLVED + return values + + if isinstance(node, ast.Set): + values = [self.resolve(element) for element in node.elts] + if any(value is UNRESOLVED for value in values): + return UNRESOLVED + try: + return set(values) + except TypeError: + return UNRESOLVED + + if isinstance(node, ast.Dict): + resolved = {} + for key_node, value_node in zip(node.keys, node.values): + if key_node is None: + return UNRESOLVED + key = self.resolve(key_node) + value = self.resolve(value_node) + if key is UNRESOLVED or value is UNRESOLVED: + return UNRESOLVED + try: + resolved[key] = value + except TypeError: + return UNRESOLVED + return resolved + + if isinstance(node, ast.Slice): + lower = None if node.lower is None else self.resolve(node.lower) + upper = None if node.upper is None else self.resolve(node.upper) + step = None if node.step is None else self.resolve(node.step) + if lower is UNRESOLVED or upper is UNRESOLVED or step is UNRESOLVED: + return UNRESOLVED + return slice(lower, upper, step) + + if isinstance(node, ast.Subscript): + base = self.resolve(node.value) + index = self.resolve(node.slice) + if base is UNRESOLVED or index is UNRESOLVED or not _supports_subscript(base): + return UNRESOLVED + try: + return base[index] + except Exception: + return UNRESOLVED + + if isinstance(node, ast.UnaryOp): + operand = self.resolve(node.operand) + if operand is UNRESOLVED: + return UNRESOLVED + operator_fn = _SAFE_UNARY_OPERATORS.get(type(node.op)) + if operator_fn is None: + return UNRESOLVED + try: + return operator_fn(operand) + except Exception: + return UNRESOLVED + + if isinstance(node, ast.BinOp): + left = self.resolve(node.left) + right = self.resolve(node.right) + if left is UNRESOLVED or right is UNRESOLVED: + return UNRESOLVED + operator_fn = _SAFE_BINARY_OPERATORS.get(type(node.op)) + if operator_fn is None: + return UNRESOLVED + try: + return operator_fn(left, right) + except Exception: + return UNRESOLVED + + if isinstance(node, ast.BoolOp): + values = [self.resolve(value) for value in node.values] + if any(value is UNRESOLVED for value in values): + return UNRESOLVED + if isinstance(node.op, ast.And): + result = values[0] + for value in values[1:]: + result = result and value + return result + if isinstance(node.op, ast.Or): + result = values[0] + for value in values[1:]: + result = result or value + return result + return UNRESOLVED + + if isinstance(node, ast.Compare): + left = self.resolve(node.left) + if left is UNRESOLVED: + return UNRESOLVED + current = left + for op, comparator_node in zip(node.ops, node.comparators): + right = self.resolve(comparator_node) + if right is UNRESOLVED: + return UNRESOLVED + operator_fn = _SAFE_COMPARE_OPERATORS.get(type(op)) + if operator_fn is None: + return UNRESOLVED + try: + if not operator_fn(current, right): + return False + except Exception: + return UNRESOLVED + current = right + return True + + if isinstance(node, ast.IfExp): + condition = self.resolve(node.test) + if condition is UNRESOLVED: + return UNRESOLVED + branch = node.body if condition else node.orelse + return self.resolve(branch) + + if isinstance(node, ast.Call): + return self._resolve_builtin_container_call(node) + + return UNRESOLVED + + def _resolve_builtin_container_call(self, node: ast.Call) -> Any: + if node.keywords: + return UNRESOLVED + + call_name = astutils.rname(node.func) + if call_name not in {'tuple', 'list'}: + return UNRESOLVED + + args = [self.resolve(arg) for arg in node.args] + if any(arg is UNRESOLVED for arg in args): + return UNRESOLVED + + if call_name == 'tuple': + if len(args) == 0: + return tuple() + if len(args) != 1: + return UNRESOLVED + try: + return tuple(args[0]) + except TypeError: + return UNRESOLVED + + if len(args) == 0: + return [] + if len(args) != 1: + return UNRESOLVED + try: + return list(args[0]) + except TypeError: + return UNRESOLVED + + +def _supports_attribute_lookup(value: Any) -> bool: + if isinstance(value, (types.ModuleType, type, data.Data, np.ndarray, np.generic)): + return True + return type(value).__module__.startswith('numpy') + + +def _supports_simple_object_attribute_lookup(value: Any) -> bool: + if _supports_attribute_lookup(value): + return False + + objtype = type(value) + getattribute = objtype.__dict__.get('__getattribute__') + if getattribute is not None and getattribute is not object.__getattribute__: + return False + if '__getattr__' in objtype.__dict__: + return False + return True + + +def _resolve_simple_object_attribute(value: Any, attr_name: str) -> Any: + try: + static_attr = inspect.getattr_static(value, attr_name) + except AttributeError: + return UNRESOLVED + + if any(hasattr(static_attr, attr) for attr in ('__get__', '__set__', '__delete__')): + return UNRESOLVED + return static_attr + + +def _supports_subscript(value: Any) -> bool: + return isinstance(value, (list, tuple, str, bytes, dict, np.ndarray)) diff --git a/dace/frontend/python/schedule_tree/structure_helpers.py b/dace/frontend/python/schedule_tree/structure_helpers.py new file mode 100644 index 0000000000..b79cbb2f9f --- /dev/null +++ b/dace/frontend/python/schedule_tree/structure_helpers.py @@ -0,0 +1,7 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Compatibility shim for the older structure helper module name.""" + +from dace.frontend.python.schedule_tree.structure_support import bind_target_structure, clone_descriptor, \ + descriptor_from_structure + +__all__ = ['bind_target_structure', 'clone_descriptor', 'descriptor_from_structure'] diff --git a/dace/frontend/python/schedule_tree/structure_support.py b/dace/frontend/python/schedule_tree/structure_support.py new file mode 100644 index 0000000000..5592ba8897 --- /dev/null +++ b/dace/frontend/python/schedule_tree/structure_support.py @@ -0,0 +1,225 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Canonical helpers for Python structure handling in schedule-tree lowering. + +This module is the forward-looking home for schedule-tree structure support. +The older ``structure_helpers`` module remains as a compatibility shim while +callers migrate to this boundary. +""" + +from __future__ import annotations + +import ast +import copy +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Mapping, Optional, Sequence + +from dace import data, dtypes +from dace.data.pydata import PythonClass, PythonList, PythonTuple + + +def clone_descriptor(descriptor: data.Data) -> data.Data: + return copy.deepcopy(descriptor) + + +def structure_member_path(base_path: str, member_name: str) -> str: + return f'{base_path}.{member_name}' + + +@dataclass(frozen=True) +class StructureMemberAccess: + data_name: str + descriptor: data.Data + + +def descriptor_members(descriptor: data.Data) -> Optional[Mapping[str, data.Data]]: + if hasattr(descriptor, 'members'): + return descriptor.members + stype = getattr(descriptor, 'stype', None) + if stype is not None and hasattr(stype, 'members'): + return stype.members + return None + + +def supports_member_access(descriptor: data.Data) -> bool: + return descriptor_members(descriptor) is not None + + +def member_descriptor(descriptor: data.Data, member_name: str) -> Optional[data.Data]: + members = descriptor_members(descriptor) + if members is None or member_name not in members: + return None + result = clone_descriptor(members[member_name]) + result.transient = descriptor.transient + return result + + +def nested_member_descriptor(descriptor: data.Data, member_names: Sequence[str]) -> Optional[data.Data]: + current = clone_descriptor(descriptor) + for member_name in member_names: + current = member_descriptor(current, member_name) + if current is None: + return None + return current + + +def resolve_member_access(base_name: str, descriptor: data.Data, member_name: str) -> Optional[StructureMemberAccess]: + member = member_descriptor(descriptor, member_name) + if member is None: + return None + return StructureMemberAccess(data_name=structure_member_path(base_name, member_name), descriptor=member) + + +def ensure_nested_member_descriptor(descriptor: data.Data, member_names: Sequence[str], + member: data.Data) -> Optional[data.Data]: + if not member_names: + return None + + current = descriptor + for member_name in member_names[:-1]: + members = descriptor_members(current) + if members is None or member_name not in members: + return None + current = members[member_name] + stype = getattr(current, 'stype', None) + if stype is not None and hasattr(stype, 'members'): + current = stype + + members = descriptor_members(current) + if members is None: + return None + + leaf_name = member_names[-1] + if leaf_name not in members: + new_member = clone_descriptor(member) + new_member.transient = False + members[leaf_name] = new_member + + return clone_descriptor(members[leaf_name]) + + +def direct_class_annotation_type(annotation: Any) -> Optional[type[Any]]: + if not isinstance(annotation, type): + return None + try: + data.Structure.from_class(annotation) + except (TypeError, ValueError): + return None + return annotation + + +def nested_direct_class_owner(root_class_type: type[Any], member_names: Sequence[str]) -> Optional[type[Any]]: + current = direct_class_annotation_type(root_class_type) + if current is None: + return None + + for member_name in member_names: + annotation = _class_member_annotation(current, member_name) + current = direct_class_annotation_type(annotation) + if current is None: + return None + + return current + + +def _class_member_annotation(class_type: type[Any], member_name: str) -> Any: + annotations = getattr(class_type, '__annotations__', {}) + annotation = annotations.get(member_name) + if annotation is not None and not isinstance(annotation, str): + return annotation + try: + return inspect.get_annotations(class_type, eval_str=True).get(member_name) + except Exception: + return annotation + + +def python_class_requirement_for_member_assignment(descriptor: data.Data, member_name: str) -> Optional[str]: + """Return a message when ``descriptor.member_name = ...`` needs ``PythonClass`` semantics. + + ``Structure`` models a fixed by-value layout that is marshalled as a C + struct. Direct assignment to a non-array field or creation of a new field + changes the Python object state instead of mutating array contents inside a + fixed layout, so those writes require the by-reference ``PythonClass`` path. + """ + if isinstance(descriptor, PythonClass) or not isinstance(descriptor, data.Structure): + return None + + member = member_descriptor(descriptor, member_name) + if member is None: + return (f'Creating field "{member_name}" on by-value Structure "{descriptor.name}" requires ' + 'PythonClass semantics. Use PythonClass when dynamic field creation must be preserved.') + + if isinstance(member, (data.Array, data.View, data.Reference)): + return None + + return (f'Assigning to non-array field "{member_name}" on by-value Structure "{descriptor.name}" ' + 'requires PythonClass semantics. Use PythonClass when field rebinding must affect the original ' + 'Python object.') + + +def descriptor_from_structure(structure: Any) -> Optional[data.Data]: + """Build a transient Python container descriptor for a tuple or list structure.""" + if isinstance(structure, data.Data): + return clone_descriptor(structure) + + if not isinstance(structure, (list, tuple)): + return None + + dtype = dtypes.pyobject() + if structure: + first = structure[0] + if all( + isinstance(element, data.Scalar) and isinstance(first, data.Scalar) and element.dtype == first.dtype + for element in structure): + dtype = first.dtype + elif all(isinstance(element, data.Scalar) for element in structure): + if first.dtype != dtypes.pyobject() and not any(element.dtype == dtypes.pyobject() + for element in structure[1:]): + dtype = first.dtype + for element in structure[1:]: + dtype = dtypes.result_type_of(dtype, element.dtype) + + descriptor_type = PythonList if isinstance(structure, list) else PythonTuple + return descriptor_type(dtype=dtype, shape=(len(structure), ), transient=True) + + +def bind_target_structure(target: ast.AST, structure: Any, bind_name: Callable[[str, Any], None]) -> bool: + """Walk a destructuring target and invoke *bind_name* for each bound name.""" + if isinstance(target, ast.Name): + bind_name(target.id, structure) + return True + + if isinstance(target, ast.Starred): + if not isinstance(structure, list): + structure = list(structure) if isinstance(structure, tuple) else [structure] + return bind_target_structure(target.value, structure, bind_name) + + if isinstance(target, (ast.Tuple, ast.List)) and isinstance(structure, (list, tuple)): + starred_indices = [index for index, element in enumerate(target.elts) if isinstance(element, ast.Starred)] + if len(starred_indices) > 1: + return False + if not starred_indices: + if len(target.elts) != len(structure): + return False + return all( + bind_target_structure(subtarget, substructure, bind_name) + for subtarget, substructure in zip(target.elts, structure)) + + starred_index = starred_indices[0] + if len(structure) < len(target.elts) - 1: + return False + + prefix_targets = target.elts[:starred_index] + suffix_targets = target.elts[starred_index + 1:] + prefix_structures = structure[:starred_index] + suffix_structures = structure[len(structure) - len(suffix_targets):] + middle_structure = list(structure[starred_index:len(structure) - len(suffix_targets)]) + + return all( + bind_target_structure(subtarget, substructure, bind_name) + for subtarget, substructure in zip(prefix_targets, prefix_structures)) and bind_target_structure( + target.elts[starred_index], middle_structure, bind_name) and all( + bind_target_structure(subtarget, substructure, bind_name) + for subtarget, substructure in zip(suffix_targets, suffix_structures)) + + return False diff --git a/dace/frontend/python/schedule_tree/tuple_assignment.py b/dace/frontend/python/schedule_tree/tuple_assignment.py new file mode 100644 index 0000000000..de4eabe0f9 --- /dev/null +++ b/dace/frontend/python/schedule_tree/tuple_assignment.py @@ -0,0 +1,375 @@ +"""Tuple and list assignment lowering for the direct schedule-tree frontend.""" + +from __future__ import annotations + +import ast +import copy +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from dace.frontend.python import astutils + +_CONTAINER_INIT_ATTR = '_schedule_tree_container_init_only' +_ELEMENT_ASSIGNMENT_ATTR = '_schedule_tree_tuple_element_assignment' + + +@dataclass +class _PackedSequence: + container: ast.Name + elements: List[Union[ast.AST, '_PackedSequence']] + is_list: bool + + +_SourceValue = Union[ast.AST, _PackedSequence] + + +def is_container_initialization(node: ast.AST) -> bool: + """Return True when *node* only establishes tuple/list descriptor metadata.""" + return bool(getattr(node, _CONTAINER_INIT_ATTR, False)) + + +def is_tuple_element_assignment(node: ast.AST) -> bool: + """Return True when *node* copies one packed tuple/list element.""" + return bool(getattr(node, _ELEMENT_ASSIGNMENT_ATTR, False)) + + +class ScheduleTreeTupleAssignmentLowerer(ast.NodeTransformer): + """Lower tuple/list packing and unpacking into element assignments. + + The pass keeps Python assignment ordering explicit by first materializing the + literal right-hand side into fresh element temporaries, then assigning + destructured targets from those temporaries. For example, ``A, B = B, A`` + becomes a metadata-only tuple initializer, two element assignments, then two + ordinary name assignments from the frozen element values. Destructuring from + non-literal values, including function returns, is left to the regular DaCe + frontend lowering. + """ + + def __init__(self) -> None: + self._packed_sequences: Dict[str, _PackedSequence] = {} + self._used_names: set[str] = set() + self._temp_counter = 0 + + def visit_Module(self, node: ast.Module) -> ast.AST: + self._seed_used_names(node) + saved = self._packed_sequences + self._packed_sequences = {} + node.body = self._rewrite_body(node.body) + self._packed_sequences = saved + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: + self._seed_used_names(node) + saved = self._packed_sequences + self._packed_sequences = {} + node.body = self._rewrite_body(node.body) + self._packed_sequences = saved + return node + + if hasattr(ast, 'AsyncFunctionDef'): + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: + self._seed_used_names(node) + saved = self._packed_sequences + self._packed_sequences = {} + node.body = self._rewrite_body(node.body) + self._packed_sequences = saved + return node + + def visit_Assign(self, node: ast.Assign) -> ast.AST: + if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name) and self._is_sequence_literal(node.value): + return self._pack_named_sequence(node.targets[0].id, node.value, node) + + if any(isinstance(target, (ast.Tuple, ast.List)) for target in node.targets): + source = self._assignment_source(node.value, node) + if source is None: + self._invalidate_targets(node.targets) + return node + + prefix, source_value = source + lowered: List[ast.stmt] = list(prefix) + for target in node.targets: + if isinstance(target, (ast.Tuple, ast.List)): + assignments = self._lower_destructuring_target(target, source_value, node) + if assignments is None: + self._invalidate_target(target) + return node + lowered.extend(assignments) + else: + lowered.append( + ast.copy_location( + ast.Assign(targets=[astutils.copy_tree(target)], value=self._source_expr(source_value)), + node)) + self._invalidate_target(target) + return lowered + + node.value = self._rewrite_expression(node.value) + for target in node.targets: + self._invalidate_target(target) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST: + if node.value is not None: + node.value = self._rewrite_expression(node.value) + self._invalidate_target(node.target) + return node + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.AST: + node.value = self._rewrite_expression(node.value) + self._invalidate_target(node.target) + return node + + def visit_Return(self, node: ast.Return) -> ast.AST: + return node + + def visit_Expr(self, node: ast.Expr) -> ast.AST: + node.value = self._rewrite_expression(node.value) + return node + + def visit_If(self, node: ast.If) -> ast.AST: + node.test = self._rewrite_expression(node.test) + node.body = self._rewrite_child_body(node.body) + node.orelse = self._rewrite_child_body(node.orelse) + self._invalidate_assigned_names(node.body + node.orelse) + return node + + def visit_While(self, node: ast.While) -> ast.AST: + node.test = self._rewrite_expression(node.test) + node.body = self._rewrite_child_body(node.body) + node.orelse = self._rewrite_child_body(node.orelse) + self._invalidate_assigned_names(node.body + node.orelse) + return node + + def visit_For(self, node: ast.For) -> ast.AST: + node.iter = self._rewrite_expression(node.iter) + self._invalidate_target(node.target) + node.body = self._rewrite_child_body(node.body) + node.orelse = self._rewrite_child_body(node.orelse) + return node + + if hasattr(ast, 'AsyncFor'): + + def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AST: + node.iter = self._rewrite_expression(node.iter) + self._invalidate_target(node.target) + node.body = self._rewrite_child_body(node.body) + node.orelse = self._rewrite_child_body(node.orelse) + return node + + def visit_With(self, node: ast.With) -> ast.AST: + for item in node.items: + item.context_expr = self._rewrite_expression(item.context_expr) + if item.optional_vars is not None: + self._invalidate_target(item.optional_vars) + node.body = self._rewrite_child_body(node.body) + return node + + if hasattr(ast, 'AsyncWith'): + + def visit_AsyncWith(self, node: ast.AsyncWith) -> ast.AST: + for item in node.items: + item.context_expr = self._rewrite_expression(item.context_expr) + if item.optional_vars is not None: + self._invalidate_target(item.optional_vars) + node.body = self._rewrite_child_body(node.body) + return node + + def visit_Try(self, node: ast.Try) -> ast.AST: + node.body = self._rewrite_child_body(node.body) + node.orelse = self._rewrite_child_body(node.orelse) + node.finalbody = self._rewrite_child_body(node.finalbody) + for handler in node.handlers: + handler.body = self._rewrite_child_body(handler.body) + self._invalidate_assigned_names(node.body + node.orelse + node.finalbody) + return node + + def _rewrite_body(self, body: Sequence[ast.stmt]) -> List[ast.stmt]: + result: List[ast.stmt] = [] + for statement in body: + rewritten = self.visit(statement) + if rewritten is None: + continue + if isinstance(rewritten, list): + result.extend(rewritten) + else: + result.append(rewritten) + return result + + def _rewrite_child_body(self, body: Sequence[ast.stmt]) -> List[ast.stmt]: + saved = self._packed_sequences + self._packed_sequences = copy.deepcopy(saved) + rewritten = self._rewrite_body(body) + self._packed_sequences = saved + return rewritten + + def _pack_named_sequence(self, name: str, value: ast.AST, template: ast.AST) -> List[ast.stmt]: + statements, sequence = self._pack_sequence(name, value, template) + self._packed_sequences[name] = sequence + return statements + + def _pack_sequence(self, name: str, value: ast.AST, template: ast.AST) -> tuple[List[ast.stmt], _PackedSequence]: + init = ast.copy_location( + ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=astutils.copy_tree(value)), template) + setattr(init, _CONTAINER_INIT_ATTR, True) + + statements: List[ast.stmt] = [init] + elements: List[Union[ast.AST, _PackedSequence]] = [] + for index, element in enumerate(value.elts): + if self._is_sequence_literal(element): + element_name = self._fresh_name(f'{name}_{index}') + nested_statements, nested_sequence = self._pack_sequence(element_name, element, template) + statements.extend(nested_statements) + elements.append(nested_sequence) + continue + + element_name = self._fresh_name(f'{name}_{index}') + element_target = ast.Name(id=element_name, ctx=ast.Store()) + element_value = self._rewrite_expression(element) + statements.append(self._element_assignment(element_target, element_value, template)) + elements.append(ast.Name(id=element_name, ctx=ast.Load())) + + sequence = _PackedSequence(container=ast.Name(id=name, ctx=ast.Load()), + elements=elements, + is_list=isinstance(value, ast.List)) + return (statements, sequence) + + def _assignment_source(self, value: ast.AST, template: ast.AST) -> Optional[Tuple[List[ast.stmt], _SourceValue]]: + if isinstance(value, ast.Name): + packed = self._packed_sequences.get(value.id) + if packed is not None: + return ([], copy.deepcopy(packed)) + if self._is_sequence_literal(value): + temp_name = self._fresh_name('__stree_unpack_tmp') + statements, packed = self._pack_sequence(temp_name, value, template) + self._packed_sequences[temp_name] = packed + return (statements, packed) + return None + + def _lower_destructuring_target(self, target: ast.AST, source: _SourceValue, + template: ast.AST) -> Optional[List[ast.stmt]]: + if isinstance(target, ast.Name): + value = self._source_expr(source) + if isinstance(source, _PackedSequence): + self._packed_sequences[target.id] = copy.deepcopy(source) + else: + self._packed_sequences.pop(target.id, None) + return [self._element_assignment(astutils.copy_tree(target), value, template)] + + if not isinstance(target, (ast.Tuple, ast.List)): + self._invalidate_target(target) + return [self._element_assignment(astutils.copy_tree(target), self._source_expr(source), template)] + + if any(isinstance(element, ast.Starred) for element in target.elts): + return None + + elements = self._source_elements(source, len(target.elts)) + if elements is None: + return None + + result: List[ast.stmt] = [] + for subtarget, subsource in zip(target.elts, elements): + lowered = self._lower_destructuring_target(subtarget, subsource, template) + if lowered is None: + return None + result.extend(lowered) + return result + + def _source_elements(self, source: _SourceValue, expected_length: int) -> Optional[List[_SourceValue]]: + if isinstance(source, _PackedSequence): + if len(source.elements) != expected_length: + return None + return [astutils.copy_tree(element) for element in source.elements] + + return [ + ast.Subscript(value=astutils.copy_tree(source), slice=ast.Constant(value=index), ctx=ast.Load()) + for index in range(expected_length) + ] + + @staticmethod + def _source_expr(source: _SourceValue) -> ast.AST: + if isinstance(source, _PackedSequence): + return astutils.copy_tree(source.container) + return astutils.copy_tree(source) + + @staticmethod + def _element_assignment(target: ast.AST, value: ast.AST, template: ast.AST) -> ast.Assign: + assignment = ast.copy_location(ast.Assign(targets=[target], value=value), template) + setattr(assignment, _ELEMENT_ASSIGNMENT_ATTR, True) + return assignment + + def _rewrite_expression(self, node: ast.AST) -> ast.AST: + outer = self + + class _ExpressionRewriter(ast.NodeTransformer): + + def visit_Subscript(self, subscript: ast.Subscript) -> ast.AST: + subscript = self.generic_visit(subscript) + replacement = outer._packed_subscript_replacement(subscript) + return replacement if replacement is not None else subscript + + return _ExpressionRewriter().visit(astutils.copy_tree(node)) + + def _packed_subscript_replacement(self, node: ast.Subscript) -> Optional[ast.AST]: + if not isinstance(node.value, ast.Name): + return None + packed = self._packed_sequences.get(node.value.id) + if packed is None: + return None + index = self._constant_int(node.slice) + if index is None or index < 0 or index >= len(packed.elements): + return None + element = packed.elements[index] + return self._source_expr(element) + + @staticmethod + def _constant_int(node: ast.AST) -> Optional[int]: + if isinstance(node, ast.Constant) and isinstance(node.value, int) and not isinstance(node.value, bool): + return node.value + if isinstance(node, ast.BinOp): + left = ScheduleTreeTupleAssignmentLowerer._constant_int(node.left) + right = ScheduleTreeTupleAssignmentLowerer._constant_int(node.right) + if left is None or right is None: + return None + if isinstance(node.op, ast.Add): + return left + right + if isinstance(node.op, ast.Sub): + return left - right + return None + + @staticmethod + def _is_sequence_literal(node: ast.AST) -> bool: + return isinstance(node, + (ast.Tuple, ast.List)) and not any(isinstance(element, ast.Starred) for element in node.elts) + + def _invalidate_targets(self, targets: Sequence[ast.AST]) -> None: + for target in targets: + self._invalidate_target(target) + + def _invalidate_target(self, target: ast.AST) -> None: + for child in ast.walk(target): + if isinstance(child, ast.Name) and isinstance(child.ctx, ast.Store): + self._packed_sequences.pop(child.id, None) + + def _invalidate_assigned_names(self, statements: Sequence[ast.stmt]) -> None: + for statement in statements: + for child in ast.walk(statement): + if isinstance(child, ast.Name) and isinstance(child.ctx, ast.Store): + self._packed_sequences.pop(child.id, None) + + def _fresh_name(self, prefix: str) -> str: + candidate = prefix + while candidate in self._used_names or candidate in self._packed_sequences: + self._temp_counter += 1 + candidate = f'{prefix}_{self._temp_counter}' + self._used_names.add(candidate) + return candidate + + def _seed_used_names(self, node: ast.AST) -> None: + for child in ast.walk(node): + if isinstance(child, ast.Name): + self._used_names.add(child.id) + + +def lower_tuple_assignments(parsed_ast: ast.AST) -> ast.AST: + """Lower tuple/list packing and destructuring assignments in *parsed_ast*.""" + return ScheduleTreeTupleAssignmentLowerer().visit(astutils.copy_tree(parsed_ast)) diff --git a/dace/frontend/python/schedule_tree/type_inference.py b/dace/frontend/python/schedule_tree/type_inference.py new file mode 100644 index 0000000000..1af30910dc --- /dev/null +++ b/dace/frontend/python/schedule_tree/type_inference.py @@ -0,0 +1,1469 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Static type inference helpers for the direct Python schedule-tree frontend.""" + +import ast +import collections.abc as cabc +import copy +import inspect +import numbers +import numpy as np +from dataclasses import dataclass +from typing import Any, Dict, Iterable as TypingIterable, Iterator as TypingIterator, List, Optional, Sequence, Tuple, \ + get_args, get_origin + +from dace import data, dtypes, symbolic, subsets +from dace.data.pydata import PythonClass, PythonDict, PythonList, PythonTuple +from dace.frontend.common import op_repository as oprepo +from dace.frontend.python import astutils, memlet_parser +from dace.frontend.python.schedule_tree.array_literal_support import infer_array_literal_descriptor +from dace.frontend.python.schedule_tree.dict_support import DictSupportContext, DictSupportLibrary, StaticDictBinding +from dace.frontend.python.schedule_tree.match_support import UnsupportedMatchPatternError, lower_match_to_statements +from dace.frontend.python.schedule_tree.structure_support import bind_target_structure, descriptor_from_structure, \ + direct_class_annotation_type, member_descriptor, nested_direct_class_owner, \ + python_class_requirement_for_member_assignment +from dace.frontend.python.schedule_tree.static_evaluation import UNRESOLVED, try_resolve_static_value +from dace.sdfg.type_inference import infer_expr_type + + +@dataclass +class _Binding: + descriptor: Optional[data.Data] + kind: str = 'value' + structure: Optional[Any] = None + + +def _clone_descriptor(descriptor: data.Data) -> data.Data: + return copy.deepcopy(descriptor) + + +def _clone_binding(binding: _Binding) -> _Binding: + descriptor = _clone_descriptor(binding.descriptor) if binding.descriptor is not None else None + return _Binding(descriptor=descriptor, kind=binding.kind, structure=copy.deepcopy(binding.structure)) + + +def _normalize_inferred_structure(result: Any) -> Optional[Any]: + if isinstance(result, data.Data): + descriptor = _clone_descriptor(result) + descriptor.transient = True + return descriptor + if isinstance(result, tuple): + elements = [] + for element in result: + normalized = _normalize_inferred_structure(element) + if normalized is None and element is not None: + return None + elements.append(normalized) + return tuple(elements) + if isinstance(result, list): + elements = [] + for element in result: + normalized = _normalize_inferred_structure(element) + if normalized is None and element is not None: + return None + elements.append(normalized) + return elements + return None + + +def _binding_from_inference_result(result: Any) -> Optional[_Binding]: + if result is None: + return None + + if isinstance(result, data.Data): + descriptor = _clone_descriptor(result) + descriptor.transient = True + kind = 'scalar' if isinstance(descriptor, data.Scalar) else 'container' + structure = descriptor if isinstance(descriptor, data.Scalar) else None + return _Binding(descriptor=descriptor, kind=kind, structure=structure) + + structure = _normalize_inferred_structure(result) + if structure is None: + return None + if isinstance(structure, (tuple, list)) and len(structure) == 0: + return _Binding(descriptor=None, kind='value', structure=structure) + + descriptor = descriptor_from_structure(structure) + if descriptor is None: + return None + descriptor.transient = True + return _Binding(descriptor=descriptor, kind='container', structure=structure) + + +def _resolve_ufunc_inference_target(node: ast.Call, env: Dict[str, Any]) -> Optional[Tuple[str, str]]: + func_value = try_resolve_static_value(node.func, env) + if isinstance(func_value, np.ufunc): + return 'ufunc', func_value.__name__ + + if not isinstance(node.func, ast.Attribute): + return None + + owner_value = try_resolve_static_value(node.func.value, env) + if isinstance(owner_value, np.ufunc) and node.func.attr in {'reduce', 'accumulate', 'outer'}: + return node.func.attr, owner_value.__name__ + + return None + + +def _unparse(node: ast.AST) -> str: + return astutils.unparse(node) + + +def _normalize_dtype(dtype: Any) -> Optional[dtypes.typeclass]: + if isinstance(dtype, dtypes.typeclass): + return dtype + if isinstance(dtype, data.Data): + return dtype.dtype + if dtype in (int, float, complex, bool): + return dtypes.typeclass(dtype) + if dtype is str: + return dtypes.string + try: + return dtypes.typeclass(dtype) + except (KeyError, TypeError, ValueError): + return None + + +def _pyobject_scalar_descriptor() -> data.Scalar: + return data.Scalar(dtypes.pyobject(), transient=True) + + +def _string_scalar_descriptor() -> data.Scalar: + return data.Scalar(dtypes.string, transient=True) + + +def _is_scalar_subscript(node: ast.Subscript, subset: subsets.Range, new_axes: Sequence[int], + arrdims: Dict[int, str]) -> bool: + if new_axes or arrdims: + return False + if isinstance(node.slice, ast.Slice): + return False + if isinstance(node.slice, ast.Tuple): + for element in node.slice.elts: + if isinstance(element, ast.Slice): + return False + if isinstance(element, ast.Constant) and (element.value is None or element.value is Ellipsis): + return False + for (start, end, step), tile in zip(subset.ranges, subset.tile_sizes): + if tile != 1 or step != 1 or start != end: + return False + return True + + +def _infer_static_subscript_descriptor(descriptor: data.Data, node: ast.Subscript, + evaluation_context: Dict[str, Any]) -> Optional[data.Data]: + if not hasattr(descriptor, 'shape') or not hasattr(descriptor, 'dtype'): + return None + + index_value = try_resolve_static_value(node.slice, evaluation_context) + if index_value is UNRESOLVED: + return None + + result_shape = _infer_static_subscript_shape(tuple(descriptor.shape), index_value) + if result_shape is None: + return None + if not result_shape: + return data.Scalar(descriptor.dtype, transient=True) + return data.Array(descriptor.dtype, list(result_shape), transient=True) + + +def _infer_static_subscript_shape(array_shape: Tuple[Any, ...], index_value: Any) -> Optional[Tuple[Any, ...]]: + expanded = _expand_static_indices(index_value, len(array_shape)) + if expanded is None: + return None + + chunks: List[Any] = [] + advanced_shapes: List[Tuple[int, ...]] = [] + advanced_groups = 0 + in_advanced_group = False + array_dim = 0 + + for index in expanded: + if index is None: + chunks.append((1, )) + in_advanced_group = False + continue + + if array_dim >= len(array_shape): + return None + + if _is_static_integer_index(index): + array_dim += 1 + in_advanced_group = False + continue + + advanced_shape = _static_advanced_index_shape(index) + if advanced_shape is not None: + advanced_shapes.append(advanced_shape) + if not in_advanced_group: + chunks.append('ADV') + advanced_groups += 1 + in_advanced_group = True + array_dim += 1 + continue + + if not isinstance(index, slice): + return None + + slice_dim = _static_slice_result_dim(array_shape[array_dim], index) + if slice_dim is None: + return None + chunks.append((slice_dim, )) + array_dim += 1 + in_advanced_group = False + + while array_dim < len(array_shape): + chunks.append((array_shape[array_dim], )) + array_dim += 1 + + if not advanced_shapes: + return tuple(dim for chunk in chunks for dim in chunk) + + broadcast_shape = _broadcast_static_shapes(advanced_shapes) + if broadcast_shape is None: + return None + + if advanced_groups == 1: + output_shape: List[Any] = [] + inserted = False + for chunk in chunks: + if chunk == 'ADV': + if not inserted: + output_shape.extend(broadcast_shape) + inserted = True + continue + output_shape.extend(chunk) + return tuple(output_shape) + + output_shape = list(broadcast_shape) + for chunk in chunks: + if chunk == 'ADV': + continue + output_shape.extend(chunk) + return tuple(output_shape) + + +def _expand_static_indices(index_value: Any, rank: int) -> Optional[List[Any]]: + indices = list(index_value) if isinstance(index_value, tuple) else [index_value] + if sum(1 for index in indices if index is Ellipsis) > 1: + return None + + consumed = sum(1 for index in indices if index is not None and index is not Ellipsis) + expanded: List[Any] = [] + ellipsis_seen = False + for index in indices: + if index is Ellipsis: + ellipsis_seen = True + expanded.extend([slice(None)] * max(rank - consumed, 0)) + continue + expanded.append(index) + + if not ellipsis_seen: + expanded.extend([slice(None)] * max(rank - consumed, 0)) + + return expanded + + +def _is_static_integer_index(index: Any) -> bool: + return isinstance(index, numbers.Integral) and not isinstance(index, bool) + + +def _static_advanced_index_shape(index: Any) -> Optional[Tuple[int, ...]]: + if isinstance(index, np.ndarray): + if index.ndim == 0 or index.dtype == bool: + return None + return tuple(index.shape) + + if isinstance(index, list): + return _static_nested_sequence_shape(index) + + if isinstance(index, tuple): + nested_shape = _static_nested_sequence_shape(list(index)) + if nested_shape is None: + return None + return nested_shape + + return None + + +def _static_nested_sequence_shape(value: List[Any]) -> Optional[Tuple[int, ...]]: + if not value: + return (0, ) + first = value[0] + if isinstance(first, (list, tuple)): + inner_shape = _static_nested_sequence_shape(list(first)) + if inner_shape is None: + return None + for element in value[1:]: + if not isinstance(element, (list, tuple)): + return None + if _static_nested_sequence_shape(list(element)) != inner_shape: + return None + return (len(value), ) + inner_shape + + if any(isinstance(element, (list, tuple)) for element in value[1:]): + return None + if any(not _is_static_integer_index(element) for element in value): + return None + return (len(value), ) + + +def _static_slice_result_dim(dim_size: Any, index: slice) -> Optional[Any]: + if index == slice(None): + return dim_size + + step = 1 if index.step is None else index.step + try: + if step == 0: + return None + except TypeError: + pass + + step_is_negative = (step < 0) == True + step_is_positive = (step > 0) == True + if not step_is_negative and not step_is_positive: + return None + + if index.start is None: + start = dim_size - 1 if step_is_negative else 0 + else: + start = index.start + + if index.stop is None: + stop = -1 if step_is_negative else dim_size + else: + stop = index.stop + + try: + if (start < 0) == True: + start += dim_size + except TypeError: + pass + try: + if (stop < 0) == True: + stop += dim_size + except TypeError: + pass + + end = stop + 1 if step_is_negative else stop - 1 + return subsets.Range([(start, end, step)]).size()[0] + + +def _broadcast_static_shapes(shapes: Sequence[Tuple[int, ...]]) -> Optional[Tuple[int, ...]]: + result: List[int] = [] + max_rank = max(len(shape) for shape in shapes) + for axis in range(max_rank): + axis_sizes = [] + for shape in shapes: + offset = axis - (max_rank - len(shape)) + axis_sizes.append(1 if offset < 0 else shape[offset]) + size = max(axis_sizes) + if any(axis_size not in {1, size} for axis_size in axis_sizes): + return None + result.append(size) + return tuple(result) + + +def _should_fallback_to_pyobject_scalar(node: ast.AST, value: Any = UNRESOLVED) -> bool: + if value is None or isinstance(value, (str, bytes, numbers.Number, bool, type(Ellipsis))): + return False + return isinstance(node, (ast.Await, ast.Attribute, ast.BinOp, ast.BoolOp, ast.Call, ast.Compare, ast.FormattedValue, + ast.IfExp, ast.JoinedStr, ast.Name, ast.NamedExpr, ast.UnaryOp, ast.Yield, ast.YieldFrom)) + + +class ScheduleTreeTypeInference(ast.NodeVisitor): + """Conservative binding inference for the direct schedule-tree frontend.""" + + def __init__(self, + globals_env: Dict[str, Any], + argtypes: Dict[str, data.Data], + seed_bindings: Optional[Dict[str, _Binding]] = None) -> None: + self.globals = copy.copy(globals_env) + self.dict_support = DictSupportLibrary() + self.bindings: Dict[str, _Binding] = { + name: _Binding(descriptor=_clone_descriptor(descriptor), kind='container') + for name, descriptor in argtypes.items() + } + for name, binding in (seed_bindings or {}).items(): + self.bindings[name] = _clone_binding(binding) + self.results: Dict[str, _Binding] = {} + self.annotated_class_types: Dict[str, type[Any]] = {} + + def infer(self, program: ast.AST) -> Dict[str, _Binding]: + if isinstance(program, ast.Module): + program = program.body[0] if program.body else None + if not isinstance(program, (ast.FunctionDef, ast.AsyncFunctionDef)): + return {} + self._initialize_direct_class_annotations(program) + for stmt in program.body: + self.visit(stmt) + return {name: _clone_binding(binding) for name, binding in self.results.items()} + + def visit_Assign(self, node: ast.Assign) -> None: + for target in node.targets: + self._infer_assignment(target, node.value, None) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + annotated_descriptor = self._evaluate_descriptor(node.annotation) + class_type = self._evaluate_annotation_class_type(node.annotation) + if class_type is not None and isinstance(node.target, ast.Name): + self.annotated_class_types[node.target.id] = class_type + if node.value is None: + if isinstance(node.target, ast.Name) and annotated_descriptor is not None: + self._store_binding(node.target.id, annotated_descriptor) + return + self._infer_assignment(node.target, node.value, annotated_descriptor) + + def visit_AugAssign(self, node: ast.AugAssign) -> None: + if isinstance(node.target, ast.Name) and node.target.id not in self.bindings: + scalar_descriptor = self._infer_scalar_descriptor(node.value, None) + if scalar_descriptor is not None: + self._store_binding(node.target.id, scalar_descriptor) + self.generic_visit(node) + + def visit_Expr(self, node: ast.Expr) -> None: + self._apply_method_self_descriptor_side_effect(node.value) + + def visit_For(self, node: ast.For) -> None: + self._bind_loop_target(node.target) + for stmt in node.body: + self.visit(stmt) + for stmt in node.orelse: + self.visit(stmt) + + def visit_If(self, node: ast.If) -> None: + before = {name: _clone_binding(binding) for name, binding in self.bindings.items()} + then_bindings = self._visit_branch(node.body, before) + else_bindings = self._visit_branch(node.orelse, before) + self._merge_branch_bindings(before, then_bindings, else_bindings) + + def visit_Match(self, node: ast.Match) -> None: + try: + lowered = lower_match_to_statements(node, astutils.copy_tree(node.subject)) + except UnsupportedMatchPatternError: + return + for stmt in lowered: + self.visit(stmt) + + def visit_While(self, node: ast.While) -> None: + for stmt in node.body: + self.visit(stmt) + for stmt in node.orelse: + self.visit(stmt) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + # Nested function-local bindings must not leak into the enclosing + # schedule-tree type-inference scope. + return + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + return + + def _visit_branch(self, body: Sequence[ast.AST], initial: Dict[str, _Binding]) -> Dict[str, _Binding]: + previous = self.bindings + self.bindings = {name: _clone_binding(binding) for name, binding in initial.items()} + try: + for stmt in body: + self.visit(stmt) + return {name: _clone_binding(binding) for name, binding in self.bindings.items()} + finally: + self.bindings = previous + + def _merge_branch_bindings(self, before: Dict[str, _Binding], then_bindings: Dict[str, _Binding], + else_bindings: Dict[str, _Binding]) -> None: + merged = {name: _clone_binding(binding) for name, binding in before.items()} + candidate_names = (set(then_bindings.keys()) | set(else_bindings.keys())) - set(before.keys()) + for name in candidate_names: + left = then_bindings.get(name) + right = else_bindings.get(name) + if left is None or right is None: + continue + if self._compatible_bindings(left, right): + merged[name] = _clone_binding(left) + self.results[name] = _clone_binding(left) + self.bindings = merged + + def _compatible_bindings(self, left: _Binding, right: _Binding) -> bool: + if left.kind != right.kind: + return False + if not self._compatible_descriptors(left.descriptor, right.descriptor): + return False + return self._compatible_structures(left.structure, right.structure) + + def _compatible_descriptors(self, left: Optional[data.Data], right: Optional[data.Data]) -> bool: + if left is None or right is None: + return left is right + if type(left) is not type(right): + return False + if hasattr(left, 'is_equivalent'): + return left.is_equivalent(right) + return left == right + + def _compatible_structures(self, left: Any, right: Any) -> bool: + if left is None or right is None: + return left is right + if isinstance(left, StaticDictBinding) and isinstance(right, StaticDictBinding): + if set(left.entries.keys()) != set(right.entries.keys()): + return False + return all(self._compatible_descriptors(left.entries[key], right.entries[key]) for key in left.entries) + if isinstance(left, data.Data) and isinstance(right, data.Data): + return self._compatible_descriptors(left, right) + if isinstance(left, list) and isinstance(right, list) and len(left) == len(right): + return all(self._compatible_structures(lval, rval) for lval, rval in zip(left, right)) + if isinstance(left, tuple) and isinstance(right, tuple) and len(left) == len(right): + return all(self._compatible_structures(lval, rval) for lval, rval in zip(left, right)) + return False + + def _infer_assignment(self, target: ast.AST, value: ast.AST, annotated_descriptor: Optional[data.Data]) -> None: + binding = self._infer_binding(value, annotated_descriptor) + if binding is not None: + if isinstance(target, ast.Name): + self._store_binding(target.id, binding.descriptor, kind=binding.kind, structure=binding.structure) + return + if isinstance(target, (ast.Tuple, ast.List)) and binding.structure is not None: + self._bind_target_structure(target, binding.structure) + self._ensure_pythonclass_for_direct_class_annotation(target) + self._update_dict_subscript_binding(target, value) + + def _ensure_pythonclass_for_direct_class_annotation(self, target: ast.AST) -> None: + if not isinstance(target, ast.Attribute): + return + + owner_binding = self._resolve_binding(target.value) + if owner_binding is None or owner_binding.descriptor is None: + return + if python_class_requirement_for_member_assignment(owner_binding.descriptor, target.attr) is None: + return + + root = self._attribute_root_and_members(target.value) + if root is None: + return + root_name, member_names = root + class_type = self.annotated_class_types.get(root_name) + if class_type is None: + return + if nested_direct_class_owner(class_type, member_names) is None: + return + + binding = self.bindings.get(root_name) + if binding is None or binding.descriptor is None or isinstance(binding.descriptor, PythonClass): + return + + try: + python_class_descriptor = PythonClass.from_class(class_type) + except (TypeError, ValueError): + return + + self._store_binding(root_name, python_class_descriptor, kind=binding.kind) + + def _update_dict_subscript_binding(self, target: ast.AST, value: ast.AST) -> None: + if not isinstance(target, ast.Subscript) or not isinstance(target.value, ast.Name): + return + binding = self.bindings.get(target.value.id) + if binding is None or binding.descriptor is None: + return + dict_binding = binding.structure if isinstance(binding.structure, StaticDictBinding) else None + updated = self.dict_support.infer_assignment_binding(self._dict_support_context(), binding.descriptor, + dict_binding, target.slice, value) + if updated is None: + return + updated_descriptor, updated_binding = updated + self._store_binding(target.value.id, + updated_descriptor, + kind=binding.kind, + structure=updated_binding if updated_binding is not None else None) + + def _infer_binding(self, value: ast.AST, annotated_descriptor: Optional[data.Data]) -> Optional[_Binding]: + binding = self._infer_internal_iterator_binding(value) + if binding is not None: + return binding + + binding = self._resolve_binding(value) + if binding is not None: + return binding + + if isinstance(value, ast.Dict): + descriptor = self.dict_support.infer_literal_descriptor(self._dict_support_context(), value) + structure = self.dict_support.infer_literal_binding(self._dict_support_context(), value) + return _Binding(descriptor=descriptor, kind='container', structure=structure) + + replacement_binding = self._try_replacement_binding_inference(value) + if replacement_binding is not None: + return replacement_binding + + inferred_descriptor = self._infer_descriptor(value) + if inferred_descriptor is not None: + kind = 'scalar' if isinstance(inferred_descriptor, data.Scalar) else 'container' + structure = inferred_descriptor if isinstance(inferred_descriptor, data.Scalar) else None + return _Binding(descriptor=inferred_descriptor, kind=kind, structure=structure) + + scalar_descriptor = self._infer_scalar_descriptor(value, annotated_descriptor) + if scalar_descriptor is not None: + return _Binding(descriptor=scalar_descriptor, kind='scalar', structure=scalar_descriptor) + + return None + + def _resolve_binding(self, value: ast.AST) -> Optional[_Binding]: + if isinstance(value, ast.Name): + if value.id in self.bindings: + return _clone_binding(self.bindings[value.id]) + + external_value = self.globals.get(value.id, UNRESOLVED) + if external_value is not UNRESOLVED: + if symbolic.issymbolic(external_value): + return None + try: + descriptor = _clone_descriptor(data.create_datadescriptor(external_value)) + except Exception: + descriptor = None + if descriptor is not None: + descriptor.transient = False + kind = 'scalar' if isinstance(descriptor, data.Scalar) else 'container' + structure = descriptor if isinstance(descriptor, data.Scalar) else None + binding = _Binding(descriptor=descriptor, kind=kind, structure=structure) + self.bindings[value.id] = _clone_binding(binding) + return binding + + if isinstance(value, ast.Attribute): + base_binding = self._resolve_binding(value.value) + if base_binding is None or base_binding.descriptor is None: + return None + descriptor = member_descriptor(base_binding.descriptor, value.attr) + if descriptor is None: + return None + kind = 'scalar' if isinstance(descriptor, data.Scalar) else 'container' + structure = descriptor if isinstance(descriptor, data.Scalar) else None + return _Binding(descriptor=descriptor, kind=kind, structure=structure) + + if isinstance(value, ast.Subscript): + binding = self._resolve_binding(value.value) + if binding is None or binding.descriptor is None: + return None + if isinstance(binding.descriptor, PythonDict): + descriptor = self.dict_support.infer_subscript_descriptor( + self._dict_support_context(), binding.descriptor, value.slice, + binding.structure if isinstance(binding.structure, StaticDictBinding) else None) + if descriptor is None: + return None + kind = 'scalar' if isinstance(descriptor, data.Scalar) else binding.kind + structure = descriptor if isinstance(descriptor, data.Scalar) else None + return _Binding(descriptor=descriptor, kind=kind, structure=structure) + structure = self._subscript_structure(binding, value.slice) + descriptor = descriptor_from_structure(structure) if structure is not None else None + if descriptor is None: + descriptor = self._subscript_descriptor(binding.descriptor, value) + if descriptor is None: + return None + kind = 'scalar' if isinstance(descriptor, data.Scalar) else binding.kind + return _Binding(descriptor=descriptor, kind=kind, structure=structure) + + if isinstance(value, ast.Call) and isinstance(value.func, ast.Attribute) and value.func.attr == 'reshape': + base_binding = self._resolve_binding(value.func.value) + if base_binding is None or base_binding.descriptor is None: + return None + shape = self._parse_shape(value.args[0]) if value.args else list(base_binding.descriptor.shape) + return _Binding(descriptor=self._make_view_descriptor(base_binding.descriptor, shape), kind='container') + + if isinstance(value, (ast.Tuple, ast.List)): + structure = self._structure_from_expression(value) + if structure is None: + return None + descriptor = descriptor_from_structure(structure) + if descriptor is None: + return None + kind = 'scalar' if isinstance(descriptor, data.Scalar) else 'container' + return _Binding(descriptor=descriptor, kind=kind, structure=structure) + + return None + + def _infer_internal_iterator_binding(self, value: ast.AST) -> Optional[_Binding]: + if not isinstance(value, ast.Call): + return None + helper_name = astutils.rname(value.func) + if helper_name == '__dace_iterator_init' and value.args: + structure = self._infer_iterable_structure(value.args[0]) + if structure is None: + return None + return _Binding(descriptor=descriptor_from_structure(structure), kind='iterator', structure=structure) + if helper_name == '__dace_iterator_next' and value.args and isinstance(value.args[0], ast.Name): + iterator_binding = self.bindings.get(value.args[0].id) + if iterator_binding is None or iterator_binding.structure is None: + return None + structure = (data.Scalar(dtypes.bool, transient=True), copy.deepcopy(iterator_binding.structure)) + return _Binding(descriptor=descriptor_from_structure(structure), kind='iterator-value', structure=structure) + return None + + def _infer_iterable_structure(self, node: ast.AST, env: Optional[Dict[str, Any]] = None) -> Optional[Any]: + env = env or self._evaluation_context() + + if isinstance(node, ast.Call): + call_name = astutils.rname(node.func) + if call_name == 'dace.nounroll' and node.args: + return self._infer_iterable_structure(node.args[0], env) + if call_name == 'zip' and node.args: + elements = [self._infer_iterable_structure(arg, env) for arg in node.args] + if any(element is None for element in elements): + return None + return tuple(elements) + if call_name == 'enumerate' and node.args: + inner = self._infer_iterable_structure(node.args[0], env) + if inner is None: + return None + return (data.Scalar(dtypes.int64, transient=True), inner) + if call_name == 'iter' and node.args: + return self._infer_iterable_structure(node.args[0], env) + return None + + if isinstance(node, ast.Name) and node.id in self.bindings: + return self._element_structure_from_binding(self.bindings[node.id]) + + value = self._safe_eval(node, env) + if value is None: + return None + return self._infer_iterable_structure_from_value(value) + + def _infer_iterable_structure_from_value(self, value: Any) -> Optional[Any]: + if dtypes.is_array(value): + descriptor = _clone_descriptor(data.create_datadescriptor(value)) + descriptor.transient = True + return self._element_structure_from_descriptor(descriptor) + + if isinstance(value, (list, tuple)): + if not value: + return None + structures = [self._structure_from_value(element) for element in value] + return self._merge_structures(structures) + + structure = self._structure_from_iterator_annotation(value, '__iter__', returns_iterator=True) + if structure is not None: + return structure + structure = self._structure_from_iterator_method(value, '__iter__', returns_iterator=True) + if structure is not None: + return structure + structure = self._structure_from_iterator_annotation(value, '__next__', returns_iterator=False) + if structure is not None: + return structure + return self._structure_from_iterator_method(value, '__next__', returns_iterator=False) + + def _structure_from_iterator_annotation(self, value: Any, method_name: str, *, + returns_iterator: bool) -> Optional[Any]: + method = getattr(type(value), method_name, None) + if method is None: + return None + try: + annotation = inspect.signature(method).return_annotation + except (TypeError, ValueError): + return None + if annotation is inspect.Signature.empty: + return None + return self._structure_from_annotation(annotation, returns_iterator=returns_iterator) + + def _structure_from_iterator_method(self, value: Any, method_name: str, *, returns_iterator: bool) -> Optional[Any]: + method = getattr(type(value), method_name, None) + if method is None: + return None + + try: + method_ast, _, _, _ = astutils.function_to_ast(method) + except TypeError: + return None + + if not method_ast.body or not isinstance(method_ast.body[0], ast.FunctionDef): + return None + + function_node = method_ast.body[0] + env = copy.copy(getattr(method, '__globals__', {})) + env.update(self.globals) + if function_node.args.args: + env[function_node.args.args[0].arg] = value + + local_values: Dict[str, Any] = {} + for stmt in function_node.body: + if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name): + evaluated = self._safe_eval(stmt.value, {**env, **local_values}) + if evaluated is not None: + local_values[stmt.targets[0].id] = evaluated + continue + if isinstance(stmt, ast.Return) and stmt.value is not None: + merged_env = {**env, **local_values} + if returns_iterator: + return self._infer_iterable_structure(stmt.value, merged_env) + return self._structure_from_expression(stmt.value, merged_env) + + yielded_structures: List[Any] = [] + for yielded in ast.walk(function_node): + if isinstance(yielded, ast.Yield) and yielded.value is not None: + structure = self._structure_from_expression(yielded.value, {**env, **local_values}) + yielded_structures.append(structure) + elif isinstance(yielded, ast.YieldFrom): + structure = self._infer_iterable_structure(yielded.value, {**env, **local_values}) + yielded_structures.append(structure) + + return self._merge_structures(yielded_structures) + + def _structure_from_annotation(self, annotation: Any, *, returns_iterator: bool) -> Optional[Any]: + origin = get_origin(annotation) + args = get_args(annotation) + + if returns_iterator and origin in {TypingIterator, TypingIterable, cabc.Iterator, cabc.Iterable} and args: + return self._structure_from_annotation(args[0], returns_iterator=False) + + if origin in {tuple, Tuple} and args: + if len(args) == 2 and args[1] is Ellipsis: + element = self._structure_from_annotation(args[0], returns_iterator=False) + return (element, ) if element is not None else None + elements = [self._structure_from_annotation(arg, returns_iterator=False) for arg in args] + if any(element is None for element in elements): + return None + return tuple(elements) + + if origin in {list, List} and args: + element = self._structure_from_annotation(args[0], returns_iterator=False) + return [element] if element is not None else None + + try: + descriptor = _clone_descriptor(data.create_datadescriptor(annotation)) + except Exception: + descriptor = None + + if descriptor is None: + return None + descriptor.transient = True + return descriptor + + def _merge_structures(self, structures: Sequence[Any]) -> Optional[Any]: + filtered = [structure for structure in structures if structure is not None] + if not filtered: + return None + + if all(isinstance(structure, data.Scalar) for structure in filtered): + dtype = filtered[0].dtype + for structure in filtered[1:]: + dtype = dtypes.result_type_of(dtype, structure.dtype) + return data.Scalar(dtype, transient=True) + + first = filtered[0] + if isinstance(first, tuple) and all( + isinstance(structure, tuple) and len(structure) == len(first) for structure in filtered): + elements = [ + self._merge_structures([structure[index] for structure in filtered]) for index in range(len(first)) + ] + if any(element is None for element in elements): + return None + return tuple(elements) + + if isinstance(first, list) and all( + isinstance(structure, list) and len(structure) == len(first) for structure in filtered): + elements = [ + self._merge_structures([structure[index] for structure in filtered]) for index in range(len(first)) + ] + if any(element is None for element in elements): + return None + return elements + + if all( + isinstance(structure, data.Data) and self._compatible_descriptors(first, structure) + for structure in filtered): + return _clone_descriptor(first) + + return None + + def _element_structure_from_binding(self, binding: _Binding) -> Optional[Any]: + if binding.structure is not None and isinstance(binding.structure, (list, tuple)): + if not binding.structure: + return None + return copy.deepcopy(binding.structure[0]) + if binding.descriptor is None: + return None + return self._element_structure_from_descriptor(binding.descriptor) + + def _element_structure_from_descriptor(self, descriptor: data.Data) -> Optional[Any]: + if isinstance(descriptor, data.Scalar): + return None + if isinstance(descriptor, (PythonList, PythonTuple)): + if descriptor.dtype == dtypes.pyobject(): + return None + return data.Scalar(descriptor.dtype, transient=True) + if hasattr(descriptor, 'shape'): + if len(descriptor.shape) <= 1: + return data.Scalar(descriptor.dtype, transient=True) + return self._make_view_descriptor(descriptor, descriptor.shape[1:]) + return None + + def _structure_from_value(self, value: Any) -> Optional[Any]: + if isinstance(value, tuple): + elements = [self._structure_from_value(element) for element in value] + if any(element is None for element in elements): + return None + return tuple(elements) + if isinstance(value, list): + elements = [self._structure_from_value(element) for element in value] + if any(element is None for element in elements): + return None + return elements + try: + descriptor = _clone_descriptor(data.create_datadescriptor(value)) + except Exception: + return None + descriptor.transient = True + return descriptor + + def _structure_from_expression(self, node: ast.AST, env: Optional[Dict[str, Any]] = None) -> Optional[Any]: + env = env or self._evaluation_context() + if isinstance(node, (ast.Tuple, ast.List)): + elements = [self._structure_from_expression(element, env) for element in node.elts] + if any(element is None for element in elements): + return None + return elements if isinstance(node, ast.List) else tuple(elements) + + if isinstance(node, ast.Name) and node.id in self.bindings: + binding = self.bindings[node.id] + if binding.structure is not None: + return copy.deepcopy(binding.structure) + if binding.descriptor is not None: + return _clone_descriptor(binding.descriptor) + + if isinstance(node, ast.Call) and astutils.rname(node.func) not in {'tuple', 'list'}: + return None + + value = self._safe_eval(node, env) + if value is None: + return None + return self._structure_from_value(value) + + def _subscript_structure(self, binding: _Binding, slice_node: ast.AST) -> Optional[Any]: + if binding.structure is None or not isinstance(binding.structure, (list, tuple)): + return None + index_value = self._safe_eval(slice_node, self._evaluation_context()) + if not isinstance(index_value, int): + return None + if index_value < 0 or index_value >= len(binding.structure): + return None + return copy.deepcopy(binding.structure[index_value]) + + def _subscript_descriptor(self, descriptor: data.Data, node: ast.Subscript) -> Optional[data.Data]: + if isinstance(descriptor, (PythonList, PythonTuple)): + if descriptor.dtype == dtypes.pyobject(): + return None + return data.Scalar(descriptor.dtype, transient=True) + + dict_descriptor = self.dict_support.infer_subscript_descriptor(self._dict_support_context(), descriptor, + node.slice) + if dict_descriptor is not None: + return dict_descriptor + + static_descriptor = _infer_static_subscript_descriptor(descriptor, node, self._evaluation_context()) + + try: + subset, new_axes, arrdims = memlet_parser.parse_memlet_subset(descriptor, node, self._evaluation_context()) + except Exception: + return static_descriptor + if _is_scalar_subscript(node, subset, new_axes, arrdims): + return data.Scalar(descriptor.dtype, transient=True) + + if static_descriptor is not None: + if isinstance(static_descriptor, data.Scalar): + return static_descriptor + return self._make_view_descriptor(descriptor, static_descriptor.shape) + + return self._make_view_descriptor(descriptor, subset.size(), new_axes) + + def _infer_known_descriptor(self, node: ast.AST) -> Optional[data.Data]: + binding = self._resolve_binding(node) + if binding is not None and binding.descriptor is not None: + return _clone_descriptor(binding.descriptor) + return self._infer_descriptor(node) + + def _infer_operator_operand(self, node: ast.AST) -> Optional[Any]: + binding = self._resolve_binding(node) + if binding is not None and binding.descriptor is not None: + return _clone_descriptor(binding.descriptor) + + descriptor = self._infer_descriptor(node) + if descriptor is not None: + return descriptor + + value = try_resolve_static_value(node, self._evaluation_context()) + if value is not UNRESOLVED: + return value + + return self._infer_scalar_descriptor(node, None) + + def _infer_binop_descriptor(self, node: ast.BinOp) -> Optional[data.Data]: + left_operand = self._infer_operator_operand(node.left) + right_operand = self._infer_operator_operand(node.right) + if left_operand is None or right_operand is None: + return None + + infer_fn = oprepo.Replacements.get_operator_descriptor_inference( + type(node.op).__name__, left_operand, right_operand) + if infer_fn is None: + return None + try: + return infer_fn(left_operand, right_operand) + except Exception: + return None + + def _infer_unaryop_descriptor(self, node: ast.UnaryOp) -> Optional[data.Data]: + operand = self._infer_operator_operand(node.operand) + if operand is None: + return None + + infer_fn = oprepo.Replacements.get_operator_descriptor_inference(type(node.op).__name__, operand) + if infer_fn is None: + return None + try: + return infer_fn(operand) + except Exception: + return None + + def _infer_boolop_descriptor(self, node: ast.BoolOp) -> Optional[data.Data]: + if len(node.values) == 0: + return None + + current = self._infer_operator_operand(node.values[0]) + if current is None: + return None + + for value in node.values[1:]: + next_operand = self._infer_operator_operand(value) + if next_operand is None: + return None + infer_fn = oprepo.Replacements.get_operator_descriptor_inference( + type(node.op).__name__, current, next_operand) + if infer_fn is None: + return None + try: + current = infer_fn(current, next_operand) + except Exception: + return None + if current is None: + return None + return current + + def _infer_compare_descriptor(self, node: ast.Compare) -> Optional[data.Data]: + if len(node.ops) != 1 or len(node.comparators) != 1: + return None + + left_operand = self._infer_operator_operand(node.left) + right_operand = self._infer_operator_operand(node.comparators[0]) + if left_operand is None or right_operand is None: + return None + + infer_fn = oprepo.Replacements.get_operator_descriptor_inference( + type(node.ops[0]).__name__, left_operand, right_operand) + if infer_fn is None: + return None + try: + return infer_fn(left_operand, right_operand) + except Exception: + return None + + def _infer_descriptor(self, node: ast.AST) -> Optional[data.Data]: + if isinstance(node, ast.Dict): + return self.dict_support.infer_literal_descriptor(self._dict_support_context(), node) + + if isinstance(node, ast.Call): + inferred = infer_array_literal_descriptor(node, + self._infer_descriptor, + self._infer_scalar_descriptor, + self._evaluation_context, + callable_name_resolver=self._resolved_callable_name) + if inferred is not None: + return inferred + + if isinstance(node, ast.Attribute): + binding = self._resolve_binding(node) + if binding is not None and binding.descriptor is not None: + return _clone_descriptor(binding.descriptor) + + if isinstance(node, ast.BinOp): + inferred = self._infer_binop_descriptor(node) + if inferred is not None: + return inferred + + if isinstance(node, ast.UnaryOp): + inferred = self._infer_unaryop_descriptor(node) + if inferred is not None: + return inferred + + if isinstance(node, ast.BoolOp): + inferred = self._infer_boolop_descriptor(node) + if inferred is not None: + return inferred + + if isinstance(node, ast.Compare): + inferred = self._infer_compare_descriptor(node) + if inferred is not None: + return inferred + + if isinstance(node, ast.Call): + # Try the method descriptor-inference registry first (a.sum(), etc.) + if isinstance(node.func, ast.Attribute): + inferred = self._try_method_descriptor_inference(node) + if inferred is not None: + return inferred.descriptor + + inferred = self._try_ufunc_descriptor_inference(node) + if inferred is not None: + return inferred.descriptor + + # Try the free-function descriptor-inference registry (numpy.sum(), etc.) + inferred = self._try_descriptor_inference(node) + if inferred is not None: + return inferred.descriptor + + # Attribute inference (a.T, a.flat, a.real, a.imag, etc.) + if isinstance(node, ast.Attribute): + inferred = self._try_attribute_descriptor_inference(node) + if inferred is not None: + return inferred.descriptor + + return None + + def _infer_scalar_descriptor(self, node: ast.AST, annotated_descriptor: Optional[data.Data]) -> Optional[data.Data]: + if annotated_descriptor is not None and isinstance(annotated_descriptor, data.Scalar): + return _clone_descriptor(annotated_descriptor) + + if isinstance(node, (ast.JoinedStr, ast.FormattedValue)): + return _string_scalar_descriptor() + + scalar_types = { + name: binding.descriptor.dtype + for name, binding in self.bindings.items() + if binding.descriptor is not None and isinstance(binding.descriptor, data.Scalar) + } + try: + inferred_type = infer_expr_type(_unparse(node), scalar_types) + except Exception: + inferred_type = None + if inferred_type is not None: + return data.Scalar(inferred_type, transient=True) + + value = try_resolve_static_value(node, self._evaluation_context()) + if value is not UNRESOLVED and value is not None: + try: + descriptor = _clone_descriptor(data.create_datadescriptor(value)) + except Exception: + descriptor = None + if isinstance(descriptor, data.Scalar): + descriptor.transient = True + return descriptor + + if isinstance(value, numbers.Number) or isinstance(value, bool): + dtype = _normalize_dtype(type(value)) + if dtype is not None: + return data.Scalar(dtype, transient=True) + + if _should_fallback_to_pyobject_scalar(node, value): + return _pyobject_scalar_descriptor() + + if value is UNRESOLVED: + return None + return None + + def _try_replacement_binding_inference(self, node: ast.AST) -> Optional[_Binding]: + if isinstance(node, ast.Call): + if isinstance(node.func, ast.Attribute): + inferred = self._try_method_descriptor_inference(node) + if inferred is not None: + return inferred + inferred = self._try_ufunc_descriptor_inference(node) + if inferred is not None: + return inferred + return self._try_descriptor_inference(node) + + if isinstance(node, ast.Attribute): + return self._try_attribute_descriptor_inference(node) + + return None + + def _try_descriptor_inference(self, node: ast.Call) -> Optional[_Binding]: + """Query the descriptor-inference registry for a call node.""" + call_name = self._resolved_callable_name(node.func) + infer_fn = oprepo.Replacements.get_descriptor_inference(call_name) + if infer_fn is None: + textual_name = astutils.rname(node.func) + if textual_name != call_name: + infer_fn = oprepo.Replacements.get_descriptor_inference(textual_name) + if infer_fn is None: + return None + input_descs, args, kwargs = self._resolve_call_inputs_for_inference(node) + try: + result = infer_fn(input_descs, *args, **kwargs) + except Exception: + return None + return _binding_from_inference_result(result) + + def _resolved_callable_name(self, node: ast.AST) -> str: + textual_name = astutils.rname(node) + resolved = try_resolve_static_value(node, self._evaluation_context()) + if resolved is not UNRESOLVED: + module_name = getattr(resolved, '__module__', None) + callable_name = getattr(resolved, '__name__', None) + if module_name and callable_name and module_name != 'builtins': + return f'{module_name}.{callable_name}' + + if '.' in textual_name: + root_name, suffix = textual_name.split('.', 1) + root_value = try_resolve_static_value(ast.Name(id=root_name, ctx=ast.Load()), self._evaluation_context()) + module_name = getattr(root_value, '__name__', None) if root_value is not UNRESOLVED else None + if module_name is not None: + return f'{module_name}.{suffix}' + + return textual_name + + def _try_ufunc_descriptor_inference(self, node: ast.Call) -> Optional[_Binding]: + """Query the descriptor-inference registry for a NumPy ufunc call or ufunc method.""" + target = _resolve_ufunc_inference_target(node, self._evaluation_context()) + if target is None: + return None + + method_name, ufunc_name = target + infer_fn = oprepo.Replacements.get_ufunc_descriptor_inference(method_name) + if infer_fn is None: + return None + + input_descs, args, kwargs = self._resolve_call_inputs_for_inference(node) + try: + result = infer_fn(input_descs, ufunc_name, *args, **kwargs) + except Exception: + return None + return _binding_from_inference_result(result) + + def _try_method_descriptor_inference(self, node: ast.Call) -> Optional[_Binding]: + """Query the method descriptor-inference registry for ``obj.method(...)`` calls.""" + if not isinstance(node.func, ast.Attribute): + return None + # Resolve the object (e.g. ``a`` in ``a.sum()``) + obj_binding = self._resolve_binding(node.func.value) + if obj_binding is not None and obj_binding.descriptor is not None: + obj_desc = obj_binding.descriptor + else: + obj_desc = try_resolve_static_value(node.func.value, self._evaluation_context()) + if obj_desc is UNRESOLVED: + return None + method_name = node.func.attr + infer_fn = oprepo.Replacements.get_method_descriptor_inference(type(obj_desc), method_name) + if infer_fn is None: + return None + # Resolve the remaining arguments (skip 'self') + _input_descs, args, kwargs = self._resolve_call_inputs_for_inference(node) + try: + result = infer_fn(obj_desc, *args, **kwargs) + except Exception: + return None + return _binding_from_inference_result(result) + + def _apply_method_self_descriptor_side_effect(self, node: ast.AST) -> None: + if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute): + return + if not isinstance(node.func.value, ast.Name): + return + + obj_name = node.func.value.id + obj_binding = self.bindings.get(obj_name) + if obj_binding is None or obj_binding.descriptor is None: + return + + infer_fn = oprepo.Replacements.get_method_self_descriptor_inference( + type(obj_binding.descriptor).__name__, node.func.attr) + if infer_fn is None: + return + + _input_descs, args, kwargs = self._resolve_call_inputs_for_inference(node) + try: + updated_self = infer_fn(obj_binding.descriptor, *args, **kwargs) + except Exception: + return + if not isinstance(updated_self, data.Data): + return + + self._store_binding(obj_name, updated_self, kind=obj_binding.kind) + + def _try_attribute_descriptor_inference(self, node: ast.Attribute) -> Optional[_Binding]: + """Query the attribute descriptor-inference registry for ``obj.attr`` accesses.""" + obj_binding = self._resolve_binding(node.value) + if obj_binding is None or obj_binding.descriptor is None: + return None + obj_desc = obj_binding.descriptor + classname = type(obj_desc).__name__ + infer_fn = oprepo.Replacements.get_attribute_descriptor_inference(classname, node.attr) + if infer_fn is None: + return None + try: + result = infer_fn(obj_desc) + except Exception: + return None + return _binding_from_inference_result(result) + + def _resolve_call_inputs_for_inference(self, call_node: ast.Call) -> tuple: + """Resolve call arguments to ``(input_descriptors, args, kwargs)``.""" + input_descs = {} + args = [] + for arg in call_node.args: + binding = self._resolve_binding(arg) + if binding is not None and binding.descriptor is not None: + name = astutils.rname(arg) if isinstance(arg, (ast.Name, ast.Attribute)) else f'__arg{len(args)}' + input_descs[name] = binding.descriptor + args.append(name) + else: + val = self._safe_eval(arg, self._evaluation_context()) + args.append(val) + kwargs = {} + for kw in call_node.keywords: + if kw.arg is None: + continue + val = self._safe_eval(kw.value, self._evaluation_context()) + kwargs[kw.arg] = val + return input_descs, args, kwargs + + def _bind_loop_target(self, target: ast.AST) -> None: + loop_scalar = data.Scalar(dtypes.int64, transient=True) + if isinstance(target, ast.Name): + self._store_binding(target.id, loop_scalar) + return + if isinstance(target, (ast.Tuple, ast.List)): + for element in target.elts: + self._bind_loop_target(element) + + def _bind_target_structure(self, target: ast.AST, structure: Any) -> None: + + def _bind(name: str, substructure: Any) -> None: + descriptor = descriptor_from_structure(substructure) + if descriptor is None: + return + kind = 'scalar' if isinstance(descriptor, data.Scalar) else 'container' + self._store_binding(name, descriptor, kind=kind, structure=substructure) + + bind_target_structure(target, structure, _bind) + + def _store_binding(self, + name: str, + descriptor: data.Data, + *, + kind: Optional[str] = None, + structure: Optional[Any] = None) -> None: + binding_kind = kind or ('scalar' if isinstance(descriptor, data.Scalar) else 'container') + binding = _Binding(descriptor=_clone_descriptor(descriptor), + kind=binding_kind, + structure=copy.deepcopy(structure)) + self.bindings[name] = binding + self.results[name] = _clone_binding(binding) + + def _safe_eval(self, node: ast.AST, env: Dict[str, Any]) -> Optional[Any]: + if isinstance(node, ast.Call) and astutils.rname(node.func) not in {'tuple', 'list'}: + return None + value = try_resolve_static_value(node, env) + if value is UNRESOLVED: + return None + return value + + def _evaluation_context(self) -> Dict[str, Any]: + context = copy.copy(self.globals) + context.update({ + name: binding.descriptor + for name, binding in self.bindings.items() if binding.descriptor is not None + }) + return context + + def _dict_support_context(self) -> DictSupportContext: + return DictSupportContext(infer_descriptor=self._infer_known_descriptor, + infer_scalar_descriptor=self._infer_scalar_descriptor, + evaluation_context=self._evaluation_context) + + def _evaluate_annotation_class_type(self, node: Optional[ast.AST]) -> Optional[type[Any]]: + if node is None: + return None + value = self._safe_eval(node, self._evaluation_context()) + return direct_class_annotation_type(value) + + def _evaluate_descriptor(self, node: Optional[ast.AST]) -> Optional[data.Data]: + if node is None: + return None + class_type = self._evaluate_annotation_class_type(node) + if class_type is not None: + descriptor = data.Structure.from_class(class_type) + descriptor.transient = True + return descriptor + value = self._safe_eval(node, self._evaluation_context()) + if isinstance(value, data.Data): + descriptor = _clone_descriptor(value) + descriptor.transient = True + return descriptor + dtype = _normalize_dtype(value) + if dtype is not None: + return data.Scalar(dtype, transient=True) + return None + + def _initialize_direct_class_annotations(self, program: ast.AST) -> None: + arguments = list(program.args.posonlyargs) + list(program.args.args) + list(program.args.kwonlyargs) + if program.args.vararg is not None: + arguments.append(program.args.vararg) + if program.args.kwarg is not None: + arguments.append(program.args.kwarg) + + for argument in arguments: + class_type = self._evaluate_annotation_class_type(argument.annotation) + if class_type is not None: + self.annotated_class_types[argument.arg] = class_type + + def _attribute_root_name(self, node: ast.AST) -> Optional[str]: + root = self._attribute_root_and_members(node) + return root[0] if root is not None else None + + def _attribute_root_and_members(self, node: ast.AST) -> Optional[Tuple[str, List[str]]]: + current = node + members: List[str] = [] + while isinstance(current, ast.Attribute): + members.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + members.reverse() + return current.id, members + return None + + def _parse_shape(self, node: ast.AST) -> List[Any]: + value = self._safe_eval(node, self._evaluation_context()) + if isinstance(value, (list, tuple)): + return [self._shape_dim(dim) for dim in value] + if value is not None: + return [self._shape_dim(value)] + if isinstance(node, (ast.List, ast.Tuple)): + return [self._shape_dim(symbolic.pystr_to_symbolic(_unparse(elem))) for elem in node.elts] + return [self._shape_dim(symbolic.pystr_to_symbolic(_unparse(node)))] + + def _parse_dtype(self, node: Optional[ast.AST]) -> Optional[dtypes.typeclass]: + if node is None: + return None + value = self._safe_eval(node, self._evaluation_context()) + return _normalize_dtype(value) + + def _shape_dim(self, value: Any) -> Any: + if isinstance(value, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)): + return value + if isinstance(value, str): + return symbolic.pystr_to_symbolic(value) + return value + + def _call_argument(self, node: ast.Call, position: int, keyword: str) -> Optional[ast.AST]: + if len(node.args) > position: + return node.args[position] + for kw in node.keywords: + if kw.arg == keyword: + return kw.value + return None + + def _make_view_descriptor(self, + descriptor: data.Data, + shape: Optional[Sequence[Any]] = None, + new_axes: Optional[Sequence[int]] = None) -> data.Data: + view_desc = data.View.view(descriptor) + if shape is None: + shape = descriptor.shape + shape_list = list(shape) + if new_axes: + for axis in sorted(new_axes): + shape_list.insert(axis, 1) + if hasattr(view_desc, 'set_shape'): + view_desc.set_shape(shape_list) + return view_desc diff --git a/dace/frontend/python/schedule_tree_frontend.py b/dace/frontend/python/schedule_tree_frontend.py new file mode 100644 index 0000000000..9a6e23c8fe --- /dev/null +++ b/dace/frontend/python/schedule_tree_frontend.py @@ -0,0 +1,2766 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Python frontend entry point for building schedule trees directly from AST.""" + +import ast +import builtins as pybuiltins +import copy +import inspect +import numbers +import re +import warnings +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from dace import data, dtypes, symbolic, subsets +from dace.config import Config +from dace.data.pydata import PythonClass, PythonDict, PythonList, PythonTuple +from dace.frontend.common import op_repository as oprepo +from dace.frontend.python.common import DaceSyntaxError +from dace.frontend.python import astutils, memlet_parser, preprocessing +from dace.frontend.python.schedule_tree.array_literal_support import ArrayLiteralContext, ArrayLiteralSupportLibrary +from dace.frontend.python.schedule_tree.dict_support import DictSupportContext, DictSupportLibrary, StaticDictBinding +from dace.frontend.python.schedule_tree.lambda_support import LambdaResolver +from dace.frontend.python.schedule_tree.structure_support import ( + descriptor_from_structure, direct_class_annotation_type, ensure_nested_member_descriptor, nested_direct_class_owner, + python_class_requirement_for_member_assignment, resolve_member_access) +from dace.frontend.python.schedule_tree.static_evaluation import UNRESOLVED, try_resolve_static_value +from dace.frontend.python.schedule_tree.match_support import UnsupportedMatchPatternError, lower_match_to_statements +from dace.frontend.python.schedule_tree import ( + AttributeRewriter, ExpressionPlanningContext, CallbackHandler, CallableArgumentSpecializer, CallableResolver, + GenericExpressionSupportLibrary, NumpyLoweringContext, NumpySupportLibrary, ScheduleTreeTypeInference, _Binding, + callback_reason, desugar_schedule_tree_expansions, is_container_initialization, is_tuple_element_assignment, + promote_dynamic_scope_copies, resolve_function_calls) +from dace.frontend.python.schedule_tree.type_inference import _binding_from_inference_result, \ + _infer_static_subscript_descriptor, _resolve_ufunc_inference_target +from dace.memlet import Memlet +from dace.properties import CodeBlock +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.type_inference import infer_expr_type + +_INTERNAL_ITERATOR_HELPERS = { + '__dace_iterator_init', + '__dace_iterator_next', +} + +_SUPPORTED_RAISE_BEHAVIORS = {'support', 'ignore_dynamic', 'ignore_all'} + + +def _normalize_raise_behavior(value: Any) -> str: + normalized = str(value).strip().lower().replace('-', '_').replace(' ', '_') + if normalized in _SUPPORTED_RAISE_BEHAVIORS: + return normalized + return 'support' + + +def _clone_descriptor(descriptor: data.Data) -> data.Data: + return copy.deepcopy(descriptor) + + +def _clone_binding(binding: _Binding) -> _Binding: + descriptor = _clone_descriptor(binding.descriptor) if binding.descriptor is not None else None + return _Binding(descriptor=descriptor, kind=binding.kind, structure=copy.deepcopy(binding.structure)) + + +def _copy_target_descriptor(descriptor: data.Data) -> data.Data: + result = descriptor.as_array() if isinstance(descriptor, + (data.Reference, data.View)) else _clone_descriptor(descriptor) + result.transient = True + return result + + +def _unparse(node: ast.AST) -> str: + try: + working_node = astutils.copy_tree(node) + except Exception: + working_node = node + sanitized = _sanitize_ast_for_unparse(working_node) + return astutils.unparse(sanitized) + + +def _sanitize_ast_for_unparse(node: ast.AST) -> ast.AST: + + class _ConstantNormalizer(ast.NodeTransformer): + + def visit_Constant(self, constant: ast.Constant) -> ast.AST: + replacement_name = _constant_source_name(constant.value) + if replacement_name is None: + return constant + return ast.copy_location(ast.Name(id=replacement_name, ctx=ast.Load()), constant) + + sanitized = _ConstantNormalizer().visit(node) + return ast.fix_missing_locations(sanitized) + + +def _constant_source_name(value: Any) -> Optional[str]: + if isinstance(value, (str, bytes, numbers.Number, bool, type(None), type(Ellipsis))): + return None + + candidate = None + if hasattr(value, 'f') and hasattr(value.f, '__name__'): + candidate = value.f.__name__ + if not isinstance(candidate, str) or not candidate: + candidate = getattr(value, 'name', None) + if not isinstance(candidate, str) or not candidate: + candidate = getattr(value, '__name__', None) + if not isinstance(candidate, str) or not candidate: + candidate = getattr(value, '__qualname__', None) + if isinstance(candidate, str) and candidate: + candidate = candidate.split('.')[-1] + + if not isinstance(candidate, str) or not candidate: + return None + + sanitized = re.sub(r'\W|^(?=\d)', '_', candidate) + if not sanitized or sanitized in {'True', 'False', 'None'}: + return None + return sanitized + + +def _normalize_dtype(dtype: Any) -> Optional[dtypes.typeclass]: + if isinstance(dtype, dtypes.typeclass): + return dtype + if isinstance(dtype, data.Data): + return dtype.dtype + if dtype in (int, float, complex, bool): + return dtypes.typeclass(dtype) + if dtype is str: + return dtypes.string + try: + return dtypes.typeclass(dtype) + except (KeyError, TypeError, ValueError): + return None + + +def _pyobject_scalar_descriptor() -> data.Scalar: + return data.Scalar(dtypes.pyobject(), transient=True) + + +def _is_pyobject_scalar_descriptor(descriptor: Optional[data.Data]) -> bool: + return isinstance(descriptor, data.Scalar) and isinstance(descriptor.dtype, dtypes.pyobject) + + +def _is_iterator_next_call(node: ast.AST) -> bool: + return isinstance(node, ast.Call) and (astutils.rname(node.func) == 'next' or + (isinstance(node.func, ast.Attribute) and node.func.attr == '__next__')) + + +def _is_iterator_protocol_call(node: ast.AST) -> bool: + return isinstance(node, ast.Call) and (astutils.rname( + node.func) in {'iter', '__dace_iterator_init', '__dace_iterator_next'} or _is_iterator_next_call(node)) + + +def _is_singleton_scalar_memlet(memlet: Memlet) -> bool: + subset = memlet.subset + if not isinstance(subset, subsets.Range): + return False + try: + return subset.num_elements() == 1 + except Exception: + return False + + +def _string_scalar_descriptor() -> data.Scalar: + return data.Scalar(dtypes.string, transient=True) + + +def _binding_to_descriptor(value: Any) -> data.Data: + if isinstance(value, data.Data): + descriptor = _clone_descriptor(value) + else: + descriptor = _clone_descriptor(data.create_datadescriptor(value)) + + if isinstance(descriptor, data.View): + descriptor = descriptor.as_array() + descriptor.transient = False + return descriptor + + +def _binding_kind_for_descriptor(descriptor: data.Data) -> str: + if isinstance(descriptor, data.Reference): + return 'reference' + if isinstance(descriptor, data.Scalar): + if isinstance(descriptor.dtype, dtypes.callback): + return 'callback' + return 'scalar' + return 'container' + + +def _collect_scope_declarations(node: ast.AST) -> Tuple[set[str], set[str]]: + + class _ScopeDeclarationCollector(ast.NodeVisitor): + + def __init__(self) -> None: + self.global_names: set[str] = set() + self.nonlocal_names: set[str] = set() + + def visit_Global(self, global_node: ast.Global) -> None: + self.global_names.update(global_node.names) + + def visit_Nonlocal(self, nonlocal_node: ast.Nonlocal) -> None: + self.nonlocal_names.update(nonlocal_node.names) + + def visit_FunctionDef(self, nested_node: ast.FunctionDef) -> None: + if nested_node is node: + for stmt in nested_node.body: + self.visit(stmt) + + def visit_AsyncFunctionDef(self, nested_node: ast.AsyncFunctionDef) -> None: + if nested_node is node: + for stmt in nested_node.body: + self.visit(stmt) + + def visit_Lambda(self, lambda_node: ast.Lambda) -> None: + if lambda_node is node: + self.generic_visit(lambda_node.body) + + def visit_ClassDef(self, _: ast.ClassDef) -> None: + return + + collector = _ScopeDeclarationCollector() + collector.visit(node) + return collector.global_names, collector.nonlocal_names + + +def _function_signature_from_ast(node: ast.FunctionDef) -> inspect.Signature: + parameters: List[inspect.Parameter] = [] + positional = list(node.args.posonlyargs) + list(node.args.args) + positional_defaults = list(node.args.defaults) + positional_default_offset = len(positional) - len(positional_defaults) + + for index, arg in enumerate(node.args.posonlyargs): + default = inspect._empty if index < positional_default_offset else object() + parameters.append(inspect.Parameter(arg.arg, inspect.Parameter.POSITIONAL_ONLY, default=default)) + + for index, arg in enumerate(node.args.args, start=len(node.args.posonlyargs)): + default = inspect._empty if index < positional_default_offset else object() + parameters.append(inspect.Parameter(arg.arg, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=default)) + + if node.args.vararg is not None: + parameters.append(inspect.Parameter(node.args.vararg.arg, inspect.Parameter.VAR_POSITIONAL)) + + for arg, default_value in zip(node.args.kwonlyargs, node.args.kw_defaults): + default = inspect._empty if default_value is None else object() + parameters.append(inspect.Parameter(arg.arg, inspect.Parameter.KEYWORD_ONLY, default=default)) + + if node.args.kwarg is not None: + parameters.append(inspect.Parameter(node.args.kwarg.arg, inspect.Parameter.VAR_KEYWORD)) + + return inspect.Signature(parameters) + + +class _NestedFunctionProgram: + """AST-backed inline callee used for known nested FunctionDefs.""" + + _schedule_tree_inline_callable = True + + def __init__(self, name: str, function_ast: ast.FunctionDef, *, program_globals: Dict[str, Any], + external_globals: Dict[str, Any], captured_names: set[str], constants: Dict[str, Tuple[data.Data, + Any]], + callback_mapping: Dict[str, str], seed_bindings: Dict[str, _Binding], + lambda_bindings: Dict[str, ast.Lambda], callable_bindings: Dict[str, Any]) -> None: + self.name = name + self.function_ast = ast.fix_missing_locations(astutils.copy_tree(function_ast)) + self.program_globals = copy.copy(program_globals) + self.external_globals = copy.copy(external_globals) + self.captured_names = set(captured_names) + self.constants = {key: (_clone_descriptor(desc), value) for key, (desc, value) in constants.items()} + self.callback_mapping = dict(callback_mapping) + self.seed_bindings = {key: _clone_binding(binding) for key, binding in seed_bindings.items()} + self.lambda_bindings = {key: astutils.copy_tree(value) for key, value in lambda_bindings.items()} + self.callable_bindings = dict(callable_bindings) + self.signature = _function_signature_from_ast(function_ast) + self.argnames = [parameter.name for parameter in self.signature.parameters.values()] + + def __descriptor__(self) -> data.Data: + return data.Scalar(dtypes.callback(None)) + + def __deepcopy__(self, memo: Dict[int, Any]) -> '_NestedFunctionProgram': + memo[id(self)] = self + return self + + def _generate_schedule_tree(self, + args: Tuple[Any], + kwargs: Dict[str, Any], + *, + lambda_bindings: Optional[Dict[str, ast.Lambda]] = None, + callable_bindings: Optional[Dict[str, Any]] = None) -> tn.ScheduleTreeRoot: + bound_args = self.signature.bind_partial(*args, **kwargs) + argtypes = {name: _binding_to_descriptor(value) for name, value in bound_args.arguments.items()} + + active_lambda_bindings = {key: astutils.copy_tree(value) for key, value in self.lambda_bindings.items()} + active_lambda_bindings.update({ + key: astutils.copy_tree(value) + for key, value in (lambda_bindings or {}).items() + }) + + active_callable_bindings = dict(self.callable_bindings) + active_callable_bindings.update(dict(callable_bindings or {})) + + seed_bindings = { + key: _clone_binding(binding) + for key, binding in self.seed_bindings.items() if key not in bound_args.arguments + } + + parsed_ast = preprocessing.PreprocessedAST('', getattr(self.function_ast, 'lineno', 0), '', + astutils.copy_tree(self.function_ast), + copy.copy(self.program_globals)) + return build_schedule_tree(self.name, + parsed_ast, + argtypes, + constants={ + key: (_clone_descriptor(desc), value) + for key, (desc, value) in self.constants.items() + }, + callback_mapping=dict(self.callback_mapping), + arg_names=[name for name in self.argnames if name in argtypes], + lambda_bindings=active_lambda_bindings, + callable_bindings=active_callable_bindings, + seed_bindings=seed_bindings, + external_globals=self.external_globals) + + +def _should_fallback_to_pyobject_scalar(node: ast.AST, value: Any = UNRESOLVED) -> bool: + if value is None or isinstance(value, (str, bytes, numbers.Number, bool, type(Ellipsis))): + return False + return isinstance(node, (ast.Await, ast.Attribute, ast.BinOp, ast.BoolOp, ast.Call, ast.Compare, ast.FormattedValue, + ast.IfExp, ast.JoinedStr, ast.Name, ast.NamedExpr, ast.UnaryOp, ast.Yield, ast.YieldFrom)) + + +def _requires_fstring_callback(node: ast.AST) -> bool: + return isinstance(node, (ast.JoinedStr, ast.FormattedValue)) + + +def build_schedule_tree(name: str, + parsed_ast: preprocessing.PreprocessedAST, + argtypes: Dict[str, data.Data], + *, + constants: Optional[Dict[str, Tuple[data.Data, Any]]] = None, + callback_mapping: Optional[Dict[str, str]] = None, + arg_names: Optional[Sequence[str]] = None, + lambda_bindings: Optional[Dict[str, ast.Lambda]] = None, + callable_bindings: Optional[Dict[str, Any]] = None, + seed_bindings: Optional[Dict[str, _Binding]] = None, + external_globals: Optional[Dict[str, Any]] = None, + inline_calls: bool = True) -> tn.ScheduleTreeRoot: + """ + Build a schedule tree directly from a preprocessed Python AST. + + :param name: Program name. + :param parsed_ast: Preprocessed program AST and metadata. + :param argtypes: Mapping from visible argument names to DaCe descriptors. + :param inline_calls: If True (default), resolve and inline nested + ``@dace.program`` calls after building the tree. + :return: A schedule tree rooted at a top-level scope. + """ + desugared_ast = preprocessing.PreprocessedAST( + parsed_ast.filename, + parsed_ast.src_line, + parsed_ast.src, + desugar_schedule_tree_expansions(parsed_ast.preprocessed_ast, + filename=parsed_ast.filename, + global_vars=parsed_ast.program_globals, + known_descriptors=argtypes, + seed_bindings=seed_bindings, + callable_bindings=callable_bindings), + parsed_ast.program_globals, + resolved_arg_annotations=copy.deepcopy(parsed_ast.resolved_arg_annotations)) + builder = PythonScheduleTreeBuilder(name, + desugared_ast, + argtypes, + constants=constants, + callback_mapping=callback_mapping, + arg_names=arg_names, + lambda_bindings=lambda_bindings, + callable_bindings=callable_bindings, + seed_bindings=seed_bindings, + external_globals=external_globals) + root = builder.build() + promote_dynamic_scope_copies(root) + if inline_calls: + resolve_function_calls(root) + return root + + +class PythonScheduleTreeBuilder(ast.NodeVisitor): + """Builds schedule trees from preprocessed Python ASTs.""" + + def __init__(self, + name: str, + parsed_ast: preprocessing.PreprocessedAST, + argtypes: Dict[str, data.Data], + *, + constants: Optional[Dict[str, Tuple[data.Data, Any]]] = None, + callback_mapping: Optional[Dict[str, str]] = None, + arg_names: Optional[Sequence[str]] = None, + lambda_bindings: Optional[Dict[str, ast.Lambda]] = None, + callable_bindings: Optional[Dict[str, Any]] = None, + seed_bindings: Optional[Dict[str, _Binding]] = None, + external_globals: Optional[Dict[str, Any]] = None) -> None: + self.name = name + self.filename = parsed_ast.filename + self.parsed_ast = parsed_ast + self.argtypes = {k: _clone_descriptor(v) for k, v in argtypes.items()} + self.globals = copy.copy(parsed_ast.program_globals) + self.external_globals = copy.copy(parsed_ast.program_globals if external_globals is None else external_globals) + self.root = tn.ScheduleTreeRoot(name=name, + children=[], + containers={}, + symbols={}, + constants=self._clone_constants(constants), + callback_mapping=dict(callback_mapping or {}), + arg_names=list(arg_names or [])) + self.scope_stack: List[tn.ScheduleTreeScope] = [self.root] + self.bindings: Dict[str, _Binding] = {} + self.annotated_descriptors: Dict[str, data.Data] = {} + self.annotated_class_types: Dict[str, type[Any]] = {} + self.explicit_structure_argument_names: set[str] = set() + self.lambda_bindings: Dict[str, ast.Lambda] = { + key: astutils.copy_tree(value) + for key, value in (lambda_bindings or {}).items() + } + self.callable_bindings: Dict[str, Any] = dict(callable_bindings or {}) + self.seed_bindings = {key: _clone_binding(binding) for key, binding in (seed_bindings or {}).items()} + self._declared_global_names: set[str] = set() + self._declared_nonlocal_names: set[str] = set() + self._callback_mutated_global_names: set[str] = set() + self._raise_behavior = _normalize_raise_behavior(Config.get('frontend', 'raise_statements')) + self._emit_external_reassign_nodes = isinstance(parsed_ast.preprocessed_ast, ast.Module) + self._global_lambda_cache: Dict[str, Optional[ast.Lambda]] = {} + self.expression_support = GenericExpressionSupportLibrary() + self.array_literal_support = ArrayLiteralSupportLibrary() + self.dict_support = DictSupportLibrary() + self.numpy_support = NumpySupportLibrary() + self.attribute_rewriter = AttributeRewriter(self._evaluation_context) + self.lambda_resolver = LambdaResolver(self.globals, + self.lambda_bindings, + self.callable_bindings, + cache=self._global_lambda_cache) + self.callable_resolver = CallableResolver(callable_bindings=self.callable_bindings, + evaluation_context=self._evaluation_context) + self.callable_specializer = CallableArgumentSpecializer( + lambda_resolver=self.lambda_resolver, + callable_resolver=self.callable_resolver, + bindings=self.bindings, + infer_descriptor=self._infer_plannable_expression_descriptor, + resolve_data_access=self._resolve_data_access, + is_callback_descriptor=self._is_callback_descriptor, + callback_specialization_value=self._callback_specialization_value) + + def _raise_callback_syntax_error(node: ast.AST, message: str) -> None: + raise DaceSyntaxError(self, node, message) + + self.callback_handler = CallbackHandler(bindings=self.bindings, + callback_mutated_global_names=self._callback_mutated_global_names, + callable_resolver=self.callable_resolver, + evaluation_context=self._evaluation_context, + append_node=self._append_node, + register_binding=self._register_binding, + fresh_callback_name=self._fresh_callback_name, + fresh_transient_name=self._fresh_transient_name, + render_callback_code=self._callback_code_text, + collect_scope_declarations=_collect_scope_declarations, + raise_syntax_error=_raise_callback_syntax_error, + binding_kind_for_descriptor=_binding_kind_for_descriptor, + pyobject_scalar_descriptor=_pyobject_scalar_descriptor, + is_pyobject_scalar_descriptor=_is_pyobject_scalar_descriptor, + is_iterator_protocol_call=_is_iterator_protocol_call, + is_iterator_next_call=_is_iterator_next_call) + self._terminate_body_stack: List[bool] = [] + + self._initialize_root_scope() + self._initialize_seed_bindings() + self._initialize_direct_class_annotations() + type_inference_globals = copy.copy(self.external_globals) + type_inference_globals.update(self.globals) + self.inferred_bindings = ScheduleTreeTypeInference(type_inference_globals, + self.argtypes, + seed_bindings=self.seed_bindings).infer(self._program_node()) + for name, binding in self.inferred_bindings.items(): + if binding.descriptor is not None: + self.root.containers.setdefault(name, _clone_descriptor(binding.descriptor)) + + def build(self) -> tn.ScheduleTreeRoot: + """Build the schedule tree for the program AST.""" + program = self._program_node() + for stmt in program.body: + self.visit(stmt) + return self.root + + def visit_Assign(self, node: ast.Assign) -> None: + reason = callback_reason(node) + if reason is not None: + self.callback_handler.wrap_node(node, reason) + return + if is_tuple_element_assignment(node) and len(node.targets) == 1: + self._handle_tuple_element_assignment(node.targets[0], node.value) + return + if is_container_initialization(node): + for target in node.targets: + if isinstance(target, ast.Name): + self._handle_container_initialization(target.id, node.value) + else: + self._handle_assignment(target, node.value) + return + for target in node.targets: + self._handle_assignment(target, node.value) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + reason = callback_reason(node) + if reason is not None: + self.callback_handler.wrap_node(node, reason) + return + descriptor = self._evaluate_descriptor(node.annotation) + class_type = self._evaluate_annotation_class_type(node.annotation) + if descriptor is not None and isinstance(node.target, ast.Name): + self.annotated_descriptors[node.target.id] = descriptor + if class_type is not None: + self.annotated_class_types[node.target.id] = class_type + if node.value is None: + existing = self.bindings.get(node.target.id) + if existing is not None and isinstance(existing.descriptor, data.Reference): + return + if not isinstance(descriptor, data.Reference): + self._register_binding(node.target.id, descriptor, kind=_binding_kind_for_descriptor(descriptor)) + return + if node.value is not None: + self._handle_assignment(node.target, node.value, annotated_descriptor=descriptor) + + def visit_AugAssign(self, node: ast.AugAssign) -> None: + reason = callback_reason(node) + if reason is not None: + self.callback_handler.wrap_node(node, reason) + return + value = ast.BinOp(left=astutils.copy_tree(node.target), op=node.op, right=astutils.copy_tree(node.value)) + self._handle_assignment(node.target, value) + + def visit_Expr(self, node: ast.Expr) -> None: + reason = callback_reason(node) + if reason is not None: + self.callback_handler.wrap_node(node, reason) + return + if isinstance(node.value, ast.Constant) and isinstance(node.value.value, str): + return + self.callback_handler.reject_mutated_global_uses(node.value) + value = self.lambda_resolver.inline_known_lambda_calls(node.value) + if self.callable_resolver.is_dace_program_call(value): + self._materialize_call_args(value) + self._emit_function_call(value) + return + if self.callable_resolver.is_sdfg_call(value): + self._materialize_call_args(value) + if self._emit_sdfg_call(value): + return + planned_value = self.expression_support.plan_expression(self._expression_planning_context(), + value, + materialize_root=False) + if self._handle_expression(planned_value): + self._apply_method_self_descriptor_side_effect(planned_value) + return + if _requires_fstring_callback(planned_value): + callback_expr = ast.copy_location(ast.Expr(value=astutils.copy_tree(planned_value)), planned_value) + self.callback_handler.wrap_node(callback_expr, 'f-string') + return + self._append_node(tn.StatementNode(code=CodeBlock(self._format_runtime_expression(planned_value)))) + + def visit_Return(self, node: ast.Return) -> None: + if node.value is None: + values: List[str] = [] + else: + self.callback_handler.reject_mutated_global_uses(node.value) + return_value = self.lambda_resolver.inline_known_lambda_calls(node.value) + if self.callable_resolver.is_dace_program_call(return_value): + # Materialize array-valued arguments, emit the function call; + # the inlining pass will propagate the callee's return value. + self._materialize_call_args(return_value) + tmp = self._fresh_transient_name('__stree_retval') + self._emit_function_call(return_value, return_targets=[tmp]) + self._append_node(tn.ReturnNode(values=[tmp])) + return + if self.callable_resolver.is_sdfg_call(return_value): + self._materialize_call_args(return_value) + tmp = self._fresh_transient_name('__stree_retval') + self._register_binding(tmp, _pyobject_scalar_descriptor(), kind='scalar') + if self._emit_sdfg_call(return_value, return_targets=[tmp]): + self._append_node(tn.ReturnNode(values=[tmp])) + return + if isinstance(return_value, ast.Tuple): + planned_values = [ + self.expression_support.plan_expression(self._expression_planning_context(), + value, + materialize_root=True) for value in return_value.elts + ] + values = [self._materialize_return_value(v) for v in planned_values] + else: + planned_value = self.expression_support.plan_expression(self._expression_planning_context(), + return_value, + materialize_root=True) + values = [self._materialize_return_value(planned_value)] + self._append_node(tn.ReturnNode(values=values)) + + def _materialize_return_value(self, value: ast.AST) -> str: + """Return the descriptor name backing a return-value expression. + + Non-descriptor expressions are materialized into fresh temporaries + before returning so :class:`ReturnNode` only refers to descriptor names. + """ + if _requires_fstring_callback(value): + materialized = self.callback_handler.materialize_expression(value, + 'f-string', + _string_scalar_descriptor(), + prefix='__stree_retval') + return materialized.id + + if isinstance(value, ast.Name) and self._resolve_data_access(value) is not None: + return value.id + + descriptor = self._infer_plannable_expression_descriptor(value) + if descriptor is None: + descriptor = self._infer_scalar_descriptor(value, None) + if descriptor is not None: + name = self._fresh_transient_name('__stree_retval') + kind = 'scalar' if isinstance(descriptor, data.Scalar) else 'container' + self._register_binding(name, descriptor, kind=kind) + target = ast.Name(id=name, ctx=ast.Store()) + if self._emit_computed_assignment(target, value, descriptor): + return name + + output = self._resolve_output_target(target, value, descriptor) + if output is not None: + _, out_memlet, _ = output + tasklet = tn.FrontendTasklet(name=self._tasklet_name(target), + code=CodeBlock(f'{_unparse(target)} = {_unparse(value)}')) + self._append_node( + tn.TaskletNode(node=tasklet, + in_memlets=self._collect_input_memlets(value), + out_memlets={'out': out_memlet})) + return name + + materialized = self.callback_handler.materialize_expression(value, + 'return expression', + _pyobject_scalar_descriptor(), + prefix='__stree_retval') + return materialized.id + + def visit_Pass(self, node: ast.Pass) -> None: + del node + + def visit_Break(self, node: ast.Break) -> None: + del node + self._append_node(tn.BreakNode()) + + def visit_Continue(self, node: ast.Continue) -> None: + del node + self._append_node(tn.ContinueNode()) + + def visit_If(self, node: ast.If) -> None: + reason = callback_reason(node) + if reason is not None: + self.callback_handler.wrap_node(node, reason) + return + self.callback_handler.reject_mutated_global_uses(node.test) + self._emit_if_chain(node) + + def visit_For(self, node: ast.For) -> None: + reason = callback_reason(node) + if reason is not None: + self.callback_handler.wrap_node(node, reason) + return + self.callback_handler.reject_mutated_global_uses(node.iter) + loop_indices = self._parse_for_indices(node.target) + iterator_kind, iterator_ranges = self._parse_for_iterator(node.iter) + + if iterator_kind == 'dace.map': + map_scope = tn.MapScope(node=tn.FrontendMap(params=loop_indices, ranges=iterator_ranges), children=[]) + for index_name in loop_indices: + self._register_symbol(index_name) + self._append_node(map_scope) + self._visit_body(map_scope, node.body) + elif iterator_kind == 'range': + index_name = loop_indices[0] + start, stop, step = iterator_ranges[0] + comparator = '>' if stop.startswith('-') or step.startswith('-') else '<' + loop_scope = tn.LoopScope(loop=tn.FrontendLoop( + loop_condition=CodeBlock(f'{index_name} {comparator} {stop}'), + init_statement=CodeBlock(f'{index_name} = {start}'), + update_statement=CodeBlock(f'{index_name} = {index_name} + {step}'), + loop_variable=index_name), + children=[]) + self._register_symbol(index_name) + self._append_node(loop_scope) + self._visit_body(loop_scope, node.body) + else: + self._append_node(tn.StatementNode(code=CodeBlock(_unparse(node)))) + + if node.orelse: + else_scope = tn.ElseScope(children=[]) + self._append_node(else_scope) + self._visit_body(else_scope, node.orelse) + + def visit_While(self, node: ast.While) -> None: + reason = callback_reason(node) + if reason is not None: + self.callback_handler.wrap_node(node, reason) + return + self.callback_handler.reject_mutated_global_uses(node.test) + loop_scope = tn.LoopScope( + loop=tn.FrontendLoop(loop_condition=CodeBlock(self._format_runtime_expression(node.test))), children=[]) + self._append_node(loop_scope) + self._visit_body(loop_scope, node.body) + if node.orelse: + else_scope = tn.ElseScope(children=[]) + self._append_node(else_scope) + self._visit_body(else_scope, node.orelse) + + def generic_visit(self, node: ast.AST) -> None: + import warnings + warnings.warn( + f'Schedule tree frontend: unhandled AST node {type(node).__name__} ' + f'at line {getattr(node, "lineno", "?")} — wrapping as callback', + stacklevel=2) + self.callback_handler.wrap_node(node, f'unhandled {type(node).__name__}') + + # ------------------------------------------------------------------ # + # PythonCallbackNode helpers # + # ------------------------------------------------------------------ # + + def _callback_code_text(self, node: ast.AST) -> str: + """Return parseable source code for a callback-wrapped AST node.""" + if isinstance(node, ast.stmt): + try: + return ast.unparse(_sanitize_ast_for_unparse(astutils.copy_tree(node))) + except Exception: + try: + return astutils.unparse(_sanitize_ast_for_unparse(astutils.copy_tree(node))) + except Exception: + return 'pass' + + try: + return self._format_runtime_expression(node) + except Exception: + try: + return _unparse(node) + except Exception: + try: + return ast.unparse(node) + except Exception: + return 'None' + + # ------------------------------------------------------------------ # + # Category C visitors — always callback # + # ------------------------------------------------------------------ # + + def visit_Try(self, node: ast.Try) -> None: + self.callback_handler.wrap_node(node, 'try/except') + + # Python 3.11+ except* (TryStar) + if hasattr(ast, 'TryStar'): + visit_TryStar = visit_Try + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + raise DaceSyntaxError( + self, node, + 'Nested class definitions are unsupported in @dace.program schedule-tree lowering because they cannot ' + 'be outlined safely from compiled code') + + def visit_Import(self, node: ast.Import) -> None: + self.callback_handler.wrap_node(node, 'import') + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + self.callback_handler.wrap_node(node, 'import') + + def visit_Yield(self, node: ast.AST) -> None: + self.callback_handler.wrap_node(node, 'yield') + + def visit_YieldFrom(self, node: ast.AST) -> None: + self.callback_handler.wrap_node(node, 'yield from') + + def visit_Await(self, node: ast.AST) -> None: + self.callback_handler.wrap_node(node, 'await') + + def visit_Match(self, node: ast.AST) -> None: + subject = self._match_subject_expression(node.subject) + try: + lowered = lower_match_to_statements(node, subject) + except UnsupportedMatchPatternError: + self.callback_handler.wrap_node(node, 'match/case') + return + + for stmt in lowered: + self.visit(stmt) + + def visit_With(self, node: ast.With) -> None: + self.callback_handler.wrap_node(node, 'context manager') + + def visit_AsyncWith(self, node: ast.AsyncWith) -> None: + self.callback_handler.wrap_node(node, 'context manager') + + # ------------------------------------------------------------------ # + # Category B visitors — try to lower, fall back to callback # + # ------------------------------------------------------------------ # + + def visit_Global(self, node: ast.Global) -> None: + self._declared_global_names.update(node.names) + for name in node.names: + value = self._resolve_external_scope_value(name) + if value is UNRESOLVED: + self.callback_handler.wrap_node(node, 'global scope') + return + self._bind_external_scope_value(name, value) + + def visit_Nonlocal(self, node: ast.Nonlocal) -> None: + self._declared_nonlocal_names.update(node.names) + for name in node.names: + if name in self.bindings: + continue + + value = self._resolve_external_scope_value(name) + if value is UNRESOLVED: + raise DaceSyntaxError(self, node, f'Could not resolve nonlocal name "{name}" in schedule-tree lowering') + + self._bind_external_scope_value(name, value) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + """Handle nested function definitions. + + Known nested function definitions are lowered as inline call regions. + When the target cannot be modeled safely, keep explicit callback + fallback. + """ + global_names, nonlocal_names = _collect_scope_declarations(node) + inline_function = self._make_nested_function_program(node) + if inline_function is not None: + self.callable_bindings[node.name] = inline_function + self.lambda_bindings.pop(node.name, None) + self._register_binding(node.name, data.Scalar(dtypes.callback(None), transient=True), kind='callback') + return + conflicting_globals = self._enclosing_load_uses_outside(node, global_names) + if conflicting_globals: + conflicts = ', '.join(sorted(conflicting_globals)) + raise DaceSyntaxError( + self, node, + f'Nested callback functions cannot reassign global names that are used in the enclosing program: ' + f'{conflicts}') + if nonlocal_names: + raise DaceSyntaxError( + self, node, 'Nested functions that use nonlocal declarations cannot fall back to callbacks during ' + 'schedule-tree lowering') + self.callback_handler.wrap_node(node, 'nested function') + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + global_names, nonlocal_names = _collect_scope_declarations(node) + conflicting_globals = self._enclosing_load_uses_outside(node, global_names) + if conflicting_globals: + conflicts = ', '.join(sorted(conflicting_globals)) + raise DaceSyntaxError( + self, node, + f'Nested callback functions cannot reassign global names that are used in the enclosing program: ' + f'{conflicts}') + if nonlocal_names: + raise DaceSyntaxError( + self, node, 'Nested functions that use nonlocal declarations cannot fall back to callbacks during ' + 'schedule-tree lowering') + self.callback_handler.wrap_node(node, 'async function') + + def visit_Delete(self, node: ast.Delete) -> None: + # del of DaCe arrays is a no-op (runtime manages memory) + for target in node.targets: + if isinstance(target, ast.Name) and target.id in self.bindings: + continue # No-op for known containers + else: + self.callback_handler.wrap_node(node, 'delete') + return + + def visit_Raise(self, node: ast.Raise) -> None: + if node.cause is not None: + raise DaceSyntaxError( + self, node, + 'raise from is unsupported in @dace.program schedule-tree lowering because exceptional control flow ' + 'cannot be represented safely') + + if self._raise_behavior == 'ignore_all': + return + + raise_node = self._build_direct_raise_node(node) + if raise_node is not None: + self._append_node(raise_node) + self._terminate_current_body() + return + + if self._raise_behavior == 'ignore_dynamic': + return + + self.callback_handler.wrap_node(node, 'raise') + self._terminate_current_body() + + def _append_node(self, node: tn.ScheduleTreeNode) -> None: + scope = self.scope_stack[-1] + node.parent = scope + scope.children.append(node) + + def _visit_body(self, scope: tn.ScheduleTreeScope, body: Sequence[ast.AST]) -> None: + self.scope_stack.append(scope) + self._terminate_body_stack.append(False) + try: + for stmt in body: + self.visit(stmt) + if self._terminate_body_stack[-1]: + break + finally: + self._terminate_body_stack.pop() + self.scope_stack.pop() + + def _program_node(self) -> ast.FunctionDef: + program_ast = self.parsed_ast.preprocessed_ast + if isinstance(program_ast, ast.Module): + node = program_ast.body[0] + else: + node = program_ast + if not isinstance(node, ast.FunctionDef): + raise TypeError('Expected a preprocessed FunctionDef as schedule-tree frontend input') + return node + + def _enclosing_load_uses_outside(self, excluded_node: ast.AST, names: set[str]) -> set[str]: + if not names or not self.parsed_ast.src: + return set() + + source_ast = ast.parse(astutils._remove_outer_indentation(self.parsed_ast.src)) + ast.increment_lineno(source_ast, self.parsed_ast.src_line) + + source_program = source_ast.body[0] if source_ast.body else None + if not isinstance(source_program, ast.FunctionDef): + return set() + + excluded_end = (getattr(excluded_node, 'end_lineno', None) + or excluded_node.lineno, getattr(excluded_node, 'end_col_offset', None) or 0) + + class _LoadUseFinder(ast.NodeVisitor): + + def __init__(self, candidates: set[str], excluded_end_location: Tuple[int, int]) -> None: + self.candidates = candidates + self.excluded_end_location = excluded_end_location + self.used: set[str] = set() + + def visit_Name(self, name_node: ast.Name) -> None: + location = (getattr(name_node, 'lineno', 0), getattr(name_node, 'col_offset', 0)) + if (isinstance(name_node.ctx, ast.Load) and name_node.id in self.candidates + and location > self.excluded_end_location): + self.used.add(name_node.id) + + finder = _LoadUseFinder(names, excluded_end) + finder.visit(source_program) + return finder.used + + def _terminate_current_body(self) -> None: + if self._terminate_body_stack: + self._terminate_body_stack[-1] = True + + def _is_direct_exception_type(self, node: ast.AST) -> bool: + value = try_resolve_static_value(node, self._evaluation_context()) + if value is UNRESOLVED: + builtins_env = { + **pybuiltins.__dict__, + 'builtins': pybuiltins, + '__builtins__': pybuiltins.__dict__, + } + value = try_resolve_static_value(node, builtins_env) + if value is UNRESOLVED: + try: + value = astutils.evalnode(node, builtins_env) + except Exception: + return False + if not inspect.isclass(value): + return False + try: + return issubclass(value, BaseException) + except TypeError: + return False + + def _build_direct_raise_node(self, node: ast.Raise) -> Optional[tn.RaiseNode]: + if node.exc is None: + return None + + exc_type = node.exc + args: List[ast.AST] = [] + kwargs: Dict[str, ast.AST] = {} + + if isinstance(node.exc, ast.Call): + if any(keyword.arg is None for keyword in node.exc.keywords): + return None + exc_type = node.exc.func + args = list(node.exc.args) + kwargs = {keyword.arg: keyword.value for keyword in node.exc.keywords if keyword.arg is not None} + + if not self._is_direct_exception_type(exc_type): + return None + + return tn.RaiseNode(exception_type=CodeBlock(self._format_runtime_expression(exc_type)), + args=[CodeBlock(self._format_runtime_expression(argument)) for argument in args], + kwargs={ + name: CodeBlock(self._format_runtime_expression(value)) + for name, value in kwargs.items() + }) + + def _match_subject_expression(self, subject: ast.AST) -> ast.AST: + planned = self.expression_support.plan_expression(self._expression_planning_context(), + subject, + materialize_root=False) + if isinstance(planned, (ast.Name, ast.Constant)): + return planned + + descriptor = self._infer_compute_descriptor(planned) + if descriptor is None: + descriptor = self._infer_scalar_descriptor(planned, None) + if descriptor is None: + descriptor = _pyobject_scalar_descriptor() + return self._materialize_temporary_expression(planned, descriptor) + + def _initialize_root_scope(self) -> None: + for name, descriptor in self.argtypes.items(): + self.root.containers[name] = _clone_descriptor(descriptor) + kind = 'callback' if self._is_callback_descriptor(descriptor) else 'container' + self.bindings[name] = _Binding(descriptor=_clone_descriptor(descriptor), kind=kind) + self.globals[name] = descriptor + for free_symbol in descriptor.free_symbols: + self.root.symbols[free_symbol.name] = free_symbol + + for name, value in self.globals.items(): + if isinstance(value, symbolic.symbol): + self.root.symbols[name] = value + + def _initialize_seed_bindings(self) -> None: + for name, binding in self.seed_bindings.items(): + self.bindings[name] = _clone_binding(binding) + if binding.descriptor is not None: + self.root.containers[name] = _clone_descriptor(binding.descriptor) + self.globals.setdefault(name, _clone_descriptor(binding.descriptor)) + + def _external_scope_kind(self, name: str) -> Optional[str]: + if name in self._declared_nonlocal_names: + return 'nonlocal' + if name in self._declared_global_names: + return 'global' + return None + + def _should_emit_external_reassign(self, name: str) -> bool: + return self._emit_external_reassign_nodes and self._external_scope_kind(name) is not None + + def _handle_assignment(self, + target: ast.AST, + value: ast.AST, + annotated_descriptor: Optional[data.Data] = None) -> None: + self.callback_handler.reject_mutated_global_uses(value) + if isinstance(target, ast.Name): + self._update_callable_binding(target.id, value) + self.lambda_resolver.update_binding(target.id, value) + + value = self.lambda_resolver.inline_known_lambda_calls(value) + + # Intercept nested @dace.program calls — materialize array-valued + # arguments into temporaries first, then emit FunctionCallScope. + if self.callable_resolver.is_dace_program_call(value): + self._materialize_call_args(value) + targets = [target.id] if isinstance(target, ast.Name) else [_unparse(target)] + self._emit_function_call(value, return_targets=targets) + return + + if self.callable_resolver.is_sdfg_call(value) and not isinstance(target, (ast.Tuple, ast.List)): + self._materialize_call_args(value) + if self._emit_sdfg_call_assignment(target, value, annotated_descriptor): + return + + if isinstance(target, (ast.Tuple, ast.List)): + self._seed_inferred_target_bindings(target) + if self._emit_computed_assignment(target, value, annotated_descriptor): + return + self._append_node(tn.StatementNode(code=CodeBlock(self._format_assignment_statement(target, value)))) + return + + value = self.expression_support.plan_expression(self._expression_planning_context(), + value, + materialize_root=False) + if not self._ensure_pythonclass_for_direct_class_annotation(target): + self._raise_or_warn_if_member_assignment_requires_python_class(target) + self._update_dict_subscript_binding(target, value) + + source_access = self._resolve_data_access(value) + self._ensure_pythonclass_member_target(target, value, source_access) + if isinstance(target, ast.Name): + self._handle_name_assignment(target.id, value, source_access, annotated_descriptor) + return + + target_access = self._resolve_data_access(target) + if source_access is not None and target_access is not None: + _, source_memlet, source_desc, _ = source_access + target_name, target_memlet, target_desc, _ = target_access + memlet = copy.deepcopy(source_memlet) + memlet.other_subset = copy.deepcopy(target_memlet.subset) + if isinstance(target_desc, data.Reference): + self._append_node( + tn.RefSetNode(target=target_name, + memlet=memlet, + src_desc=_clone_descriptor(source_desc), + ref_desc=_clone_descriptor(target_desc))) + return + self._append_node(tn.CopyNode(target=target_name, memlet=memlet)) + return + + if self._emit_computed_assignment(target, value, annotated_descriptor): + return + + self._append_node(tn.StatementNode(code=CodeBlock(self._format_assignment_statement(target, value)))) + + def _raise_or_warn_if_member_assignment_requires_python_class(self, target: ast.AST) -> None: + if not isinstance(target, ast.Attribute): + return + + owner_access = self._resolve_data_access(target.value) + if owner_access is None: + return + + _, _, owner_descriptor, _ = owner_access + message = python_class_requirement_for_member_assignment(owner_descriptor, target.attr) + if message is None: + return + + root_name = self._attribute_root_name(target.value) + if root_name in self.explicit_structure_argument_names: + raise DaceSyntaxError(self, target, message) + + warnings.warn(message, UserWarning, stacklevel=2) + + def _ensure_pythonclass_for_direct_class_annotation(self, target: ast.AST) -> bool: + if not isinstance(target, ast.Attribute): + return False + + owner_access = self._resolve_data_access(target.value) + if owner_access is None: + return False + + _, _, owner_descriptor, _ = owner_access + if python_class_requirement_for_member_assignment(owner_descriptor, target.attr) is None: + return False + + root = self._attribute_root_and_members(target.value) + if root is None: + return False + root_name, member_names = root + class_type = self.annotated_class_types.get(root_name) + if class_type is None: + return False + if nested_direct_class_owner(class_type, member_names) is None: + return False + + binding = self.bindings.get(root_name) + if binding is None or binding.descriptor is None or isinstance(binding.descriptor, PythonClass): + return False + + try: + python_class_descriptor = PythonClass.from_class(class_type) + except (TypeError, ValueError): + return False + + self._store_binding(root_name, python_class_descriptor, kind=binding.kind) + return True + + def _attribute_root_name(self, node: ast.AST) -> Optional[str]: + root = self._attribute_root_and_members(node) + return root[0] if root is not None else None + + def _attribute_root_and_members(self, node: ast.AST) -> Optional[Tuple[str, List[str]]]: + current = node + members: List[str] = [] + while isinstance(current, ast.Attribute): + members.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + members.reverse() + return current.id, members + return None + + def _update_dict_subscript_binding(self, target: ast.AST, value: ast.AST) -> None: + if not isinstance(target, ast.Subscript) or not isinstance(target.value, ast.Name): + return + binding = self.bindings.get(target.value.id) + if binding is None or binding.descriptor is None: + return + dict_binding = binding.structure if isinstance(binding.structure, StaticDictBinding) else None + updated = self.dict_support.infer_assignment_binding(self._dict_support_context(target.value.id), + binding.descriptor, dict_binding, target.slice, value) + if updated is None: + return + updated_descriptor, updated_binding = updated + self._store_binding(target.value.id, + updated_descriptor, + kind=binding.kind, + structure=updated_binding if updated_binding is not None else None) + + def _handle_name_assignment(self, name: str, value: ast.AST, source_access: Optional[Tuple[str, Memlet, data.Data, + Optional[data.Data]]], + annotated_descriptor: Optional[data.Data]) -> None: + if self._is_internal_iterator_binding_name(name) or self._is_internal_iterator_helper_call(value): + self._infer_internal_iterator_binding(name, value, annotated_descriptor) + self._append_node(tn.AssignNode(name=name, value=CodeBlock(self._format_runtime_expression(value)))) + return + + if self._should_emit_external_reassign(name): + self._handle_external_name_reassignment(name, value, source_access, annotated_descriptor) + return + + if _requires_fstring_callback(value): + self.callback_handler.emit_assignment(name, value, 'f-string', _string_scalar_descriptor()) + return + + existing = self.bindings.get(name) + target_descriptor = annotated_descriptor or self.annotated_descriptors.get(name) + + if source_access is not None and isinstance(value, ast.Subscript) and _is_singleton_scalar_memlet( + source_access[1]): + if target_descriptor is None and existing is None: + self._register_binding(name, data.Scalar(source_access[2].dtype, transient=True), kind='scalar') + existing = self.bindings.get(name) + source_access = None + + if source_access is not None: + source_name, memlet, source_desc, view_desc = source_access + + if target_descriptor is not None and isinstance(target_descriptor, data.Reference): + ref_desc = self._ensure_reference_binding(name, target_descriptor) + self._append_node(tn.RefSetNode(target=name, memlet=memlet, src_desc=source_desc, ref_desc=ref_desc)) + return + + if existing is not None and isinstance(existing.descriptor, data.Reference): + ref_desc = self._ensure_reference_binding(name, existing.descriptor) + self._append_node(tn.RefSetNode(target=name, memlet=memlet, src_desc=source_desc, ref_desc=ref_desc)) + return + + if existing is None and self._should_bind_as_reference(value, source_desc): + ref_desc = self._ensure_reference_binding(name, source_desc) + self._append_node(tn.RefSetNode(target=name, memlet=memlet, src_desc=source_desc, ref_desc=ref_desc)) + return + + if existing is None and self._is_aliasable_descriptor(source_desc): + new_view_desc = view_desc or self._make_view_descriptor(source_desc) + self._register_binding(name, new_view_desc, kind='view') + self._append_node( + tn.ViewNode(target=name, + source=source_name, + memlet=memlet, + src_desc=source_desc, + view_desc=new_view_desc)) + return + + if existing is not None and existing.descriptor is not None and self._can_promote_to_reference( + existing.descriptor, source_desc): + ref_desc = self._ensure_reference_binding(name, existing.descriptor) + self._append_node(tn.RefSetNode(target=name, memlet=memlet, src_desc=source_desc, ref_desc=ref_desc)) + return + + inferred_descriptor = self._infer_descriptor(value, name) + if inferred_descriptor is not None: + if existing is None and self._should_bind_expression_as_reference(value, inferred_descriptor): + ref_desc = self._ensure_reference_binding(name, inferred_descriptor) + self._append_node( + tn.RefSetNode(target=name, + memlet=None, + src_desc=inferred_descriptor, + ref_desc=ref_desc, + source_expr=self._format_runtime_expression(value))) + return + kind = 'reference' if isinstance(inferred_descriptor, data.Reference) else 'container' + if self._is_callback_descriptor(inferred_descriptor): + kind = 'callback' + structure = self._runtime_container_structure(name, value, inferred_descriptor) + self._store_binding(name, inferred_descriptor, kind=kind, structure=structure) + else: + scalar_descriptor = self._infer_scalar_descriptor(value, annotated_descriptor) + if scalar_descriptor is not None: + kind = 'callback' if self._is_callback_descriptor(scalar_descriptor) else 'scalar' + self._register_binding(name, scalar_descriptor, kind=kind) + + scalar_descriptor = self._infer_scalar_descriptor(value, annotated_descriptor) + if (isinstance(value, ast.Call) and _is_pyobject_scalar_descriptor(scalar_descriptor) + and self.callback_handler.should_emit_pyobject_call_callback(value)): + self.callback_handler.emit_assignment(name, value, 'pyobject call', scalar_descriptor) + return + + if self.callable_specializer.is_callback_expression(value): + self._append_node(tn.AssignNode(name=name, value=CodeBlock(self._format_runtime_expression(value)))) + return + + if self._should_emit_runtime_container_assignment(name, value): + self._append_node(tn.StatementNode(code=CodeBlock(f'{name} = {self._format_runtime_expression(value)}'))) + return + + if self._emit_computed_assignment(ast.Name(id=name, ctx=ast.Store()), value, annotated_descriptor): + return + + self._append_node(tn.AssignNode(name=name, value=CodeBlock(self._format_runtime_expression(value)))) + + def _runtime_container_structure(self, name: str, value: ast.AST, inferred_descriptor: data.Data) -> Optional[Any]: + if isinstance(inferred_descriptor, PythonDict) and isinstance(value, ast.Dict): + return self.dict_support.infer_literal_binding(self._dict_support_context(name), value) + return None + + def _should_emit_runtime_container_assignment(self, name: str, value: ast.AST) -> bool: + binding_descriptor = self.bindings.get(name).descriptor if name in self.bindings else None + return self._is_runtime_container_descriptor(binding_descriptor) or self._is_runtime_container_literal(value) + + def _is_runtime_container_descriptor(self, descriptor: Optional[data.Data]) -> bool: + return isinstance(descriptor, (PythonDict, PythonList, PythonTuple)) + + def _is_runtime_container_literal(self, value: ast.AST) -> bool: + return isinstance(value, (ast.Dict, ast.List, ast.Tuple)) + + def _handle_external_name_reassignment(self, name: str, value: ast.AST, + source_access: Optional[Tuple[str, Memlet, data.Data, Optional[data.Data]]], + annotated_descriptor: Optional[data.Data]) -> None: + existing = self.bindings.get(name) + target_descriptor = annotated_descriptor or self.annotated_descriptors.get(name) + + if source_access is not None and isinstance(value, ast.Subscript) and _is_singleton_scalar_memlet( + source_access[1]): + if target_descriptor is None and existing is None: + self._register_binding(name, data.Scalar(source_access[2].dtype, transient=True), kind='scalar') + existing = self.bindings.get(name) + source_access = None + + if source_access is not None: + _, _, source_desc, view_desc = source_access + + if target_descriptor is not None and isinstance(target_descriptor, data.Reference): + self._ensure_reference_binding(name, target_descriptor) + elif existing is not None and isinstance(existing.descriptor, data.Reference): + self._ensure_reference_binding(name, existing.descriptor) + elif existing is None and self._should_bind_as_reference(value, source_desc): + self._ensure_reference_binding(name, source_desc) + elif existing is None and self._is_aliasable_descriptor(source_desc): + self._register_binding(name, view_desc or self._make_view_descriptor(source_desc), kind='view') + elif existing is not None and existing.descriptor is not None and self._can_promote_to_reference( + existing.descriptor, source_desc): + self._ensure_reference_binding(name, existing.descriptor) + + inferred_descriptor = self._infer_descriptor(value, name) + if inferred_descriptor is not None: + self._register_binding(name, inferred_descriptor, kind=_binding_kind_for_descriptor(inferred_descriptor)) + else: + scalar_descriptor = self._infer_scalar_descriptor(value, annotated_descriptor) + if scalar_descriptor is not None: + self._register_binding(name, scalar_descriptor, kind=_binding_kind_for_descriptor(scalar_descriptor)) + + scope_kind = self._external_scope_kind(name) + if scope_kind is None: + raise DaceSyntaxError(self, self._program_node(), f'Could not determine external scope kind for "{name}"') + self._append_node( + tn.ReassignExternalNode(name=name, + value=CodeBlock(self._format_runtime_expression(value)), + scope=scope_kind)) + + def _handle_expression(self, value: ast.AST) -> bool: + if not isinstance(value, ast.Call): + return False + + library_info = self._library_info_for_call(value) + if library_info is None: + return False + + in_memlets = self._collect_input_memlets(value) + if not in_memlets: + return False + + library_name, library_properties = library_info + + self._append_node( + tn.LibraryCall(node=tn.FrontendLibrary(name=library_name, properties=library_properties), + in_memlets=in_memlets, + out_memlets=set())) + return True + + def _emit_sdfg_call_assignment(self, target: ast.AST, value: ast.Call, + annotated_descriptor: Optional[data.Data]) -> bool: + if isinstance(target, ast.Name): + descriptor = annotated_descriptor + if descriptor is None: + binding = self.bindings.get(target.id) + if binding is not None: + descriptor = binding.descriptor + if descriptor is None: + descriptor = _pyobject_scalar_descriptor() + self._register_binding(target.id, descriptor, kind=_binding_kind_for_descriptor(descriptor)) + return self._emit_sdfg_call(value, return_targets=[target.id]) + + return self._emit_sdfg_call(value, return_targets=[_unparse(target)]) + + def _emit_computed_assignment(self, target: ast.AST, value: ast.AST, + annotated_descriptor: Optional[data.Data]) -> bool: + output = self._resolve_output_target(target, value, annotated_descriptor) + out_memlet = output[1] if output is not None else None + + lowered = self.array_literal_support.lower_assignment(self._array_literal_context(), target, value, + annotated_descriptor) + if lowered is not None: + self._append_node(lowered) + return True + + lowered = self.numpy_support.lower_assignment(self._numpy_lowering_context(), target, value, + annotated_descriptor) + if lowered is not None: + self._append_node(lowered) + return True + + if isinstance(target, (ast.Tuple, ast.List)) and isinstance(value, ast.Call): + if self._emit_structured_library_call_assignment(target, value): + return True + + if output is not None and isinstance(value, ast.Call): + library_info = self._library_info_for_call(value) + if library_info is not None: + library_name, library_properties = library_info + in_memlets = self._collect_input_memlets(value) + if not in_memlets: + return False + self._append_node( + tn.LibraryCall(node=tn.FrontendLibrary(name=library_name, properties=library_properties), + in_memlets=in_memlets, + out_memlets={'out': out_memlet})) + return True + + if output is not None and isinstance(value, ast.Attribute): + library_info = self._library_info_for_attribute(value) + if library_info is not None: + library_name, library_properties = library_info + in_memlets = self._collect_input_memlets(value) + if not in_memlets: + return False + self._append_node( + tn.LibraryCall(node=tn.FrontendLibrary(name=library_name, properties=library_properties), + in_memlets=in_memlets, + out_memlets={'out': out_memlet})) + return True + + lowered = self.expression_support.lower_assignment(self._expression_planning_context(), target, value, + annotated_descriptor) + if lowered is not None: + self._append_node(lowered) + return True + + if output is None: + return False + + _, out_memlet, _ = output + + if isinstance(value, ast.Call) and self._should_lower_as_library_call(value): + in_memlets = self._collect_input_memlets(value) + if not in_memlets: + return False + self._append_node( + tn.LibraryCall(node=tn.FrontendLibrary(name=astutils.rname(value.func), + properties=self._library_properties(value)), + in_memlets=in_memlets, + out_memlets={'out': out_memlet})) + return True + + in_memlets = self._collect_input_memlets(value) + if not in_memlets and not isinstance(target, ast.Attribute): + return False + + tasklet_code = f'out = {_unparse(value)}' if isinstance( + target, ast.Attribute) else f'{_unparse(target)} = {_unparse(value)}' + tasklet = tn.FrontendTasklet(name=self._tasklet_name(target), code=CodeBlock(tasklet_code)) + self._append_node(tn.TaskletNode(node=tasklet, in_memlets=in_memlets, out_memlets={'out': out_memlet})) + return True + + def _emit_structured_library_call_assignment(self, target: ast.AST, value: ast.Call) -> bool: + library_info = self._library_info_for_call(value) + if library_info is None or not isinstance(target, (ast.Tuple, ast.List)): + return False + + target_structure, _ = self._structure_from_ast(target) + if not isinstance(target_structure, (tuple, list)) or len(target_structure) != len(target.elts): + return False + + out_memlets: Dict[str, Memlet] = {} + for index, (element, element_structure) in enumerate(zip(target.elts, target_structure)): + descriptor = descriptor_from_structure(element_structure) + if descriptor is None: + return False + output = self._resolve_output_target(element, value, descriptor) + if output is None: + return False + _, output_memlet, _ = output + out_memlets[f'out{index}'] = output_memlet + + library_name, library_properties = library_info + self._append_node( + tn.LibraryCall(node=tn.FrontendLibrary(name=library_name, properties=library_properties), + in_memlets=self._collect_input_memlets(value), + out_memlets=out_memlets)) + return True + + def _ensure_pythonclass_member_target( + self, target: ast.AST, value: ast.AST, source_access: Optional[Tuple[str, Memlet, data.Data, + Optional[data.Data]]]) -> None: + if not isinstance(target, ast.Attribute): + return + + root = self._attribute_root_and_members(target) + if root is None: + return + root_name, member_names = root + + binding = self.bindings.get(root_name) + if binding is None or binding.descriptor is None or not isinstance(binding.descriptor, PythonClass): + return + + member_descriptor = self._infer_pythonclass_member_descriptor(value, source_access) + if member_descriptor is None: + return + + for descriptor in (binding.descriptor, self.root.containers.get(root_name), self.globals.get(root_name)): + if isinstance(descriptor, data.Data): + ensure_nested_member_descriptor(descriptor, member_names, member_descriptor) + + def _infer_pythonclass_member_descriptor( + self, value: ast.AST, source_access: Optional[Tuple[str, Memlet, data.Data, + Optional[data.Data]]]) -> Optional[data.Data]: + if source_access is not None: + _, memlet, source_desc, view_desc = source_access + if _is_singleton_scalar_memlet(memlet): + return data.Scalar(source_desc.dtype) + return data.Reference.view(view_desc or source_desc) + + scalar_descriptor = self._infer_scalar_descriptor(value, None) + if scalar_descriptor is not None: + return scalar_descriptor + + computed_descriptor = self._infer_compute_descriptor(value) + if isinstance(computed_descriptor, data.Scalar): + return computed_descriptor + + return None + + def _format_runtime_expression(self, node: ast.AST) -> str: + return _unparse(self.attribute_rewriter.rewrite_expression(node)) + + def _format_assignment_statement(self, target: ast.AST, value: ast.AST) -> str: + rewritten_call = self.attribute_rewriter.rewrite_assignment(target, value) + if rewritten_call is not None: + return _unparse(rewritten_call) + return f'{_unparse(target)} = {self._format_runtime_expression(value)}' + + def _resolve_data_access(self, node: ast.AST) -> Optional[Tuple[str, Memlet, data.Data, Optional[data.Data]]]: + if isinstance(node, ast.Name): + if node.id not in self.bindings: + external_value = self._resolve_external_scope_value(node.id) + if external_value is not UNRESOLVED: + if symbolic.issymbolic(external_value): + return None + try: + descriptor = _binding_to_descriptor(external_value) + except Exception: + descriptor = None + if descriptor is not None: + self._store_binding(node.id, descriptor, kind=_binding_kind_for_descriptor(descriptor)) + + if node.id in self.bindings and self.bindings[node.id].descriptor is not None: + descriptor = _clone_descriptor(self.bindings[node.id].descriptor) + if isinstance(descriptor, PythonDict): + return None + return (node.id, Memlet.from_array(node.id, descriptor), descriptor, _clone_descriptor(descriptor)) + + if isinstance(node, ast.Attribute): + owner_access = self._resolve_data_access(node.value) + if owner_access is None: + return None + owner_name, _, owner_descriptor, _ = owner_access + member_access = resolve_member_access(owner_name, owner_descriptor, node.attr) + if member_access is None or isinstance(member_access.descriptor, PythonDict): + return None + descriptor = member_access.descriptor + data_name = member_access.data_name + return (data_name, Memlet.from_array(data_name, descriptor), descriptor, _clone_descriptor(descriptor)) + + if isinstance(node, ast.Subscript): + base_access = self._resolve_data_access(node.value) + if base_access is None: + return None + base_name, _, descriptor, _ = base_access + descriptor = _clone_descriptor(descriptor) + if isinstance(descriptor, PythonDict): + return None + try: + subset, new_axes, arrdims = memlet_parser.parse_memlet_subset(descriptor, node, + self._evaluation_context()) + except Exception: + return None + if arrdims: + return None + memlet = Memlet(data=base_name, subset=subset) + return (base_name, memlet, descriptor, self._make_view_descriptor(descriptor, subset.size(), new_axes)) + + return None + + def _resolve_output_target(self, target: ast.AST, value: ast.AST, + annotated_descriptor: Optional[data.Data]) -> Optional[Tuple[str, Memlet, data.Data]]: + if isinstance(target, ast.Name): + if annotated_descriptor is not None and isinstance(annotated_descriptor, data.Reference): + return None + + existing = self.bindings.get(target.id) + if existing is not None and existing.descriptor is not None and not isinstance( + existing.descriptor, data.Reference): + descriptor = _clone_descriptor(existing.descriptor) + return (target.id, Memlet.from_array(target.id, descriptor), descriptor) + + if annotated_descriptor is not None: + kind = 'scalar' if isinstance(annotated_descriptor, data.Scalar) else 'container' + self._register_binding(target.id, annotated_descriptor, kind=kind) + descriptor = _clone_descriptor(self.bindings[target.id].descriptor) + return (target.id, Memlet.from_array(target.id, descriptor), descriptor) + + inferred_descriptor = self._infer_compute_descriptor(value) + if inferred_descriptor is None: + return None + kind = 'scalar' if isinstance(inferred_descriptor, data.Scalar) else 'container' + self._register_binding(target.id, inferred_descriptor, kind=kind) + descriptor = _clone_descriptor(self.bindings[target.id].descriptor) + return (target.id, Memlet.from_array(target.id, descriptor), descriptor) + + target_access = self._resolve_data_access(target) + if target_access is None: + return None + target_name, target_memlet, descriptor, _ = target_access + return (target_name, target_memlet, descriptor) + + def _infer_compute_descriptor(self, node: ast.AST) -> Optional[data.Data]: + inferred_descriptor = self._infer_plannable_expression_descriptor(node) + if inferred_descriptor is not None: + return inferred_descriptor + + access = self._resolve_data_access(node) + if access is not None: + _, _, descriptor, view_descriptor = access + result = _clone_descriptor(view_descriptor or descriptor) + result.transient = True + return result + + for _, _, descriptor, _ in self._collect_expression_accesses(node): + result = _clone_descriptor(descriptor) + result.transient = True + return result + return None + + def _collect_expression_accesses(self, node: ast.AST) -> List[Tuple[str, Memlet, data.Data, Optional[data.Data]]]: + accesses: List[Tuple[str, Memlet, data.Data, Optional[data.Data]]] = [] + seen = set() + + def _visit(current: ast.AST) -> None: + access = self._resolve_data_access(current) + if access is not None: + name, memlet, descriptor, view_descriptor = access + key = (name, str(memlet.subset), str(memlet.other_subset) if memlet.other_subset is not None else '') + if key not in seen: + seen.add(key) + accesses.append((name, memlet, descriptor, view_descriptor)) + return + + for child in ast.iter_child_nodes(current): + _visit(child) + + _visit(node) + return accesses + + def _collect_input_memlets(self, node: ast.AST) -> Dict[str, Memlet]: + result: Dict[str, Memlet] = {} + for index, (_, memlet, _, _) in enumerate(self._collect_expression_accesses(node)): + result[f'in{index}'] = memlet + return result + + def _infer_descriptor(self, node: ast.AST, target_name: str) -> Optional[data.Data]: + if isinstance(node, ast.Dict): + return self.dict_support.infer_literal_descriptor(self._dict_support_context(target_name), node) + + if isinstance(node, ast.Call): + inferred = self.array_literal_support.infer_expression_descriptor(self._array_literal_context(), node) + if inferred is not None: + return inferred + + if isinstance(node, (ast.List, ast.Tuple)): + structure, _ = self._structure_from_ast(node) + if structure is not None: + return descriptor_from_structure(structure) + + if isinstance(node, ast.Attribute): + access = self._resolve_data_access(node) + if access is not None: + _, _, descriptor, view_descriptor = access + result = _clone_descriptor(view_descriptor or descriptor) + result.transient = True + return result + + def _infer_operator_operand(operand: ast.AST) -> Optional[Any]: + descriptor = self._infer_descriptor(operand, target_name) + if descriptor is not None: + return descriptor + + value = try_resolve_static_value(operand, self._evaluation_context()) + if value is not UNRESOLVED: + return value + + return self._infer_scalar_descriptor(operand, None) + + if isinstance(node, ast.BinOp): + + left_operand = _infer_operator_operand(node.left) + right_operand = _infer_operator_operand(node.right) + if left_operand is not None and right_operand is not None: + infer_fn = oprepo.Replacements.get_operator_descriptor_inference( + type(node.op).__name__, left_operand, right_operand) + if infer_fn is not None: + try: + inferred = infer_fn(left_operand, right_operand) + except Exception: + inferred = None + if inferred is not None: + return inferred + + if isinstance(node, ast.UnaryOp): + + operand = _infer_operator_operand(node.operand) + if operand is not None: + infer_fn = oprepo.Replacements.get_operator_descriptor_inference(type(node.op).__name__, operand) + if infer_fn is not None: + try: + inferred = infer_fn(operand) + except Exception: + inferred = None + if inferred is not None: + return inferred + + if isinstance(node, ast.BoolOp) and len(node.values) > 0: + + infer_fn = oprepo.Replacements.get_operator_descriptor_inference(type(node.op).__name__) + current = _infer_operator_operand(node.values[0]) + if current is not None: + for value in node.values[1:]: + next_operand = _infer_operator_operand(value) + if current is None or next_operand is None: + current = None + break + infer_fn = oprepo.Replacements.get_operator_descriptor_inference( + type(node.op).__name__, current, next_operand) + if infer_fn is None: + current = None + break + try: + current = infer_fn(current, next_operand) + except Exception: + current = None + break + if current is not None: + return current + + if isinstance(node, ast.Compare) and len(node.ops) == 1 and len(node.comparators) == 1: + + left_operand = _infer_operator_operand(node.left) + right_operand = _infer_operator_operand(node.comparators[0]) + if left_operand is not None and right_operand is not None: + infer_fn = oprepo.Replacements.get_operator_descriptor_inference( + type(node.ops[0]).__name__, left_operand, right_operand) + if infer_fn is not None: + try: + inferred = infer_fn(left_operand, right_operand) + except Exception: + inferred = None + if inferred is not None: + return inferred + + if isinstance(node, ast.Subscript): + base_descriptor: Optional[data.Data] = None + base_access = self._resolve_data_access(node.value) + if base_access is not None: + _, _, base_descriptor, _ = base_access + elif isinstance(node.value, ast.Name): + binding = self.bindings.get(node.value.id) + if binding is not None and binding.descriptor is not None: + base_descriptor = binding.descriptor + elif isinstance(node.value, ast.Attribute): + base_descriptor = self._infer_descriptor(node.value, target_name) + + if base_descriptor is not None: + dict_binding = None + if isinstance(node.value, ast.Name): + binding = self.bindings.get(node.value.id) + if binding is not None and isinstance(binding.structure, StaticDictBinding): + dict_binding = binding.structure + dict_descriptor = self.dict_support.infer_subscript_descriptor(self._dict_support_context(target_name), + base_descriptor, node.slice, + dict_binding) + if dict_descriptor is not None: + return dict_descriptor + inferred = _infer_static_subscript_descriptor(base_descriptor, node, self._evaluation_context()) + if inferred is not None: + return inferred + + if isinstance(node, ast.Lambda): + return data.Scalar(dtypes.callback(None), transient=True) + + if isinstance(node, ast.Call): + # Try the method descriptor-inference registry first (a.sum(), a.reshape(), etc.) + if isinstance(node.func, ast.Attribute): + inferred = self._try_method_descriptor_inference(node) + if inferred is not None: + return inferred + + inferred = self._try_ufunc_descriptor_inference(node) + if inferred is not None: + return inferred + + # Try the free-function descriptor-inference registry (numpy.sum(), etc.) + inferred = self._try_descriptor_inference(node) + if inferred is not None: + return inferred + + # Attribute inference (a.T, a.flat, a.real, a.imag, etc.) + if isinstance(node, ast.Attribute): + inferred = self._try_attribute_descriptor_inference(node) + if inferred is not None: + return inferred + + value = try_resolve_static_value(node, self._evaluation_context()) + if value is not UNRESOLVED: + try: + descriptor = data.create_datadescriptor(value) + descriptor.transient = True + return descriptor + except Exception: + pass + + return None + + def _try_descriptor_inference(self, node: ast.Call) -> Optional[data.Data]: + """Query the descriptor-inference registry for a call node.""" + + call_name = self._resolved_callable_name(node.func) + infer_fn = oprepo.Replacements.get_descriptor_inference(call_name) + if infer_fn is None: + textual_name = astutils.rname(node.func) + if textual_name != call_name: + infer_fn = oprepo.Replacements.get_descriptor_inference(textual_name) + if infer_fn is None: + return None + input_descs, args, kwargs = self._resolve_call_inputs(node) + try: + result = infer_fn(input_descs, *args, **kwargs) + except Exception: + return None + binding = _binding_from_inference_result(result) + return None if binding is None else binding.descriptor + + def _try_method_descriptor_inference(self, node: ast.Call) -> Optional[data.Data]: + """Query the method descriptor-inference registry for ``obj.method(...)`` calls.""" + + if not isinstance(node.func, ast.Attribute): + return None + # Resolve the object (e.g. ``a`` in ``a.sum()``) + obj_access = self._resolve_data_access(node.func.value) + if obj_access is not None: + _, _, obj_desc, _ = obj_access + else: + obj_desc = try_resolve_static_value(node.func.value, self._evaluation_context()) + if obj_desc is UNRESOLVED: + return None + method_name = node.func.attr + infer_fn = oprepo.Replacements.get_method_descriptor_inference(type(obj_desc), method_name) + if infer_fn is None: + return None + # Resolve the remaining arguments (skip 'self') + _input_descs, args, kwargs = self._resolve_call_inputs(node) + try: + result = infer_fn(obj_desc, *args, **kwargs) + except Exception: + return None + binding = _binding_from_inference_result(result) + return None if binding is None else binding.descriptor + + def _apply_method_self_descriptor_side_effect(self, node: ast.AST) -> None: + + if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Attribute): + return + if not isinstance(node.func.value, ast.Name): + return + + obj_name = node.func.value.id + obj_access = self._resolve_data_access(node.func.value) + if obj_access is None: + return + _, _, obj_desc, _ = obj_access + + infer_fn = oprepo.Replacements.get_method_self_descriptor_inference(type(obj_desc).__name__, node.func.attr) + if infer_fn is None: + return + + _input_descs, args, kwargs = self._resolve_call_inputs(node) + try: + updated_self = infer_fn(obj_desc, *args, **kwargs) + except Exception: + return + if not isinstance(updated_self, data.Data): + return + + kind = self.bindings.get(obj_name).kind if obj_name in self.bindings else _binding_kind_for_descriptor( + updated_self) + self._store_binding(obj_name, updated_self, kind=kind) + + def _try_attribute_descriptor_inference(self, node: ast.Attribute) -> Optional[data.Data]: + """Query the attribute descriptor-inference registry for ``obj.attr`` accesses.""" + + obj_access = self._resolve_data_access(node.value) + if obj_access is None: + return None + _, _, obj_desc, _ = obj_access + classname = type(obj_desc).__name__ + infer_fn = oprepo.Replacements.get_attribute_descriptor_inference(classname, node.attr) + if infer_fn is None: + return None + try: + result = infer_fn(obj_desc) + except Exception: + return None + binding = _binding_from_inference_result(result) + return None if binding is None else binding.descriptor + + def _try_ufunc_descriptor_inference(self, node: ast.Call) -> Optional[data.Data]: + """Query the descriptor-inference registry for a NumPy ufunc call or ufunc method.""" + + target = _resolve_ufunc_inference_target(node, self._evaluation_context()) + if target is None: + return None + + method_name, ufunc_name = target + infer_fn = oprepo.Replacements.get_ufunc_descriptor_inference(method_name) + if infer_fn is None: + return None + + input_descs, args, kwargs = self._resolve_call_inputs(node) + try: + result = infer_fn(input_descs, ufunc_name, *args, **kwargs) + except Exception: + return None + binding = _binding_from_inference_result(result) + return None if binding is None else binding.descriptor + + def _resolve_call_inputs(self, call_node: ast.Call) -> tuple: + """Resolve call arguments to ``(input_descriptors, args, kwargs)``.""" + input_descs: Dict[str, data.Data] = {} + args: list = [] + for arg in call_node.args: + access = self._resolve_data_access(arg) + if access is not None: + name, memlet, desc, view_desc = access + is_scalar_memlet = _is_singleton_scalar_memlet(memlet) + if is_scalar_memlet: + base_descriptor = view_desc or desc + resolved_desc = data.Scalar(base_descriptor.dtype, transient=True) + key = _unparse(arg) + else: + resolved_desc = _clone_descriptor(view_desc or desc) + key = name + input_descs[key] = resolved_desc + args.append(key) + else: + val = try_resolve_static_value(arg, self._evaluation_context()) + args.append(val if val is not UNRESOLVED else _unparse(arg)) + kwargs: dict = {} + for kw in call_node.keywords: + if kw.arg is None: + continue + val = try_resolve_static_value(kw.value, self._evaluation_context()) + kwargs[kw.arg] = val if val is not UNRESOLVED else _unparse(kw.value) + return input_descs, args, kwargs + + def _infer_scalar_descriptor(self, node: ast.AST, annotated_descriptor: Optional[data.Data]) -> Optional[data.Data]: + if annotated_descriptor is not None and isinstance(annotated_descriptor, data.Scalar): + return _clone_descriptor(annotated_descriptor) + + if isinstance(node, (ast.JoinedStr, ast.FormattedValue)): + return _string_scalar_descriptor() + + scalar_types = { + name: binding.descriptor.dtype + for name, binding in self.bindings.items() + if binding.descriptor is not None and isinstance(binding.descriptor, data.Scalar) + } + try: + inferred_type = infer_expr_type(_unparse(node), scalar_types) + except Exception: + inferred_type = None + if inferred_type is not None: + return data.Scalar(inferred_type, transient=True) + + value = try_resolve_static_value(node, self._evaluation_context()) + if value is not UNRESOLVED and value is not None: + try: + descriptor = _clone_descriptor(data.create_datadescriptor(value)) + except Exception: + descriptor = None + if isinstance(descriptor, data.Scalar): + descriptor.transient = True + return descriptor + + if isinstance(value, numbers.Number) or isinstance(value, bool): + dtype = _normalize_dtype(type(value)) + if dtype is not None: + return data.Scalar(dtype, transient=True) + + if _should_fallback_to_pyobject_scalar(node, value): + return _pyobject_scalar_descriptor() + + if value is UNRESOLVED: + return None + + return None + + def _evaluate_descriptor(self, node: Optional[ast.AST]) -> Optional[data.Data]: + if node is None: + return None + class_type = self._evaluate_annotation_class_type(node) + if class_type is not None: + try: + return data.Structure.from_class(class_type) + except (TypeError, ValueError): + return None + try: + value = astutils.evalnode(node, self._evaluation_context()) + except Exception: + return None + if isinstance(value, data.Data): + return _clone_descriptor(value) + dtype = _normalize_dtype(value) + if dtype is not None: + return data.Scalar(dtype, transient=True) + return None + + def _evaluate_annotation_class_type(self, node: Optional[ast.AST]) -> Optional[type[Any]]: + if node is None: + return None + try: + value = astutils.evalnode(node, self._evaluation_context()) + except Exception: + return None + return direct_class_annotation_type(value) + + def _initialize_direct_class_annotations(self) -> None: + program = self._program_node() + arguments = list(program.args.posonlyargs) + list(program.args.args) + list(program.args.kwonlyargs) + if program.args.vararg is not None: + arguments.append(program.args.vararg) + if program.args.kwarg is not None: + arguments.append(program.args.kwarg) + + resolved_arg_annotations = self.parsed_ast.resolved_arg_annotations or {} + + for argument in arguments: + resolved_annotation = resolved_arg_annotations.get(argument.arg, None) + descriptor = None + class_type = None + if resolved_annotation is not None: + class_type = direct_class_annotation_type(resolved_annotation) + if isinstance(resolved_annotation, data.Data): + descriptor = _clone_descriptor(resolved_annotation) + + if descriptor is None: + descriptor = self._evaluate_descriptor(argument.annotation) + if class_type is None: + class_type = self._evaluate_annotation_class_type(argument.annotation) + + if class_type is not None: + self.annotated_class_types[argument.arg] = class_type + elif isinstance(descriptor, data.Structure): + self.explicit_structure_argument_names.add(argument.arg) + + def _parse_shape(self, node: ast.AST) -> List[Any]: + value = try_resolve_static_value(node, self._evaluation_context()) + if value is UNRESOLVED: + value = None + + if isinstance(value, (list, tuple)): + return [self._shape_dim(dim) for dim in value] + if value is not None: + return [self._shape_dim(value)] + + if isinstance(node, (ast.List, ast.Tuple)): + return [self._shape_dim(symbolic.pystr_to_symbolic(_unparse(elem))) for elem in node.elts] + return [self._shape_dim(symbolic.pystr_to_symbolic(_unparse(node)))] + + def _parse_dtype(self, node: Optional[ast.AST]) -> Optional[dtypes.typeclass]: + if node is None: + return None + value = try_resolve_static_value(node, self._evaluation_context()) + if value is UNRESOLVED: + return None + return _normalize_dtype(value) + + def _shape_dim(self, value: Any) -> Any: + if isinstance(value, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)): + return value + if isinstance(value, str): + return symbolic.pystr_to_symbolic(value) + return value + + def _call_argument(self, node: ast.Call, position: int, keyword: str) -> Optional[ast.AST]: + if len(node.args) > position: + return node.args[position] + for kw in node.keywords: + if kw.arg == keyword: + return kw.value + return None + + def _library_properties(self, node: ast.Call) -> Dict[str, Any]: + return {kw.arg: _unparse(kw.value) for kw in node.keywords if kw.arg is not None} + + def _tasklet_name(self, target: ast.AST) -> str: + if isinstance(target, ast.Name): + return f'{target.id}_tasklet' + if isinstance(target, ast.Subscript): + return f'{_unparse(target.value)}_tasklet' + return 'tasklet' + + def _fresh_transient_name(self, prefix: str = '__stree_tmp') -> str: + index = 0 + candidate = prefix + while candidate in self.bindings or candidate in self.root.containers or candidate in self.globals: + index += 1 + candidate = f'{prefix}{index}' + return candidate + + def _fresh_callback_name(self, prefix: str = '__stree_callback') -> str: + index = 0 + candidate = prefix + while (candidate in self.bindings or candidate in self.root.containers or candidate in self.globals + or candidate in self.callable_bindings): + index += 1 + candidate = f'{prefix}{index}' + return candidate + + def _array_constructor_name(self) -> str: + return 'numpy.array' + + def _materialize_temporary_expression(self, value: ast.AST, descriptor: data.Data) -> ast.AST: + name = self._fresh_transient_name() + kind = 'scalar' if isinstance(descriptor, data.Scalar) else 'container' + self._register_binding(name, descriptor, kind=kind) + target = ast.Name(id=name, ctx=ast.Store()) + + if isinstance(value, ast.Call) and self.callable_resolver.is_dace_program_call(value): + self._materialize_call_args(value) + self._emit_function_call(value, return_targets=[name]) + return ast.Name(id=name, ctx=ast.Load()) + + if isinstance(value, ast.Call) and self.callable_resolver.is_sdfg_call(value): + self._materialize_call_args(value) + if self._emit_sdfg_call(value, return_targets=[name]): + return ast.Name(id=name, ctx=ast.Load()) + + if isinstance(value, ast.Call) and _is_pyobject_scalar_descriptor(descriptor): + self.callback_handler.emit_assignment(name, value, 'pyobject call', descriptor) + return ast.Name(id=name, ctx=ast.Load()) + + if self._emit_computed_assignment(target, value, descriptor): + return ast.Name(id=name, ctx=ast.Load()) + + if isinstance(descriptor, data.Scalar): + self._append_node(tn.AssignNode(name=name, value=CodeBlock(_unparse(value)))) + else: + self._append_node(tn.StatementNode(code=CodeBlock(f'{name} = {_unparse(value)}'))) + return ast.Name(id=name, ctx=ast.Load()) + + def _register_binding(self, name: str, descriptor: data.Data, kind: str) -> None: + self._store_binding(name, descriptor, kind=kind) + + def _store_binding(self, + name: str, + descriptor: Optional[data.Data], + *, + kind: str, + structure: Optional[Any] = None) -> None: + cloned = _clone_descriptor(descriptor) if descriptor is not None else None + stored_structure = copy.deepcopy(structure) + self.bindings[name] = _Binding(descriptor=cloned, kind=kind, structure=stored_structure) + if cloned is None: + return + self.root.containers[name] = _clone_descriptor(cloned) + self.globals[name] = cloned + for free_symbol in cloned.free_symbols: + self.root.symbols[free_symbol.name] = free_symbol + self.globals[free_symbol.name] = free_symbol + + def _register_symbol(self, name: str, dtype: dtypes.typeclass = dtypes.int64) -> symbolic.symbol: + existing = self.root.symbols.get(name) + if isinstance(existing, symbolic.symbol): + return existing + + symbol_value = symbolic.symbol(name, dtype) + self.root.symbols[name] = symbol_value + self.globals[name] = symbol_value + return symbol_value + + def _clone_constants(self, constants: Optional[Dict[str, Tuple[data.Data, + Any]]]) -> Dict[str, Tuple[data.Data, Any]]: + if not constants: + return {} + return {name: (_clone_descriptor(descriptor), value) for name, (descriptor, value) in constants.items()} + + def _infer_internal_iterator_binding(self, name: str, value: ast.AST, + annotated_descriptor: Optional[data.Data]) -> None: + inferred_binding = self.inferred_bindings.get(name) + if inferred_binding is not None: + self._store_binding(name, + inferred_binding.descriptor, + kind=inferred_binding.kind, + structure=inferred_binding.structure) + return + + scalar_descriptor = self._infer_scalar_descriptor(value, annotated_descriptor) + if scalar_descriptor is not None: + self._store_binding(name, scalar_descriptor, kind='iterator-index', structure=scalar_descriptor) + + def _handle_container_initialization(self, name: str, value: ast.AST) -> None: + inferred_binding = self.inferred_bindings.get(name) + if inferred_binding is not None: + self._store_binding(name, + inferred_binding.descriptor, + kind=inferred_binding.kind, + structure=inferred_binding.structure) + return + + descriptor = self._infer_descriptor(value, name) + if descriptor is None: + return + kind = 'reference' if isinstance(descriptor, data.Reference) else 'container' + structure, _ = self._structure_from_ast(value) + self._store_binding(name, descriptor, kind=kind, structure=structure) + + def _handle_tuple_element_assignment(self, target: ast.AST, value: ast.AST) -> None: + if not isinstance(target, ast.Name): + self._handle_assignment(target, value) + return + + self.callback_handler.reject_mutated_global_uses(value) + value = self.lambda_resolver.inline_known_lambda_calls(value) + existing = self.bindings.get(target.id) + if existing is not None and isinstance(existing.descriptor, data.Reference): + self._handle_assignment(target, value) + return + + source_access = self._resolve_data_access(value) + if source_access is not None: + _, memlet, source_desc, view_desc = source_access + inferred_binding = self.inferred_bindings.get(target.id) + if inferred_binding is not None and inferred_binding.descriptor is not None and not isinstance( + inferred_binding.descriptor, data.Scalar): + self._store_binding(target.id, + _copy_target_descriptor(inferred_binding.descriptor), + kind='container', + structure=inferred_binding.structure) + self._append_node(tn.CopyNode(target=target.id, memlet=copy.deepcopy(memlet))) + return + + if _is_singleton_scalar_memlet(memlet): + inferred_binding = self.inferred_bindings.get(target.id) + if inferred_binding is not None and isinstance(inferred_binding.descriptor, data.Scalar): + self._store_binding(target.id, + inferred_binding.descriptor, + kind='scalar', + structure=inferred_binding.structure) + self._append_node( + tn.AssignNode(name=target.id, value=CodeBlock(self._format_runtime_expression(value)))) + return + if existing is None: + self._register_binding(target.id, data.Scalar(source_desc.dtype, transient=True), kind='scalar') + self._append_node(tn.AssignNode(name=target.id, + value=CodeBlock(self._format_runtime_expression(value)))) + return + + descriptor = _copy_target_descriptor(view_desc or source_desc) + if existing is None: + self._register_binding(target.id, descriptor, kind='container') + self._append_node(tn.CopyNode(target=target.id, memlet=copy.deepcopy(memlet))) + return + + self._handle_assignment(target, value) + + def _seed_inferred_target_bindings(self, target: ast.AST) -> None: + for child in ast.walk(target): + if not isinstance(child, ast.Name) or not isinstance(child.ctx, ast.Store): + continue + inferred_binding = self.inferred_bindings.get(child.id) + if inferred_binding is None or inferred_binding.descriptor is None: + continue + self._store_binding(child.id, + inferred_binding.descriptor, + kind=inferred_binding.kind, + structure=inferred_binding.structure) + + def _structure_from_ast(self, node: ast.AST) -> Tuple[Optional[Any], bool]: + if isinstance(node, ast.Name): + binding = self.bindings.get(node.id) + if binding is None: + return (None, False) + structure = binding.structure if binding.structure is not None else binding.descriptor + uses_internal = binding.kind.startswith('iterator') or self._is_internal_iterator_binding_name(node.id) + return (copy.deepcopy(structure), uses_internal) + + if isinstance(node, (ast.Tuple, ast.List)): + elements: List[Any] = [] + uses_internal = False + for element in node.elts: + substructure, sub_internal = self._structure_from_ast(element) + if substructure is None: + return (None, False) + elements.append(substructure) + uses_internal = uses_internal or sub_internal + if isinstance(node, ast.List): + return (elements, uses_internal) + return (tuple(elements), uses_internal) + + return (None, False) + + def _ensure_reference_binding(self, name: str, descriptor: data.Data) -> data.Data: + existing = self.bindings.get(name) + if existing is not None and existing.descriptor is not None and isinstance(existing.descriptor, data.Reference): + return _clone_descriptor(existing.descriptor) + ref_desc = data.Reference.view(descriptor) + self._register_binding(name, ref_desc, kind='reference') + return _clone_descriptor(ref_desc) + + def _make_view_descriptor(self, + descriptor: data.Data, + shape: Optional[Sequence[Any]] = None, + new_axes: Optional[Sequence[int]] = None) -> data.Data: + view_desc = data.View.view(descriptor) + if shape is None: + shape = descriptor.shape + shape_list = list(shape) + if new_axes: + for axis in sorted(new_axes): + shape_list.insert(axis, 1) + if hasattr(view_desc, 'set_shape'): + view_desc.set_shape(shape_list) + return view_desc + + def _is_aliasable_descriptor(self, descriptor: data.Data) -> bool: + return not isinstance(descriptor, data.Scalar) + + def _should_bind_expression_as_reference(self, value: ast.AST, descriptor: data.Data) -> bool: + if not self._is_aliasable_descriptor(descriptor): + return False + if isinstance(value, ast.Attribute) and self._library_info_for_attribute(value) is not None: + return False + return isinstance(value, ast.Attribute) + + # ------------------------------------------------------------------ # + # Function-call detection for nested @dace.program calls # + # ------------------------------------------------------------------ # + + def _emit_function_call(self, call_node: ast.Call, return_targets: Optional[List[str]] = None) -> None: + """Create a :class:`FunctionCallScope` placeholder and append it.""" + callee = self.callable_resolver.resolve_callable_value(call_node.func) + callee_name = self.callable_resolver.callable_name(callee) + arguments = self.callable_resolver.extract_argument_mapping(call_node, self._format_runtime_expression) + specialization_args, specialization_kwargs, lambda_bindings, callable_bindings = self.callable_specializer.extract_call_specialization( + call_node, _unparse) + + scope = tn.FunctionCallScope( + children=[], + call=tn.FrontendFunctionCall(callee_name=callee_name, arguments=arguments), + ) + # Transient metadata consumed by the inlining pass. + scope._callee_program = callee + scope._call_node = call_node + scope._call_args = specialization_args + scope._call_kwargs = specialization_kwargs + scope._lambda_bindings = lambda_bindings + scope._callable_bindings = callable_bindings + scope._captured_names = set(getattr(callee, 'captured_names', set())) + scope._return_targets = return_targets + self._append_node(scope) + + def _resolve_sdfg_call(self, call_node: ast.Call) -> Optional[Any]: + callee = self.callable_resolver.resolve_callable_value(call_node.func) + from dace import SDFG + + if isinstance(callee, SDFG): + return callee + + if not hasattr(callee, '__sdfg__') or hasattr(callee, '__schedule_tree__'): + return None + + specialization_args, specialization_kwargs, _, _ = self.callable_specializer.extract_call_specialization( + call_node, _unparse) + try: + sdfg = callee.__sdfg__(*specialization_args, **specialization_kwargs) + except Exception: + return None + + return sdfg if isinstance(sdfg, SDFG) else None + + def _emit_sdfg_call(self, call_node: ast.Call, return_targets: Optional[List[str]] = None) -> bool: + sdfg = self._resolve_sdfg_call(call_node) + if sdfg is None: + return False + + callee = self.callable_resolver.resolve_callable_value(call_node.func) + callee_name = self.callable_resolver.callable_name(callee) + arguments = self.callable_resolver.extract_argument_mapping(call_node, self._format_runtime_expression) + self._append_node( + tn.SDFGCallNode(sdfg=sdfg, + call=tn.FrontendFunctionCall(callee_name=callee_name, arguments=arguments), + return_targets=list(return_targets or []))) + return True + + def _materialize_call_args(self, call_node: ast.Call) -> None: + """Materialize array-valued call arguments into temporaries in-place.""" + ctx = self._expression_planning_context() + for i, arg in enumerate(call_node.args): + call_node.args[i] = self.expression_support.plan_expression( + ctx, self.lambda_resolver.inline_known_lambda_calls(arg), materialize_root=True) + for kw in call_node.keywords: + kw.value = self.expression_support.plan_expression(ctx, + self.lambda_resolver.inline_known_lambda_calls(kw.value), + materialize_root=True) + + def _should_lower_as_library_call(self, node: ast.Call) -> bool: + return self._library_info_for_call(node) is not None + + def _fresh_symbol(self, prefix: str = '__stree_sym') -> symbolic.symbol: + index = 0 + candidate = prefix + while (candidate in self.bindings or candidate in self.root.containers or candidate in self.globals + or candidate in self.root.symbols): + index += 1 + candidate = f'{prefix}{index}' + symbol_value = symbolic.symbol(candidate, dtypes.int64) + self.root.symbols[candidate] = symbol_value + self.globals[candidate] = symbol_value + return symbol_value + + def _library_info_for_call(self, node: ast.Call) -> Optional[Tuple[str, Dict[str, Any]]]: + + call_name = self._resolved_callable_name(node.func) + if call_name in _INTERNAL_ITERATOR_HELPERS or call_name in {'range', 'prange', 'parrange'}: + return None + + if isinstance(node.func, ast.Attribute): + obj_access = self._resolve_data_access(node.func.value) + if obj_access is not None: + _, _, obj_desc, _ = obj_access + classname = type(obj_desc).__name__ + if (oprepo.Replacements.get_method(classname, node.func.attr) is not None + or oprepo.Replacements.get_method_descriptor_inference(classname, node.func.attr) is not None): + properties = self._library_properties(node) + properties['receiver_class'] = classname + properties['access_kind'] = 'method' + return (node.func.attr, properties) + receiver_value = try_resolve_static_value(node.func.value, self._evaluation_context()) + if receiver_value is not UNRESOLVED: + receiver_class = type(receiver_value) + if (oprepo.Replacements.get_method(receiver_class, node.func.attr) is not None + or oprepo.Replacements.get_method_descriptor_inference(receiver_class, + node.func.attr) is not None): + properties = self._library_properties(node) + properties['receiver_class'] = receiver_class.__name__ + properties['access_kind'] = 'method' + properties['receiver'] = astutils.rname(node.func.value) + return (node.func.attr, properties) + + if oprepo.Replacements.get(call_name) is None and oprepo.Replacements.get_descriptor_inference( + call_name) is None: + return None + return (call_name, self._library_properties(node)) + + def _library_info_for_attribute(self, node: ast.Attribute) -> Optional[Tuple[str, Dict[str, Any]]]: + + obj_access = self._resolve_data_access(node.value) + if obj_access is None: + return None + _, _, obj_desc, _ = obj_access + classname = type(obj_desc).__name__ + if (oprepo.Replacements.get_attribute(classname, node.attr) is None + and oprepo.Replacements.get_attribute_descriptor_inference(classname, node.attr) is None): + return None + return (node.attr, {'receiver_class': classname, 'access_kind': 'attribute'}) + + def _resolved_callable_name(self, node: ast.AST) -> str: + textual_name = astutils.rname(node) + callee = self.callable_resolver.resolve_callable_value(node) + if callee is None: + resolved = try_resolve_static_value(node, self._evaluation_context()) + if resolved is not UNRESOLVED: + callee = resolved + if callee is not None: + module_name = getattr(callee, '__module__', None) + callable_name = getattr(callee, '__name__', None) + if module_name and callable_name and module_name != 'builtins': + if module_name.startswith('numpy.'): + module_name = 'numpy' + return f'{module_name}.{callable_name}' + resolved_name = self.callable_resolver.callable_name(callee) + if resolved_name: + return resolved_name + if '.' in textual_name: + root_name, suffix = textual_name.split('.', 1) + root_value = try_resolve_static_value(ast.Name(id=root_name, ctx=ast.Load()), self._evaluation_context()) + module_name = getattr(root_value, '__name__', None) if root_value is not UNRESOLVED else None + if module_name is not None: + if module_name.startswith('numpy.'): + module_name = 'numpy' + return f'{module_name}.{suffix}' + return textual_name + + def _is_internal_iterator_helper_call(self, node: ast.AST) -> bool: + return isinstance(node, ast.Call) and astutils.rname(node.func) in _INTERNAL_ITERATOR_HELPERS + + def _is_internal_iterator_binding_name(self, name: str) -> bool: + return name.startswith('__dace_iter_') + + def _should_bind_as_reference(self, value: ast.AST, source: data.Data) -> bool: + if isinstance(source, data.Scalar): + return False + return isinstance(value, ast.Name) + + def _can_promote_to_reference(self, existing: data.Data, source: data.Data) -> bool: + if isinstance(existing, data.Scalar) or isinstance(source, data.Scalar): + return False + if hasattr(existing, 'is_equivalent'): + return existing.is_equivalent(source) + return type(existing) is type(source) + + def _evaluation_context(self) -> Dict[str, Any]: + context = copy.copy(self.external_globals) + context.update(self.globals) + context.update({ + name: binding.descriptor + for name, binding in self.bindings.items() if binding.descriptor is not None + }) + context.update(self.root.symbols) + return context + + def _numpy_lowering_context(self) -> NumpyLoweringContext: + return NumpyLoweringContext(bindings=self.bindings, + evaluation_context=self._evaluation_context, + resolve_output_target=self._resolve_output_target, + tasklet_name=self._tasklet_name, + fresh_symbol=self._fresh_symbol, + register_symbol=self._register_symbol, + fresh_name=self._fresh_transient_name, + append_node=self._append_node, + register_binding=self._register_binding) + + def _array_literal_context(self) -> ArrayLiteralContext: + return ArrayLiteralContext(infer_descriptor=lambda node: self._infer_descriptor(node, '__probe'), + infer_scalar_descriptor=self._infer_scalar_descriptor, + evaluation_context=self._evaluation_context, + resolve_output_target=self._resolve_output_target, + resolve_data_access=self._resolve_data_access, + resolve_callable_name=self._resolved_callable_name, + tasklet_name=self._tasklet_name, + array_constructor_name=self._array_constructor_name) + + def _dict_support_context(self, target_name: str = '__probe') -> DictSupportContext: + return DictSupportContext(infer_descriptor=lambda node: self._infer_descriptor(node, target_name), + infer_scalar_descriptor=self._infer_scalar_descriptor, + evaluation_context=self._evaluation_context) + + def _expression_planning_context(self) -> ExpressionPlanningContext: + return ExpressionPlanningContext(infer_descriptor=self._infer_plannable_expression_descriptor, + materialize_expression=self._materialize_temporary_expression, + resolve_data_access=self._resolve_data_access, + collect_input_memlets=self._collect_input_memlets, + resolve_output_target=self._resolve_output_target, + resolve_callable_name=self._resolved_callable_name, + should_materialize_call=lambda node: + (self.callable_resolver.is_dace_program_call(node) or self.callable_resolver. + is_sdfg_call(node) or self._should_lower_as_library_call(node))) + + def _infer_plannable_expression_descriptor(self, node: ast.AST) -> Optional[data.Data]: + node = self.lambda_resolver.inline_known_lambda_calls(node) + generic_descriptor = self.expression_support.infer_expression_descriptor(self._expression_planning_context(), + node) + if generic_descriptor is not None: + return generic_descriptor + + if isinstance(node, ast.Call): + array_literal_descriptor = self.array_literal_support.infer_expression_descriptor( + self._array_literal_context(), node) + if array_literal_descriptor is not None: + return array_literal_descriptor + + numpy_descriptor = self.numpy_support.infer_expression_descriptor(self._numpy_lowering_context(), node) + if numpy_descriptor is not None: + return numpy_descriptor + + inferred_descriptor = self._infer_descriptor(node, '__probe') + if inferred_descriptor is not None: + return inferred_descriptor + + scalar_descriptor = self._infer_scalar_descriptor(node, None) + if scalar_descriptor is not None: + return scalar_descriptor + + access = self._resolve_data_access(node) + if access is not None: + _, _, descriptor, view_descriptor = access + result = _clone_descriptor(view_descriptor or descriptor) + result.transient = True + return result + + return None + + def _is_callback_descriptor(self, descriptor: Optional[data.Data]) -> bool: + return isinstance(descriptor, data.Scalar) and isinstance(descriptor.dtype, dtypes.callback) + + def _callback_specialization_value(self) -> data.Scalar: + return data.Scalar(dtypes.callback(None), transient=False) + + def _make_nested_function_program(self, node: ast.FunctionDef) -> Optional[_NestedFunctionProgram]: + if node.decorator_list: + return None + + global_names, nonlocal_names = _collect_scope_declarations(node) + + class _SelfCallDetector(ast.NodeVisitor): + + def __init__(self, name: str) -> None: + self.name = name + self.recursive = False + + def visit_Call(self, call_node: ast.Call) -> None: + if astutils.rname(call_node.func) == self.name: + self.recursive = True + return + self.generic_visit(call_node) + + detector = _SelfCallDetector(node.name) + detector.visit(node) + if detector.recursive: + return None + + return _NestedFunctionProgram(node.name, + node, + program_globals=self.globals, + external_globals=self.external_globals, + captured_names=set(global_names) | set(nonlocal_names), + constants=self.root.constants, + callback_mapping=self.root.callback_mapping, + seed_bindings=self.bindings, + lambda_bindings=self.lambda_bindings, + callable_bindings=self.callable_bindings) + + def _resolve_external_scope_value(self, name: str) -> Any: + if name in self.external_globals: + return self.external_globals[name] + return UNRESOLVED + + def _bind_external_scope_value(self, name: str, value: Any) -> None: + try: + descriptor = _binding_to_descriptor(value) + except Exception: + descriptor = _pyobject_scalar_descriptor() + + self._store_binding(name, descriptor, kind=_binding_kind_for_descriptor(descriptor)) + self.globals[name] = value + + if callable(value) and self.lambda_resolver.resolve_global_lambda_node(value) is None: + self.callable_bindings[name] = value + + self.lambda_resolver.bind_value(name, value) + + def _update_callable_binding(self, name: str, value: ast.AST) -> None: + if self.lambda_resolver.resolve_known_lambda_node(value) is not None: + self.callable_bindings.pop(name, None) + return + resolved = self.callable_resolver.resolve_known_callable(value) + if resolved is None: + self.callable_bindings.pop(name, None) + return + self.callable_bindings[name] = resolved + + def _emit_if_chain(self, node: ast.If) -> None: + parent = self.scope_stack[-1] + current = node + if_scope = tn.IfScope(condition=CodeBlock(_unparse(current.test)), children=[]) + if_scope.parent = parent + parent.children.append(if_scope) + self._visit_body(if_scope, current.body) + + orelse = current.orelse + while len(orelse) == 1 and isinstance(orelse[0], ast.If): + current = orelse[0] + elif_scope = tn.ElifScope(condition=CodeBlock(_unparse(current.test)), children=[]) + elif_scope.parent = parent + parent.children.append(elif_scope) + self._visit_body(elif_scope, current.body) + orelse = current.orelse + + if orelse: + else_scope = tn.ElseScope(children=[]) + else_scope.parent = parent + parent.children.append(else_scope) + self._visit_body(else_scope, orelse) + + def _parse_for_indices(self, node: ast.AST) -> List[str]: + if isinstance(node, ast.Name): + return [node.id] + if isinstance(node, (ast.Tuple, ast.List)): + names = [] + for elt in node.elts: + if not isinstance(elt, ast.Name): + raise TypeError('Only identifier loop targets are supported in the schedule-tree frontend') + names.append(elt.id) + return names + raise TypeError('Only identifier loop targets are supported in the schedule-tree frontend') + + def _parse_for_iterator(self, node: ast.AST) -> Tuple[str, List[Tuple[str, str, str]]]: + schedule_target = node + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.MatMult): + schedule_target = node.left + + if isinstance(schedule_target, ast.Call): + iterator = astutils.rname(schedule_target.func) + if iterator not in {'range', 'prange', 'parrange'}: + raise TypeError(f'Unsupported for-loop iterator {iterator!r}') + + args = schedule_target.args + if len(args) == 1: + return 'range', [('0', _unparse(args[0]), '1')] + if len(args) == 2: + return 'range', [(_unparse(args[0]), _unparse(args[1]), '1')] + if len(args) == 3: + return 'range', [(_unparse(args[0]), _unparse(args[1]), _unparse(args[2]))] + raise TypeError(f'Invalid number of arguments for {iterator!r}') + + if isinstance(schedule_target, ast.Subscript): + iterator = astutils.rname(schedule_target.value) + if iterator != 'dace.map': + raise TypeError(f'Unsupported for-loop iterator {iterator!r}') + return 'dace.map', self._parse_map_ranges(schedule_target) + + raise TypeError('Unsupported for-loop iterator expression in schedule-tree frontend') + + def _parse_map_ranges(self, node: ast.Subscript) -> List[Tuple[str, str, str]]: + slice_node = node.slice + if isinstance(slice_node, ast.Tuple): + dims = list(slice_node.elts) + else: + dims = [slice_node] + + ranges: List[Tuple[str, str, str]] = [] + for dim in dims: + if isinstance(dim, ast.Slice): + start = '0' if dim.lower is None else _unparse(dim.lower) + stop = _unparse(dim.upper) if dim.upper is not None else '' + step = '1' if dim.step is None else _unparse(dim.step) + ranges.append((start, stop, step)) + else: + expr = _unparse(dim) + ranges.append((expr, expr, '1')) + return ranges diff --git a/dace/frontend/python/tasklet_runner.py b/dace/frontend/python/tasklet_runner.py index 01210463d3..d48037dbcd 100644 --- a/dace/frontend/python/tasklet_runner.py +++ b/dace/frontend/python/tasklet_runner.py @@ -153,7 +153,7 @@ def visit_TopLevelExpr(self, node): # Replace "a << A[i]" with "a = A[i]" at the beginning if not dynamic: - storenode = copy.deepcopy(node.value.left) + storenode = astutils.copy_tree(node.value.left) storenode.ctx = ast.Store() self.pre_statements.append( _copy_location(ast.Assign(targets=[storenode], value=cleaned_right), node)) @@ -173,7 +173,7 @@ def visit_TopLevelExpr(self, node): # lambda: "A[i] = (lambda a,b: a+b)(A[i], a)" rhs = _copy_location(ast.Call(func=wcr, args=[cleaned_right, rhs], keywords=[]), rhs) - lhs = copy.deepcopy(cleaned_right) + lhs = astutils.copy_tree(cleaned_right) lhs.ctx = ast.Store() self.post_statements.append(_copy_location(ast.Assign(targets=[lhs], value=rhs), node)) else: @@ -209,7 +209,7 @@ def visit_Assign(self, node: ast.Assign): elif (isinstance(target, ast.Name) and target.id in self.wcr_replacements): # Replace WCR assignment newtarget, wcr = copy.deepcopy(self.wcr_replacements[target.id]) - new_old_rhs = copy.deepcopy(newtarget) + new_old_rhs = astutils.copy_tree(newtarget) newtarget.ctx = ast.Store() rhs = _copy_location(ast.Call(func=wcr, args=[new_old_rhs, rhs], keywords=[]), rhs) result.append(_copy_location(ast.Assign(targets=[newtarget], value=rhs), node)) diff --git a/dace/libraries/blas/nodes/axpy.py b/dace/libraries/blas/nodes/axpy.py index c82eabf9f7..141f1c57e5 100644 --- a/dace/libraries/blas/nodes/axpy.py +++ b/dace/libraries/blas/nodes/axpy.py @@ -157,3 +157,9 @@ def axpy_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, a, x, y, re state.add_edge(libnode, '_res', res, None, mm.Memlet(result)) return [] + + +@oprepo.infers_descriptor('dace.libraries.blas.axpy') +@oprepo.infers_descriptor('dace.libraries.blas.Axpy') +def _infer_axpy_libnode(input_descs, a, x, y, result, **_kw): + return () diff --git a/dace/libraries/blas/nodes/batched_matmul.py b/dace/libraries/blas/nodes/batched_matmul.py index 0abc94c1fc..7998347d29 100644 --- a/dace/libraries/blas/nodes/batched_matmul.py +++ b/dace/libraries/blas/nodes/batched_matmul.py @@ -519,3 +519,8 @@ def bmmnode(pv, sdfg: dace.SDFG, state: dace.SDFGState, A, B, C, alpha=1, beta=0 state.add_edge(libnode, '_c', C_out, None, mm.Memlet(C)) return [] + + +@oprepo.infers_descriptor('dace.libraries.blas.bmm') +def _infer_bmmnode(input_descs, A, B, C, alpha=1, beta=0, trans_a=False, trans_b=False, **_kw): + return () diff --git a/dace/libraries/blas/nodes/dot.py b/dace/libraries/blas/nodes/dot.py index 42ce0c0fa8..ad177e2d92 100644 --- a/dace/libraries/blas/nodes/dot.py +++ b/dace/libraries/blas/nodes/dot.py @@ -243,3 +243,9 @@ def dot_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, x, y, result state.add_edge(libnode, '_result', res, None, mm.Memlet(result)) return [] + + +@oprepo.infers_descriptor('dace.libraries.blas.dot') +@oprepo.infers_descriptor('dace.libraries.blas.Dot') +def _infer_dot_libnode(input_descs, x, y, result, acctype=None, **_kw): + return () diff --git a/dace/libraries/blas/nodes/gemm.py b/dace/libraries/blas/nodes/gemm.py index 90f4c6d2b1..cee30906be 100644 --- a/dace/libraries/blas/nodes/gemm.py +++ b/dace/libraries/blas/nodes/gemm.py @@ -626,3 +626,9 @@ def gemm_libnode(pv: 'ProgramVisitor', state.add_edge(C_in, None, libnode, '_c', mm.Memlet(C)) return [] + + +@oprepo.infers_descriptor('dace.libraries.blas.gemm') +@oprepo.infers_descriptor('dace.libraries.blas.Gemm') +def _infer_gemm_libnode(input_descs, A, B, C, alpha, beta, trans_a=False, trans_b=False, **_kw): + return () diff --git a/dace/libraries/blas/nodes/gemv.py b/dace/libraries/blas/nodes/gemv.py index 9ca6368b45..85a8ef563c 100644 --- a/dace/libraries/blas/nodes/gemv.py +++ b/dace/libraries/blas/nodes/gemv.py @@ -464,3 +464,9 @@ def gemv_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, A, x, y, al state.add_edge(y_in, None, libnode, '_y', mm.Memlet(y)) return [] + + +@oprepo.infers_descriptor('dace.libraries.blas.gemv') +@oprepo.infers_descriptor('dace.libraries.blas.Gemv') +def _infer_gemv_libnode(input_descs, A, x, y, alpha, beta, trans=None, **_kw): + return () diff --git a/dace/libraries/blas/nodes/ger.py b/dace/libraries/blas/nodes/ger.py index a91c5e3b10..56f91cc1f4 100644 --- a/dace/libraries/blas/nodes/ger.py +++ b/dace/libraries/blas/nodes/ger.py @@ -195,3 +195,9 @@ def ger_libnode(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, A, x, y, out state.add_edge(libnode, '_res', out, None, mm.Memlet(output)) return [] + + +@oprepo.infers_descriptor('dace.libraries.blas.ger') +@oprepo.infers_descriptor('dace.libraries.blas.Ger') +def _infer_ger_libnode(input_descs, A, x, y, output, alpha, **_kw): + return () diff --git a/dace/libraries/onnx/nodes/onnx_op_registry.py b/dace/libraries/onnx/nodes/onnx_op_registry.py index 91d7ab8982..8d835e007d 100644 --- a/dace/libraries/onnx/nodes/onnx_op_registry.py +++ b/dace/libraries/onnx/nodes/onnx_op_registry.py @@ -70,7 +70,11 @@ def _get_all_schemas(): def register_op_repo_replacement(cls: Type[onnx_op.ONNXOp], cls_name: str, dace_schema: ONNXSchema): """Register an op repository replacement for the given ONNX operation class.""" - @dace_op_repo.replaces("dace.libraries.onnx.{}".format(cls_name)) + @dace_op_repo.infers_descriptor(f"dace.libraries.onnx.{cls_name}") + def op_repo_descriptor_inference(input_descs, *args, **kwargs): + return () + + @dace_op_repo.replaces(f"dace.libraries.onnx.{cls_name}") def op_repo_replacement(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, **kwargs): attrs = {name: value for name, value in kwargs.items() if name in dace_schema.attributes} # Remove used attrs diff --git a/dace/runtime/include/dace/math.h b/dace/runtime/include/dace/math.h index 533056c9e4..ca21d115ae 100644 --- a/dace/runtime/include/dace/math.h +++ b/dace/runtime/include/dace/math.h @@ -255,10 +255,11 @@ static DACE_CONSTEXPR DACE_HDFI std::complex np_float_pow(const std::com // Formula: num - (num // den) * den // NOTE: This is different than Python math.remainder and C remainder, // which are equaivalent to the IEEE remainder: num - round(num / den) * den -template -static DACE_CONSTEXPR DACE_HDFI T py_mod(const T& numerator, const T& denominator) { - T quotient = py_floor(numerator, denominator); - return (T)(numerator - quotient * denominator); +template +static DACE_CONSTEXPR DACE_HDFI std::common_type_t py_mod(const T1& numerator, const T2& denominator) { + using CT = std::common_type_t; + CT quotient = py_floor((CT)numerator, (CT)denominator); + return (CT)(numerator - quotient * denominator); } // Computes C/C++ modulus (operator % and fmod) diff --git a/dace/runtime/include/dace/pyinterop.h b/dace/runtime/include/dace/pyinterop.h index 0c59633ac1..6dca872f1c 100644 --- a/dace/runtime/include/dace/pyinterop.h +++ b/dace/runtime/include/dace/pyinterop.h @@ -2,6 +2,12 @@ #ifndef __DACE_INTEROP_H #define __DACE_INTEROP_H +#include + +#include +#include +#include + #include "types.h" // Various classes to simplify interoperability with python in code converted to C++ @@ -39,6 +45,359 @@ class range typedef void *pyobject; +template +inline const char* dace_numpy_dtype_name() { + if constexpr (std::is_same_v) { + return "float64"; + } else if constexpr (std::is_same_v) { + return "float32"; + } else if constexpr (std::is_same_v) { + return "int8"; + } else if constexpr (std::is_same_v) { + return "int16"; + } else if constexpr (std::is_same_v) { + return "int32"; + } else if constexpr (std::is_same_v) { + return "int64"; + } else if constexpr (std::is_same_v) { + return "uint8"; + } else if constexpr (std::is_same_v) { + return "uint16"; + } else if constexpr (std::is_same_v) { + return "uint32"; + } else if constexpr (std::is_same_v) { + return "uint64"; + } else if constexpr (std::is_same_v) { + return "bool_"; + } else { + throw std::runtime_error("Unsupported NumPy dtype conversion"); + } +} + +template +inline const char* dace_ctypes_scalar_name() { + if constexpr (std::is_same_v) { + return "c_double"; + } else if constexpr (std::is_same_v) { + return "c_float"; + } else if constexpr (std::is_same_v) { + return "c_int8"; + } else if constexpr (std::is_same_v) { + return "c_int16"; + } else if constexpr (std::is_same_v) { + return "c_int32"; + } else if constexpr (std::is_same_v) { + return "c_int64"; + } else if constexpr (std::is_same_v) { + return "c_uint8"; + } else if constexpr (std::is_same_v) { + return "c_uint16"; + } else if constexpr (std::is_same_v) { + return "c_uint32"; + } else if constexpr (std::is_same_v) { + return "c_uint64"; + } else if constexpr (std::is_same_v) { + return "c_bool"; + } else { + throw std::runtime_error("Unsupported ctypes scalar conversion"); + } +} + +template +inline PyObject* dace_make_pyarray(T* ptr, const Py_ssize_t* shape, + const Py_ssize_t* strides, size_t ndim) { + PyObject* numpy_module = PyImport_ImportModule("numpy"); + if (numpy_module == nullptr) { + return nullptr; + } + PyObject* ctypes_module = PyImport_ImportModule("ctypes"); + if (ctypes_module == nullptr) { + Py_DecRef(numpy_module); + return nullptr; + } + + PyObject* numpy_dtype = + PyObject_GetAttrString(numpy_module, dace_numpy_dtype_name()); + PyObject* ctypes_scalar = + PyObject_GetAttrString(ctypes_module, dace_ctypes_scalar_name()); + if (numpy_dtype == nullptr || ctypes_scalar == nullptr) { + Py_DecRef(numpy_dtype); + Py_DecRef(ctypes_scalar); + Py_DecRef(ctypes_module); + Py_DecRef(numpy_module); + return nullptr; + } + + Py_ssize_t total_size = 1; + for (size_t i = 0; i < ndim; ++i) { + total_size *= shape[i]; + } + + PyObject* array_len = PyLong_FromSsize_t(total_size); + PyObject* array_type = PyNumber_Multiply(ctypes_scalar, array_len); + PyObject* pointer_fn = PyObject_GetAttrString(ctypes_module, "POINTER"); + PyObject* pointer_type = + pointer_fn ? PyObject_CallFunctionObjArgs(pointer_fn, array_type, nullptr) + : nullptr; + PyObject* voidp_type = PyObject_GetAttrString(ctypes_module, "c_void_p"); + PyObject* cast_fn = PyObject_GetAttrString(ctypes_module, "cast"); + PyObject* address = + PyLong_FromUnsignedLongLong(reinterpret_cast(ptr)); + PyObject* voidp = + voidp_type ? PyObject_CallFunctionObjArgs(voidp_type, address, nullptr) + : nullptr; + PyObject* casted = + (cast_fn && voidp && pointer_type) + ? PyObject_CallFunctionObjArgs(cast_fn, voidp, pointer_type, nullptr) + : nullptr; + PyObject* contents = + casted ? PyObject_GetAttrString(casted, "contents") : nullptr; + + PyObject* shape_tuple = PyTuple_New(ndim); + PyObject* strides_tuple = PyTuple_New(ndim); + if (shape_tuple == nullptr || strides_tuple == nullptr) { + Py_DecRef(strides_tuple); + Py_DecRef(shape_tuple); + Py_DecRef(contents); + Py_DecRef(casted); + Py_DecRef(voidp); + Py_DecRef(address); + Py_DecRef(cast_fn); + Py_DecRef(voidp_type); + Py_DecRef(pointer_type); + Py_DecRef(pointer_fn); + Py_DecRef(array_type); + Py_DecRef(array_len); + Py_DecRef(ctypes_scalar); + Py_DecRef(numpy_dtype); + Py_DecRef(ctypes_module); + Py_DecRef(numpy_module); + return nullptr; + } + for (size_t i = 0; i < ndim; ++i) { + PyObject* shape_value = PyLong_FromSsize_t(shape[i]); + PyObject* stride_value = PyLong_FromSsize_t(strides[i]); + if (shape_value == nullptr || stride_value == nullptr || + PyTuple_SetItem(shape_tuple, i, shape_value) != 0 || + PyTuple_SetItem(strides_tuple, i, stride_value) != 0) { + Py_DecRef(shape_value); + Py_DecRef(stride_value); + Py_DecRef(strides_tuple); + Py_DecRef(shape_tuple); + Py_DecRef(contents); + Py_DecRef(casted); + Py_DecRef(voidp); + Py_DecRef(address); + Py_DecRef(cast_fn); + Py_DecRef(voidp_type); + Py_DecRef(pointer_type); + Py_DecRef(pointer_fn); + Py_DecRef(array_type); + Py_DecRef(array_len); + Py_DecRef(ctypes_scalar); + Py_DecRef(numpy_dtype); + Py_DecRef(ctypes_module); + Py_DecRef(numpy_module); + return nullptr; + } + } + + PyObject* ndarray_ctor = PyObject_GetAttrString(numpy_module, "ndarray"); + PyObject* args = PyTuple_New(0); + PyObject* kwargs = PyDict_New(); + if (shape_tuple != nullptr) { + PyDict_SetItemString(kwargs, "shape", shape_tuple); + } + if (numpy_dtype != nullptr) { + PyDict_SetItemString(kwargs, "dtype", numpy_dtype); + } + if (contents != nullptr) { + PyDict_SetItemString(kwargs, "buffer", contents); + } + if (strides_tuple != nullptr) { + PyDict_SetItemString(kwargs, "strides", strides_tuple); + } + + PyObject* result = + ndarray_ctor ? PyObject_Call(ndarray_ctor, args, kwargs) : nullptr; + + Py_DecRef(kwargs); + Py_DecRef(args); + Py_DecRef(ndarray_ctor); + Py_DecRef(strides_tuple); + Py_DecRef(shape_tuple); + Py_DecRef(contents); + Py_DecRef(casted); + Py_DecRef(voidp); + Py_DecRef(address); + Py_DecRef(cast_fn); + Py_DecRef(voidp_type); + Py_DecRef(pointer_type); + Py_DecRef(pointer_fn); + Py_DecRef(array_type); + Py_DecRef(array_len); + Py_DecRef(ctypes_scalar); + Py_DecRef(numpy_dtype); + Py_DecRef(ctypes_module); + Py_DecRef(numpy_module); + return result; +} + +inline PyObject* dace_make_pyobject(pyobject value) { + PyObject* result = reinterpret_cast(value); + Py_IncRef(result); + return result; +} + +inline PyObject* dace_make_pyobject(bool value) { + return PyBool_FromLong(value ? 1 : 0); +} + +template +inline std::enable_if_t && !std::is_same_v, + PyObject*> +dace_make_pyobject(T value) { + return PyLong_FromLongLong(static_cast(value)); +} + +template +inline std::enable_if_t, PyObject*> +dace_make_pyobject(T value) { + return PyFloat_FromDouble(static_cast(value)); +} + +template +inline void dace_set_pyobject_attr(pyobject obj, const char* attr, T value) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + PyObject* pyobj = reinterpret_cast(obj); + PyObject* pyvalue = dace_make_pyobject(value); + if (PyObject_SetAttrString(pyobj, attr, pyvalue) != 0) { + Py_DecRef(pyvalue); + PyGILState_Release(gil_state); + throw std::runtime_error("Failed to set Python attribute"); + } + Py_DecRef(pyvalue); + PyGILState_Release(gil_state); +} + +template +inline void dace_set_pyobject_attr_array(pyobject obj, const char* attr, T* ptr, + const Py_ssize_t* shape, + const Py_ssize_t* strides, + size_t ndim) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + PyObject* pyobj = reinterpret_cast(obj); + PyObject* pyvalue = dace_make_pyarray(ptr, shape, strides, ndim); + if (pyvalue == nullptr) { + PyGILState_Release(gil_state); + throw std::runtime_error("Failed to materialize Python array attribute"); + } + if (PyObject_SetAttrString(pyobj, attr, pyvalue) != 0) { + Py_DecRef(pyvalue); + PyGILState_Release(gil_state); + throw std::runtime_error("Failed to set Python array attribute"); + } + Py_DecRef(pyvalue); + PyGILState_Release(gil_state); +} + +inline PyObject* dace_resolve_pyobject_attr_path(pyobject obj, + const char* attr_path) { + PyObject* current = reinterpret_cast(obj); + Py_IncRef(current); + const char* cursor = attr_path; + + while (current != nullptr) { + const char* separator = std::strchr(cursor, '.'); + PyObject* next = nullptr; + if (separator == nullptr) { + next = PyObject_GetAttrString(current, cursor); + Py_DecRef(current); + return next; + } + + const Py_ssize_t token_size = separator - cursor; + PyObject* token = PyUnicode_FromStringAndSize(cursor, token_size); + if (token == nullptr) { + Py_DecRef(current); + return nullptr; + } + next = PyObject_GetAttr(current, token); + Py_DecRef(token); + Py_DecRef(current); + if (next == nullptr) { + return nullptr; + } + current = next; + cursor = separator + 1; + } + + return nullptr; +} + +template +inline T dace_get_pyobject_attr(pyobject obj, const char* attr) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + PyObject* pyvalue = dace_resolve_pyobject_attr_path(obj, attr); + if (pyvalue == nullptr) { + PyGILState_Release(gil_state); + throw std::runtime_error("Failed to read Python attribute"); + } + + T result; + if constexpr (std::is_same_v) { + result = PyObject_IsTrue(pyvalue) != 0; + } else if constexpr (std::is_integral_v) { + result = static_cast(PyLong_AsLongLong(pyvalue)); + } else if constexpr (std::is_floating_point_v) { + result = static_cast(PyFloat_AsDouble(pyvalue)); + } else { + Py_DecRef(pyvalue); + PyGILState_Release(gil_state); + throw std::runtime_error("Unsupported Python attribute conversion"); + } + + if (PyErr_Occurred()) { + Py_DecRef(pyvalue); + PyGILState_Release(gil_state); + throw std::runtime_error("Failed to convert Python attribute"); + } + + Py_DecRef(pyvalue); + PyGILState_Release(gil_state); + return result; +} + +template +inline T* dace_get_pyobject_attr_ptr(pyobject obj, const char* attr) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + PyObject* pyvalue = dace_resolve_pyobject_attr_path(obj, attr); + if (pyvalue == nullptr) { + PyGILState_Release(gil_state); + throw std::runtime_error("Failed to read Python attribute"); + } + + Py_buffer view; + if (PyObject_GetBuffer(pyvalue, &view, PyBUF_STRIDES) != 0) { + Py_DecRef(pyvalue); + PyGILState_Release(gil_state); + throw std::runtime_error("Python attribute does not expose a buffer"); + } + + if (view.itemsize != sizeof(T)) { + PyBuffer_Release(&view); + Py_DecRef(pyvalue); + PyGILState_Release(gil_state); + throw std::runtime_error("Python attribute buffer itemsize mismatch"); + } + + T* result = reinterpret_cast(view.buf); + PyBuffer_Release(&view); + Py_DecRef(pyvalue); + PyGILState_Release(gil_state); + return result; +} + // Sympy functions template static DACE_HDFI U Min(U val, T... vals) { diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 70297e6db9..7b1799bdc0 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,8 +1,8 @@ # Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from collections.abc import Mapping from dataclasses import dataclass, field +import copy import sympy - from dace import nodes, data, subsets, dtypes, symbolic from dace.properties import CodeBlock from dace.sdfg import InterstateEdge @@ -11,7 +11,7 @@ from dace.sdfg.sdfg import InterstateEdge, SDFG, memlets_in_ast from dace.sdfg.state import LoopRegion, SDFGState from dace.memlet import Memlet -from typing import TYPE_CHECKING, Any, Iterable, Iterator, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Literal, Optional, Sequence, Set, Tuple, Union if TYPE_CHECKING: from dace import SDFG @@ -23,6 +23,76 @@ class UnsupportedScopeException(Exception): pass +def _format_frontend_range(start: str, stop: str, step: str) -> str: + if step == '1': + return f'{start}:{stop}' + return f'{start}:{stop}:{step}' + + +@dataclass(frozen=True) +class FrontendLoop: + """ + Lightweight loop metadata used by frontends that construct schedule trees + without first materializing an SDFG control-flow region. + """ + loop_condition: CodeBlock + init_statement: Optional[CodeBlock] = None + update_statement: Optional[CodeBlock] = None + loop_variable: Optional[str] = None + inverted: bool = False + update_before_condition: bool = False + + +@dataclass(frozen=True) +class FrontendMap: + """ + Lightweight map metadata used by frontends that construct schedule trees + without first materializing SDFG map nodes. + """ + params: Sequence[str] + ranges: Sequence[Tuple[str, str, str]] + schedule: Optional[str] = None + + +@dataclass(frozen=True) +class FrontendConsume: + """ + Lightweight consume-scope metadata used by frontend-produced schedule + trees. + """ + pe_index: str + num_pes: str + condition: Optional[CodeBlock] = None + + +@dataclass(frozen=True) +class FrontendTasklet: + """ + Lightweight tasklet metadata used by frontend-produced schedule trees. + """ + name: str + code: CodeBlock = field(default_factory=lambda: CodeBlock('')) + + +@dataclass(frozen=True) +class FrontendLibrary: + """ + Lightweight library call metadata used by frontend-produced schedule trees. + """ + name: str + properties: Dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class FrontendFunctionCall: + """ + Lightweight function-call metadata used by frontend-produced schedule trees + to represent a call to another ``@dace.program``. + """ + callee_name: str + arguments: Dict[str, str] = field(default_factory=dict) # callee_param -> caller_expression + + @dataclass class ScheduleTreeNode: """Base class for nodes in the schedule tree.""" @@ -199,10 +269,10 @@ def output_memlets(self, @dataclass class ScheduleTreeRoot(ScheduleTreeScope): """ - The root of a schedule tree. This is a `ScheduleTreeScope` with additional information on + The root of a schedule tree. This is a ``ScheduleTreeScope`` with additional information on the available descriptors, symbol types, and constants of the tree, aka the descriptor repository. - Each schedule tree has only one `ScheduleTreeRoot`. The `ScheduleTreeRoot` is the only `ScheduleTreeScope` + Each schedule tree has only one ``ScheduleTreeRoot``. The ``ScheduleTreeRoot`` is the only ``ScheduleTreeScope`` without a parent (because it is the root node of the tree). """ name: str @@ -274,10 +344,43 @@ def __init__(self, *, children: list[ScheduleTreeNode], parent: ScheduleTreeScop super().__init__(children=children, parent=parent) +@dataclass +class FunctionCallScope(ControlFlowScope): + """ + Represents a call to another ``@dace.program`` whose schedule tree body + is inlined as children of this scope. + """ + call: FrontendFunctionCall = field(default_factory=lambda: FrontendFunctionCall('')) + + def as_string(self, indent: int = 0): + args = ', '.join(f'{k}={v}' for k, v in self.call.arguments.items()) + result = indent * INDENTATION + f'call {self.call.callee_name}({args}):\n' + return result + super().as_string(indent) + + +@dataclass +class SDFGCallNode(ScheduleTreeNode): + """ + Represents a call to an SDFG-valued callee that remains explicit in the + schedule tree instead of being inlined structurally. + """ + sdfg: 'SDFG' + call: FrontendFunctionCall = field(default_factory=lambda: FrontendFunctionCall('')) + return_targets: List[str] = field(default_factory=list) + + def as_string(self, indent: int = 0): + args = ', '.join(f'{k}={v}' for k, v in self.call.arguments.items()) + call = f'sdfg_call {self.call.callee_name}({args})' + if not self.return_targets: + return indent * INDENTATION + call + targets = ', '.join(self.return_targets) + return indent * INDENTATION + f'{targets} = {call}' + + @dataclass class DataflowScope(ScheduleTreeScope): - node: nodes.EntryNode - state: SDFGState | None = None + node: Union[nodes.EntryNode, FrontendMap, FrontendConsume] + state: Optional[SDFGState] = None def __init__(self, *, @@ -309,10 +412,14 @@ def as_string(self, indent: int = 0): @dataclass class StateLabel(ScheduleTreeNode): - state: SDFGState + state: Union[SDFGState, str] def as_string(self, indent: int = 0): - return indent * INDENTATION + f'label {self.state.name}:' + if isinstance(self.state, str): + name = self.state + else: + name = self.state.name + return indent * INDENTATION + f'label {name}:' def input_memlets(self, root: ScheduleTreeRoot | None = None, **kwargs) -> MemletSet: return MemletSet() @@ -343,7 +450,7 @@ class AssignNode(ScheduleTreeNode): """ name: str value: CodeBlock - edge: InterstateEdge + edge: Optional[InterstateEdge] = None def as_string(self, indent: int = 0): return indent * INDENTATION + f'assign {self.name} = {self.value.as_string}' @@ -356,12 +463,97 @@ def output_memlets(self, root: ScheduleTreeRoot | None = None, **kwargs) -> Meml return MemletSet() +@dataclass +class ReassignExternalNode(ScheduleTreeNode): + """ + Explicit reassignment of an external Python binding captured via + ``global`` or ``nonlocal``. + """ + name: str + value: CodeBlock + scope: Literal['global', 'nonlocal'] + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'reassign_external {self.scope} {self.name} = {self.value.as_string}' + + +@dataclass +class StatementNode(ScheduleTreeNode): + """ + Opaque statement node used by source frontends when a statement has not yet + been lowered into a more structured dataflow node. + """ + code: CodeBlock + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'stmt {self.code.as_string}' + + +@dataclass +class PythonCallbackNode(ScheduleTreeNode): + """ + Python code that cannot be represented in the dataflow model and must be + executed via native Python callback at runtime. Distinct from StatementNode + in that it explicitly marks code as never lowerable. + """ + code: CodeBlock + reason: str + input_names: List[str] = field(default_factory=list) + output_names: List[str] = field(default_factory=list) + outlined_function_name: Optional[str] = None + outlined_function_code: Optional[CodeBlock] = None + outlined_call_code: Optional[CodeBlock] = None + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'python_callback "{self.reason}" {{ {self.code.as_string} }}' + + +@dataclass +class RaiseNode(ScheduleTreeNode): + """ + Explicit raise statement emitted by source frontends when the exception + shape is known well enough to remain compilable. + """ + exception_type: Optional[CodeBlock] = None + args: List[CodeBlock] = field(default_factory=list) + kwargs: Dict[str, CodeBlock] = field(default_factory=dict) + + def as_string(self, indent: int = 0): + if self.exception_type is None: + return indent * INDENTATION + 'raise' + + call_args = [argument.as_string for argument in self.args] + call_args.extend(f'{name}={value.as_string}' for name, value in self.kwargs.items()) + rendered = self.exception_type.as_string + if call_args: + rendered = f'{rendered}({", ".join(call_args)})' + return indent * INDENTATION + f'raise {rendered}' + + +@dataclass +class ReturnNode(ScheduleTreeNode): + """ + Explicit return node used by source frontends before lowering returns to a + backend-specific representation. + """ + values: List[str] = field(default_factory=list) + """ + If non-empty, represents the return value(s) of this return statement as a list of data descriptor names. + """ + + def as_string(self, indent: int = 0): + if not self.values: + return indent * INDENTATION + 'return' + joined = ', '.join(self.values) + return indent * INDENTATION + f'return {joined}' + + @dataclass class LoopScope(ControlFlowScope): """ General loop scope (representing a loop region). """ - loop: LoopRegion + loop: Union[LoopRegion, FrontendLoop] def __init__(self, *, @@ -658,8 +850,13 @@ def __init__(self, super().__init__(node=node, state=state, children=children, parent=parent) def as_string(self, indent: int = 0): - rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) - result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' + if isinstance(self.node, FrontendMap): + rangestr = ', '.join(_format_frontend_range(start, stop, step) for start, stop, step in self.node.ranges) + params = ', '.join(self.node.params) + else: + rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) + params = ', '.join(self.node.map.params) + result = indent * INDENTATION + f'map {params} in [{rangestr}]:\n' return result + super().as_string(indent) def input_memlets(self, @@ -705,17 +902,24 @@ def __init__(self, super().__init__(node=node, state=state, children=children, parent=parent) def as_string(self, indent: int = 0): - node: nodes.ConsumeEntry = self.node - cond = 'stream not empty' if node.consume.condition is None else node.consume.condition.as_string - result = indent * INDENTATION + f'consume (PE {node.consume.pe_index} out of {node.consume.num_pes}) while {cond}:\n' + if isinstance(self.node, FrontendConsume): + cond = 'stream not empty' if self.node.condition is None else self.node.condition.as_string + pe_index = self.node.pe_index + num_pes = self.node.num_pes + else: + node: nodes.ConsumeEntry = self.node + cond = 'stream not empty' if node.consume.condition is None else node.consume.condition.as_string + pe_index = node.consume.pe_index + num_pes = node.consume.num_pes + result = indent * INDENTATION + f'consume (PE {pe_index} out of {num_pes}) while {cond}:\n' return result + super().as_string(indent) @dataclass class TaskletNode(ScheduleTreeNode): - node: nodes.Tasklet - in_memlets: dict[str, Memlet] - out_memlets: dict[str, Memlet] + node: Union[nodes.Tasklet, FrontendTasklet] + in_memlets: Dict[str, Memlet] + out_memlets: Dict[str, Memlet] def as_string(self, indent: int = 0): in_memlets = ', '.join(f'{v}' for v in self.in_memlets.values()) @@ -733,9 +937,9 @@ def output_memlets(self, root: ScheduleTreeRoot | None = None, **kwargs) -> Meml @dataclass class LibraryCall(ScheduleTreeNode): - node: nodes.LibraryNode - in_memlets: dict[str, Memlet] | MemletSet - out_memlets: dict[str, Memlet] | MemletSet + node: Union[nodes.LibraryNode, FrontendLibrary] + in_memlets: Union[Dict[str, Memlet], Set[Memlet]] + out_memlets: Union[Dict[str, Memlet], Set[Memlet]] def as_string(self, indent: int = 0): if isinstance(self.in_memlets, set): @@ -746,11 +950,17 @@ def as_string(self, indent: int = 0): out_memlets = ', '.join(f'{v}' for v in self.out_memlets) else: out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) - libname = type(self.node).__name__ - # Get the properties of the library node without its superclasses - own_properties = ', '.join(f'{k}={getattr(self.node, k)}' for k, v in self.node.__properties__.items() - if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) - return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' + if isinstance(self.node, FrontendLibrary): + libname = self.node.name + own_properties = ', '.join(f'{k}={v}' for k, v in self.node.properties.items()) + else: + libname = type(self.node).__name__ + own_properties = ', '.join(f'{k}={getattr(self.node, k)}' for k, v in self.node.__properties__.items() + if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) + call = f'library {libname}[{own_properties}]({in_memlets})' + if not out_memlets: + return indent * INDENTATION + call + return indent * INDENTATION + f'{out_memlets} = {call}' def input_memlets(self, root: ScheduleTreeRoot | None = None, **kwargs) -> MemletSet: if isinstance(self.in_memlets, set): @@ -862,13 +1072,16 @@ class RefSetNode(ScheduleTreeNode): Reference set node. Sets a reference to a data container. """ target: str - memlet: Memlet - src_desc: data.Data | nodes.CodeNode + memlet: Optional[Memlet] + src_desc: Union[data.Data, nodes.CodeNode] ref_desc: data.Data + source_expr: Optional[str] = None def as_string(self, indent: int = 0): if isinstance(self.src_desc, nodes.CodeNode): return indent * INDENTATION + f'{self.target} = refset from {type(self.src_desc).__name__.lower()}' + if self.source_expr is not None: + return indent * INDENTATION + f'{self.target} = refset to {self.source_expr}' return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' def input_memlets(self, root: ScheduleTreeRoot | None = None, **kwargs) -> MemletSet: @@ -896,6 +1109,18 @@ def output_memlets(self, root: ScheduleTreeRoot | None = None, **kwargs) -> Meml return MemletSet() +def clone_descriptor_with_shape(descriptor: data.Data, shape: Sequence[Any]) -> data.Data: + """ + Clone a data descriptor and update its shape if supported. + """ + result = copy.deepcopy(descriptor) + if hasattr(result, 'set_shape'): + result.set_shape(list(shape)) + elif hasattr(result, 'shape'): + result.shape = list(shape) + return result + + # Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes class ScheduleNodeVisitor: diff --git a/dace/symbolic.py b/dace/symbolic.py index 74600aa84f..77925f61ce 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -873,8 +873,8 @@ def swalk(expr, enter_functions=False): _builtin_userfunctions = { - 'int_floor', 'int_ceil', 'abs', 'Abs', 'min', 'Min', 'max', 'Max', 'not', 'Not', 'Eq', 'NotEq', 'Ne', 'AND', 'OR', - 'pow', 'round' + 'int_floor', 'int_ceil', 'pyindex', 'abs', 'Abs', 'min', 'Min', 'max', 'Max', 'not', 'Not', 'Eq', 'NotEq', 'Ne', + 'AND', 'OR', 'pow', 'round' } @@ -1055,6 +1055,28 @@ def _eval_is_integer(self): return True +class pyindex(sympy.Function): + """Python-style wraparound for scalar element indices. + + This is intentionally not used for slice bounds, where positive ``stop`` + values such as ``size`` must not wrap to zero. + """ + + @classmethod + def eval(cls, x, y): + if x.is_Number and y.is_Number: + return sympy.Mod(x, y) + if y.is_Number and y == 1: + return 0 + + def _eval_is_integer(self): + return True + + def _eval_is_nonnegative(self): + if self.args[1].is_nonnegative is True: + return True + + class OR(sympy.Function): @classmethod @@ -1789,6 +1811,10 @@ def _unary_minus(a): 'RightShift': right_shift, 'left_shift': left_shift, 'right_shift': right_shift, + 'pyindex': pyindex, + 'id': sympy.Symbol('id'), + 'diag': sympy.Symbol('diag'), + 'jn': sympy.Symbol('jn'), } _constants = { 'True': sympy.true, @@ -2241,6 +2267,10 @@ def _print_Function(self, expr): if str(expr.func) in self.arrays: indices = ", ".join(self._print(arg) for arg in expr.args) return f'{expr.func}[{indices}]' + if self.cpp_mode and str(expr.func) == 'int_floor': + return '((%s) / (%s))' % (self._print(expr.args[0]), self._print(expr.args[1])) + if self.cpp_mode and str(expr.func) == 'pyindex': + return 'py_mod(%s, %s)' % (self._print(expr.args[0]), self._print(expr.args[1])) if str(expr.func) == 'AND': return f'(({self._print(expr.args[0])}) and ({self._print(expr.args[1])}))' if str(expr.func) == 'OR': diff --git a/tests/codegen/allocation_lifetime_test.py b/tests/codegen/allocation_lifetime_test.py index b1278822a3..c52c7c699a 100644 --- a/tests/codegen/allocation_lifetime_test.py +++ b/tests/codegen/allocation_lifetime_test.py @@ -403,7 +403,7 @@ def test_persistent_loop_bound(): Code originates from Issue #1550. Tests both ``for`` and OpenMP parallel ``for`` loop bounds with persistent storage. """ - N = dace.symbol('N') + N = dace.symbol('N', dace.int64) @dace.program(auto_optimize=True) def tester(L: dace.float64[N, N], index: dace.uint64, active_size: dace.uint64): diff --git a/tests/codegen/control_flow_generation_test.py b/tests/codegen/control_flow_generation_test.py index f337570780..25bded347f 100644 --- a/tests/codegen/control_flow_generation_test.py +++ b/tests/codegen/control_flow_generation_test.py @@ -112,6 +112,13 @@ def tester(a: dace.float64[20]): assert 'goto' not in sdfg.generate_code()[0].code +def test_pyindex_codegen_prints_python_mod_for_scalar_indices(): + i = dace.symbol('i', dtype=dace.int32) + n = dace.symbol('N', dtype=dace.int32) + + assert dace.symbolic.symstr(dace.symbolic.pyindex(i, n), cpp_mode=True) == '(py_mod(i, N))' + + def test_extraneous_goto_nested(): @dace.program diff --git a/tests/numpy/array_creation_test.py b/tests/numpy/array_creation_test.py index fcb7343e40..19b34c1a3b 100644 --- a/tests/numpy/array_creation_test.py +++ b/tests/numpy/array_creation_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import dace from dace.frontend.python.common import DaceSyntaxError +from dace.frontend.python.replacements.array_creation import _infer_arange import numpy as np from common import compare_numpy_output import pytest @@ -117,6 +118,72 @@ def test_array_literal(): return np.array([[1, 2], [3, 4]], dtype=np.float32) +def test_array_literal_inside_expression(): + + @dace.program + def literal_expr(A: dace.float64[3]): + return A + np.array([1.0, 2.0, 3.0], dtype=np.float64) + + A = np.random.rand(3) + result = literal_expr(A) + expected = A + np.array([1.0, 2.0, 3.0], dtype=np.float64) + assert np.allclose(result, expected) + + +def test_array_literal_from_dynamic_scalar_elements(): + + @dace.program + def dynamic_literal(A: dace.float64[1], B: dace.float64[4], i: dace.int32): + return np.array([A[0], B[i]], dtype=np.float64) + + A = np.random.rand(1) + B = np.random.rand(4) + i = np.int32(2) + result = dynamic_literal(A, B, i) + expected = np.array([A[0], B[i]], dtype=np.float64) + assert np.allclose(result, expected) + + +def test_list_literal_inside_array_expression(): + + @dace.program + def literal_expr(A: dace.float64[3]): + return A * [1.0, 2.0, 3.0] + + A = np.random.rand(3) + result = literal_expr(A) + expected = A * np.array([1.0, 2.0, 3.0], dtype=np.float64) + assert np.allclose(result, expected) + + +def test_constant_list_literal_inside_array_expression_materializes_as_one_constant_array(): + + @dace.program + def literal_expr(A: dace.float64[3]): + return A * [1.0, 2.0, 3.0] + + sdfg = literal_expr.to_sdfg(simplify=False) + constant_arrays = [value for _, (_, value) in sdfg.constants_prop.items() if isinstance(value, np.ndarray)] + assert any(np.array_equal(value, np.array([1.0, 2.0, 3.0], dtype=np.float64)) for value in constant_arrays) + + literal_tasklets = [ + node for state in sdfg.states() for node in state.nodes() + if isinstance(node, dace.sdfg.nodes.Tasklet) and '_literal_' in node.label + ] + assert not literal_tasklets + + +def test_broadcast_mixed_tuple_and_list_literals_inside_expression(): + + @dace.program + def literal_expr(): + return np.array([1, 2, 3]) * ((4, 5, 6), [1, 2, 3]) + + result = literal_expr() + expected = np.array([1, 2, 3]) * ((4, 5, 6), [1, 2, 3]) + assert np.allclose(result, expected) + + @compare_numpy_output() def test_arange_0(): return np.arange(10, dtype=np.int32) @@ -152,6 +219,60 @@ def test_arange_6(): return np.arange(2.5, 10, 3) +def test_arange_symbolic_stop(): + K = dace.symbol('K') + desc = _infer_arange({}, K, dtype=np.int32) + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (K, ) + assert desc.dtype == dace.int32 + + +def test_arange_scalar_stop(): + desc = _infer_arange({'n': dace.data.Scalar(dace.int32)}, 'n', dtype=np.int32) + assert isinstance(desc, dace.data.Array) + assert str(desc.shape[0]).startswith('__sym_n') + assert desc.dtype == dace.int32 + + @dace.program + def arange_scalar(n: dace.int32): + return np.sum(np.arange(n, dtype=np.int32)) + + result = arange_scalar(np.int32(7)) + expected = np.sum(np.arange(7, dtype=np.int32)) + assert result == expected + + +def test_arange_data_scalar_stop(): + desc = _infer_arange({'A[0]': dace.data.Scalar(dace.int32)}, 'A[0]', dtype=np.int32) + assert isinstance(desc, dace.data.Array) + assert str(desc.shape[0]).startswith('__sym_A_0_') + assert desc.dtype == dace.int32 + + @dace.program + def arange_data_scalar(A: dace.int32[1]): + return np.sum(np.arange(A[0], dtype=np.int32)) + + A = np.array([7], dtype=np.int32) + result = arange_data_scalar(A) + expected = np.sum(np.arange(A[0], dtype=np.int32)) + assert result == expected + + +def test_arange_data_scalar_stop_repromotes_after_write(): + + @dace.program + def arange_data_scalar_twice(A: dace.int32[1]): + first = np.sum(np.arange(A[0], dtype=np.int32)) + A[0] += 1 + second = np.sum(np.arange(A[0], dtype=np.int32)) + return first, second + + A = np.array([7], dtype=np.int32) + first, second = arange_data_scalar_twice(A) + assert first == np.sum(np.arange(7, dtype=np.int32)) + assert second == np.sum(np.arange(8, dtype=np.int32)) + + @compare_numpy_output() def test_linspace_1(): return np.linspace(2.5, 10, num=3) @@ -304,6 +425,9 @@ def ones_scalar_size(k: dace.int32): test_arange_4() test_arange_5() test_arange_6() + test_arange_symbolic_stop() + test_arange_scalar_stop() + test_arange_data_scalar_stop() test_linspace_1() test_linspace_2() test_linspace_3() @@ -317,3 +441,8 @@ def ones_scalar_size(k: dace.int32): test_zeros_symbolic_size_scalar() test_ones_scalar_size_scalar() test_ones_scalar_size() + test_array_literal_inside_expression() + test_array_literal_from_dynamic_scalar_elements() + test_list_literal_inside_array_expression() + test_constant_list_literal_inside_array_expression_materializes_as_one_constant_array() + test_broadcast_mixed_tuple_and_list_literals_inside_expression() diff --git a/tests/numpy/negative_indices_test.py b/tests/numpy/negative_indices_test.py index fa1fd1dcf6..3806aff111 100644 --- a/tests/numpy/negative_indices_test.py +++ b/tests/numpy/negative_indices_test.py @@ -14,6 +14,23 @@ def test_negative_index(): assert out[0] == A[-2] +@dace.program +def runtime_negative_index(A: dace.int64[10], i: dace.int64): + return A[i] + + +def test_runtime_negative_index(): + A = np.random.randint(0, 100, size=10, dtype=np.int64) + + with dace.config.set_temporary('frontend', 'runtime_negative_indices', value=True): + sdfg = runtime_negative_index.to_sdfg(A, np.int64(-2), simplify=False) + code = sdfg.generate_code()[0].clean_code + out = sdfg(A=A, i=np.int64(-2)) + + assert 'py_mod(__sym_i, 10)' in code + assert out[0] == A[-2] + + @dace.program def nested_negative_index(A: dace.int64[10]): out = np.ndarray([2], dtype=np.int64) @@ -55,6 +72,28 @@ def test_nested_negative_range(): assert np.array_equal(out[5:], A[-6:-1]) +@dace.program +def runtime_nested_negative_range(A: dace.int64[10], offset: dace.int64, offset2: dace.int64): + out = np.ndarray([10], dtype=np.int64) + for i in dace.map[0:2]: + out[i * 5:i * 5 + 5] = np.sum(A[offset:offset2]) + return out + + +def test_runtime_nested_negative_range(): + A = np.random.randint(0, 100, size=10, dtype=np.int64) + + with dace.config.set_temporary('frontend', 'runtime_negative_indices', value=True): + sdfg = runtime_nested_negative_range.to_sdfg(A, np.int64(-6), np.int64(-1), simplify=False) + runtime_code = sdfg.generate_code()[0].clean_code + out = sdfg(A=A, offset=np.int64(-6), offset2=np.int64(-1)) + + expected = np.full(5, np.sum(A[-6:-1]), dtype=np.int64) + assert 'py_mod(' not in runtime_code.split('reduce_1_1_6', 1)[-1] + assert np.array_equal(out[:5], expected) + assert np.array_equal(out[5:], expected) + + @dace.program def jacobi_2d(A: dace.float64[10, 10], B: dace.float64[10, 10]): for t in range(1, 10): diff --git a/tests/python_frontend/callback_autodetect_test.py b/tests/python_frontend/callback_autodetect_test.py index ee2ce7af21..3674f07454 100644 --- a/tests/python_frontend/callback_autodetect_test.py +++ b/tests/python_frontend/callback_autodetect_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ Tests automatic detection and baking of callbacks in the Python frontend. """ +import asyncio from typing import Dict, Union import dace import numpy as np @@ -141,6 +142,59 @@ def autocallback_method(A: dace.float64[N, N]): assert np.allclose(out, nd.q * A) +def test_async_dace_program_is_rejected(): + + @dace.program + async def async_prog(A: dace.float64[N]): + return A + + with pytest.raises(SyntaxError, match='Async @dace.program functions are unsupported'): + async_prog.to_sdfg() + + +def test_async_callback_without_running_loop(): + + async def async_scale(a): + await asyncio.sleep(0) + return 2.0 * a + + @dace.program + def autocallback_async(A: dace.float64[N], B: dace.float64[N]): + tmp: dace.float64[N] = async_scale(A) + B[:] = tmp + + A = np.random.rand(24) + B = np.zeros_like(A) + + with pytest.warns(match="Automatically creating callback"): + autocallback_async(A, B) + + assert np.allclose(B, 2.0 * A) + + +def test_async_callback_with_running_loop(): + + async def async_scale(a): + await asyncio.sleep(0) + return 2.0 * a + + @dace.program + def autocallback_async(A: dace.float64[N], B: dace.float64[N]): + tmp: dace.float64[N] = async_scale(A) + B[:] = tmp + + A = np.random.rand(24) + B = np.zeros_like(A) + + async def invoke(): + with pytest.warns(match="Automatically creating callback"): + autocallback_async(A, B) + + asyncio.run(invoke()) + + assert np.allclose(B, 2.0 * A) + + @dace.program def modcallback(A: dace.float64[N, N], B: dace.float64[N]): tmp: dace.float64[N] = np.median(A, axis=1) diff --git a/tests/python_frontend/conftest.py b/tests/python_frontend/conftest.py new file mode 100644 index 0000000000..735659f776 --- /dev/null +++ b/tests/python_frontend/conftest.py @@ -0,0 +1,30 @@ +import contextlib +import importlib.util +import pathlib +import sys +import tempfile +import uuid + +import pytest + + +@pytest.fixture +def temp_python_module(): + + @contextlib.contextmanager + def _load(module_source: str, module_name_prefix: str = 'dace_temp_module'): + with tempfile.TemporaryDirectory() as temp_dir: + module_path = pathlib.Path(temp_dir) / 'temp_module.py' + module_path.write_text(module_source) + + module_name = f'{module_name_prefix}_{uuid.uuid4().hex}' + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + try: + yield module + finally: + sys.modules.pop(module_name, None) + + return _load diff --git a/tests/python_frontend/memlet_parser_test.py b/tests/python_frontend/memlet_parser_test.py new file mode 100644 index 0000000000..1f15619e4a --- /dev/null +++ b/tests/python_frontend/memlet_parser_test.py @@ -0,0 +1,85 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +import ast + +import dace + +from dace.frontend.python.memlet_parser import parse_memlet_subset + + +def _subset_axis_values(subset): + axis_values = [] + for (begin, _, step), size in zip(subset.ranges, subset.size()): + axis_values.append([int(begin + step * i) for i in range(int(size))]) + return axis_values + + +def _expected_axis_values(shape, *layers): + axis_values = [list(range(extent)) for extent in shape] + remaining_axes = list(range(len(shape))) + new_axes = [] + + for layer in layers: + consumed = 0 + next_remaining_axes = [] + new_axes = [] + output_pos = 0 + + for item in layer: + if item is None: + new_axes.append(output_pos) + output_pos += 1 + continue + + axis = remaining_axes[consumed] + if isinstance(item, slice): + axis_values[axis] = axis_values[axis][item] + next_remaining_axes.append(axis) + else: + axis_values[axis] = [axis_values[axis][item]] + consumed += 1 + output_pos += 1 + + remaining_axes = next_remaining_axes + + return axis_values, new_axes + + +def test_parse_memlet_subset_nested_subscripts_keep_original_dimension_mapping(): + layer1 = (slice(0, 50, 2), 1, slice(None), slice(2, 40, 3), 4, slice(None), slice(5, 55, 5), slice(None), 8, + slice(None), slice(10, 60, 10), slice(None), 12, slice(None), slice(14, 62, 8), slice(None), 16, + slice(None), slice(18, 58, 4), slice(None)) + layer2 = (slice(None), 5, slice(1, 4), None, slice(None), 2, slice(None), slice(0, 2), 3, slice(None), slice(1, 3), + None, slice(None), 4, slice(None), slice(1, 5, 2), 6) + expr = ast.parse( + 'A[0:50:2, 1, :, 2:40:3, 4, :, 5:55:5, :, 8, :, 10:60:10, :, 12, :, 14:62:8, :, 16, :, 18:58:4, :][:, 5, 1:4, None, :, 2, :, 0:2, 3, :, 1:3, None, :, 4, :, 1:5:2, 6]', + mode='eval').body + array = dace.data.Array(dace.float64, [64] * 20) + + subset, new_axes, arrdims = parse_memlet_subset(array, expr, {'A': array}) + expected_axis_values, expected_new_axes = _expected_axis_values(array.shape, layer1, layer2) + + assert _subset_axis_values(subset) == expected_axis_values + assert new_axes == expected_new_axes + assert arrdims == {} + + +def test_parse_memlet_subset_three_nested_subscripts_keep_original_dimension_mapping(): + layer1 = (slice(0, 50, 2), 1, slice(None), slice(2, 40, 3), 4, slice(None), slice(5, 55, 5), slice(None), 8, + slice(None), slice(10, 60, 10), slice(None), 12, slice(None), slice(14, 62, 8), slice(None), 16, + slice(None), slice(18, 58, 4), slice(None)) + layer2 = (slice(None), 5, slice(1, 4), slice(None), 2, slice(None), slice(0, 2), 3, slice(None), slice(1, 3), + slice(None), 4, slice(None), slice(1, 5, 2), 6) + layer3 = (slice(2, 10, 2), 1, None, slice(5, 20, 3), slice(None), 1, slice(4, 9, 2), 0, slice(1, 5, + 2), slice(1, + 5), 0, None) + expr = ast.parse( + 'A[0:50:2, 1, :, 2:40:3, 4, :, 5:55:5, :, 8, :, 10:60:10, :, 12, :, 14:62:8, :, 16, :, 18:58:4, :][:, 5, 1:4, :, 2, :, 0:2, 3, :, 1:3, :, 4, :, 1:5:2, 6][2:10:2, 1, None, 5:20:3, :, 1, 4:9:2, 0, 1:5:2, 1:5, 0, None]', + mode='eval').body + array = dace.data.Array(dace.float64, [64] * 20) + + subset, new_axes, arrdims = parse_memlet_subset(array, expr, {'A': array}) + expected_axis_values, expected_new_axes = _expected_axis_values(array.shape, layer1, layer2, layer3) + + assert _subset_axis_values(subset) == expected_axis_values + assert new_axes == expected_new_axes + assert arrdims == {} diff --git a/tests/python_frontend/parallel_schedule_tree_test.py b/tests/python_frontend/parallel_schedule_tree_test.py new file mode 100644 index 0000000000..3bb6b69000 --- /dev/null +++ b/tests/python_frontend/parallel_schedule_tree_test.py @@ -0,0 +1,46 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np +import pytest + + +def test_parallel_schedule_tree_statement_nodes_raise_on_to_sdfg(): + + class AttrHolder: + + def __init__(self): + self.arr = np.zeros(4, dtype=np.float64) + + attr_holder = AttrHolder() + + @dace.program + def prog(A: dace.float64[4], out: dace.float64[4]): + attr_holder.arr = A + out[:] = attr_holder.arr + + with pytest.raises(RuntimeError, match=r'StatementNode'): + prog.to_sdfg(simplify=False) + + +def test_parallel_schedule_tree_warns_for_refsets_and_pythonclasses(): + + class Holder: + scalar: dace.float64 + + PythonHolder = dace.data.PythonClass.from_class(Holder) + + @dace.program + def prog(holder: PythonHolder, A: dace.float64[4], out: dace.float64[4]): + holder.new_data = A + out[:] = holder.new_data[:] + + with pytest.warns(UserWarning) as captured: + sdfg = prog.to_sdfg(simplify=False) + + messages = [str(record.message) for record in captured] + assert any('RefSetNode target "holder.new_data"' in message for message in messages) + assert any('PythonClass container "holder"' in message for message in messages) + assert any( + isinstance(descriptor, dace.data.Reference) + for _, _, descriptor in sdfg.arrays_recursive(include_nested_data=True)) diff --git a/tests/python_frontend/preparse_test.py b/tests/python_frontend/preparse_test.py index c46e71e88b..2e45aedc2e 100644 --- a/tests/python_frontend/preparse_test.py +++ b/tests/python_frontend/preparse_test.py @@ -3,7 +3,10 @@ import dace import numpy as np import os +import pytest +import sys import tempfile +from dace.frontend.python.common import DaceSyntaxError def test_nested_objects_same_name(): @@ -90,6 +93,62 @@ def outer(self, A): assert res.call_tree_length() == 2 +def test_type_alias_is_compile_time_only_in_dace_program(temp_python_module): + if sys.version_info < (3, 12): + pytest.skip('Type alias statements require Python 3.12+') + + with temp_python_module(''' +import dace + +@dace.program +def prog(A: dace.float32[4]): + type dtype = dace.float32 + tmp: dtype = 1 + A[0] = tmp +''', + module_name_prefix='dace_preparse_typealias') as module: + array = np.zeros(4, dtype=np.float32) + module.prog(array) + + assert array[0] == np.float32(1) + + +def test_generic_type_alias_is_rejected_in_dace_program(temp_python_module): + if sys.version_info < (3, 12): + pytest.skip('Type alias statements require Python 3.12+') + + with temp_python_module(''' +import dace + +@dace.program +def prog(A: dace.float32[4]): + type dtype[T] = T + return A +''', + module_name_prefix='dace_preparse_typealias') as module: + array = np.zeros(4, dtype=np.float32) + with pytest.raises(DaceSyntaxError, match='Generic type aliases'): + module.prog(array) + + +def test_type_var_tuple_alias_is_rejected_in_dace_program(temp_python_module): + if sys.version_info < (3, 12): + pytest.skip('Type alias statements require Python 3.12+') + + with temp_python_module(''' +import dace + +@dace.program +def prog(A: dace.float32[4]): + type dtype[*Ts] = tuple[*Ts] + return A +''', + module_name_prefix='dace_preparse_typealias') as module: + array = np.zeros(4, dtype=np.float32) + with pytest.raises(DaceSyntaxError, match='Generic type aliases'): + module.prog(array) + + def test_same_function_different_closure(): arrx = np.full([20], 1) arry = np.full([20], 2) diff --git a/tests/python_frontend/pythonclass_test.py b/tests/python_frontend/pythonclass_test.py new file mode 100644 index 0000000000..8dfb0415ae --- /dev/null +++ b/tests/python_frontend/pythonclass_test.py @@ -0,0 +1,205 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np + + +def test_pythonclass_scalar_rebind_and_new_field_codegen(): + + class Holder: + scalar: dace.float64 + + PythonHolder = dace.data.PythonClass.from_class(Holder) + + @dace.program + def prog(holder: PythonHolder, A: dace.float64[4]): + holder.scalar = A[0] + holder.new_field = A[1] + + sdfg = prog.to_sdfg(simplify=False) + + assert 'holder.scalar' in sdfg.arrays + assert 'holder.new_field' in sdfg.arrays + assert isinstance(sdfg.arrays['holder.scalar'], dace.data.Scalar) + assert isinstance(sdfg.arrays['holder.new_field'], dace.data.Scalar) + + assignment_targets = {'holder.scalar', 'holder.new_field'} + assignment_states = [] + for state in sdfg.nodes(): + for edge in state.edges(): + if getattr(edge.dst, 'data', None) in assignment_targets: + assignment_states.append(state) + break + + assert len(assignment_states) == 2 + assert all(not any(isinstance(node, dace.nodes.Tasklet) for node in state.nodes()) for state in assignment_states) + + code = sdfg.generate_code()[0].clean_code + assert 'dace_set_pyobject_attr(holder, "scalar",' in code + assert 'dace_set_pyobject_attr(holder, "new_field",' in code + + holder = Holder() + holder.scalar = -1.0 + values = np.array([3.5, 7.25, 0.0, 0.0], dtype=np.float64) + + prog(holder, values) + + assert holder.scalar == values[0] + assert holder.new_field == values[1] + + +def test_pythonclass_literal_scalar_assignment_uses_tasklet_output_codegen(): + + class Holder: + scalar: dace.float64 + + PythonHolder = dace.data.PythonClass.from_class(Holder) + + @dace.program + def prog(holder: PythonHolder): + holder.new_field = 4.25 + + sdfg = prog.to_sdfg(simplify=False) + + assert 'holder.new_field' in sdfg.arrays + assert isinstance(sdfg.arrays['holder.new_field'], dace.data.Scalar) + + assignment_states = [] + for state in sdfg.nodes(): + for edge in state.edges(): + if getattr(edge.dst, 'data', None) == 'holder.new_field': + assignment_states.append(state) + break + + assert len(assignment_states) == 1 + assert sum(isinstance(node, dace.nodes.Tasklet) for node in assignment_states[0].nodes()) == 1 + assert next(node for node in assignment_states[0].nodes() + if isinstance(node, dace.nodes.Tasklet)).language == dace.dtypes.Language.Python + assert not any(state.label.startswith('pythonclass_attr_barrier_') for state in sdfg.nodes()) + + code = sdfg.generate_code()[0].clean_code + assert 'dace_set_pyobject_attr(holder, "new_field",' in code + + holder = Holder() + holder.scalar = -1.0 + + prog(holder) + + assert holder.new_field == 4.25 + + +def test_pythonclass_array_field_access_codegen(): + + class Holder: + data: dace.float64[4] + + PythonHolder = dace.data.PythonClass.from_class(Holder) + + @dace.program + def prog(holder: PythonHolder): + for i in range(4): + holder.data[i] = holder.data[i] + 1.0 + + code = prog.to_sdfg(simplify=False).generate_code()[0].clean_code + assert 'dace_get_pyobject_attr_ptr(holder, "data")' in code + + holder = Holder() + holder.data = np.array([1.0, 2.5, -3.0, 0.25], dtype=np.float64) + expected = holder.data + 1.0 + + prog(holder) + + assert np.allclose(holder.data, expected) + + +def test_pythonclass_nested_array_field_access_codegen(): + + class Inner: + data: dace.float64[4] + + class Outer: + inner: Inner + + PythonOuter = dace.data.PythonClass.from_class(Outer) + + @dace.program + def prog(holder: PythonOuter): + for i in range(4): + holder.inner.data[i] = holder.inner.data[i] + 1.0 + + code = prog.to_sdfg(simplify=False).generate_code()[0].clean_code + assert 'dace_get_pyobject_attr_ptr(holder, "inner.data")' in code + + class InnerRuntime: + pass + + class OuterRuntime: + pass + + holder = OuterRuntime() + holder.inner = InnerRuntime() + holder.inner.data = np.array([1.0, 2.5, -3.0, 0.25], dtype=np.float64) + expected = holder.inner.data + 1.0 + + prog(holder) + + assert np.allclose(holder.inner.data, expected) + + +def test_pythonclass_new_array_field_assignment_uses_reference_set(): + + class Holder: + scalar: dace.float64 + + PythonHolder = dace.data.PythonClass.from_class(Holder) + + @dace.program + def prog(holder: PythonHolder, A: dace.float64[4], out: dace.float64[4]): + holder.new_data = A + out[:] = holder.new_data[:] + + sdfg = prog.to_sdfg(simplify=False) + + assert 'holder.new_data' in sdfg.arrays + assert isinstance(sdfg.arrays['holder.new_data'], dace.data.ArrayReference) + + set_edges = [] + assignment_state = None + for state in sdfg.nodes(): + for edge in state.edges(): + if getattr(edge.dst, 'data', None) == 'holder.new_data' and edge.dst_conn == 'set': + set_edges.append(edge) + assignment_state = state + + assert len(set_edges) == 1 + assert assignment_state is not None + assert not any(isinstance(node, dace.nodes.Tasklet) for node in assignment_state.nodes()) + assert not any(state.label.startswith('pythonclass_attr_barrier_') for state in sdfg.nodes()) + + code = sdfg.generate_code()[0].clean_code + assert 'dace_get_pyobject_attr_ptr(holder, "new_data") = A;' not in code + assert 'dace_get_pyobject_attr_ptr(holder, "new_data")' in code + assert 'dace_set_pyobject_attr_array(holder, "new_data", A, __shape,' in code + assert '__strides, 1);' in code + + holder = Holder() + holder.scalar = -1.0 + values = np.array([1.0, 2.5, -3.0, 0.25], dtype=np.float64) + out = np.zeros_like(values) + + prog(holder, values, out) + + assert isinstance(holder.new_data, np.ndarray) + assert np.allclose(out, values) + assert np.allclose(holder.new_data, values) + + values[0] = 9.5 + assert holder.new_data[0] == values[0] + + +if __name__ == '__main__': + test_pythonclass_scalar_rebind_and_new_field_codegen() + test_pythonclass_literal_scalar_assignment_uses_tasklet_output_codegen() + test_pythonclass_array_field_access_codegen() + test_pythonclass_nested_array_field_access_codegen() + test_pythonclass_new_array_field_assignment_uses_reference_set() diff --git a/tests/python_frontend/schedule_tree/array_literal_support_test.py b/tests/python_frontend/schedule_tree/array_literal_support_test.py new file mode 100644 index 0000000000..1032adeeb1 --- /dev/null +++ b/tests/python_frontend/schedule_tree/array_literal_support_test.py @@ -0,0 +1,48 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +import pytest + +import dace +from dace import data +from dace.frontend.python.schedule_tree.array_literal_support import infer_array_literal_descriptor + + +def test_infer_array_literal_descriptor_for_nested_list_constants(): + node = ast.parse('[[1, 2], [3, 4]]', mode='eval').body + descriptor = infer_array_literal_descriptor(node, lambda _: None, lambda *_: None, lambda: {}) + + assert isinstance(descriptor, data.Array) + assert descriptor.dtype == dace.int64 + assert tuple(descriptor.shape) == (2, 2) + + +def test_infer_array_literal_descriptor_for_dynamic_scalar_elements(): + node = ast.parse('[A[0], B[i]]', mode='eval').body + scalar_desc = data.Scalar(dace.float64, transient=True) + + def infer_descriptor(expr): + text = ast.unparse(expr) + if text in {'A[0]', 'B[i]'}: + return scalar_desc + return None + + descriptor = infer_array_literal_descriptor(node, infer_descriptor, lambda *_: None, lambda: {}) + + assert isinstance(descriptor, data.Array) + assert descriptor.dtype == dace.float64 + assert tuple(descriptor.shape) == (2, ) + + +def test_infer_numpy_array_literal_descriptor_respects_dtype_and_ndmin(): + node = ast.parse('np.array([1, 2], dtype=np.float32, ndmin=2)', mode='eval').body + context = lambda: {'np': __import__('numpy')} + descriptor = infer_array_literal_descriptor(node, lambda _: None, lambda *_: None, context) + + assert isinstance(descriptor, data.Array) + assert descriptor.dtype == dace.float32 + assert tuple(descriptor.shape) == (1, 2) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/python_frontend/schedule_tree/attribute_rewriter_test.py b/tests/python_frontend/schedule_tree/attribute_rewriter_test.py new file mode 100644 index 0000000000..aac654f1cf --- /dev/null +++ b/tests/python_frontend/schedule_tree/attribute_rewriter_test.py @@ -0,0 +1,161 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +import pytest + +import dace +import numpy as np +from dace.frontend.python import astutils +from dace.frontend.python.schedule_tree import AttributeRewriter +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +def _rewrite_expression(source: str, context): + rewriter = AttributeRewriter(lambda: dict(context)) + expr = ast.parse(source, mode='eval').body + return astutils.unparse(rewriter.rewrite_expression(expr)) + + +def _rewrite_assignment(source: str, context): + rewriter = AttributeRewriter(lambda: dict(context)) + assign = ast.parse(source).body[0] + rewritten = rewriter.rewrite_assignment(assign.targets[0], assign.value) + return None if rewritten is None else astutils.unparse(rewritten) + + +def test_attribute_rewriter_rewrites_descriptor_loads_and_stores(): + + class ArrayDescriptor: + + def __set_name__(self, owner, name): + self.name = '_' + name + + def __get__(self, obj, objtype=None): + return getattr(obj, self.name) + + def __set__(self, obj, value): + setattr(obj, self.name, value) + + class DescriptorHolder: + arr = ArrayDescriptor() + + def __init__(self): + self.arr = None + + descriptor_holder = DescriptorHolder() + context = {'descriptor_holder': descriptor_holder} + + assert _rewrite_assignment('descriptor_holder.arr = A', + context) == ("type(descriptor_holder).__dict__['arr'].__set__(descriptor_holder, A)") + assert _rewrite_expression( + 'descriptor_holder.arr', + context) == ("type(descriptor_holder).__dict__['arr'].__get__(descriptor_holder, type(descriptor_holder))") + + +def test_attribute_rewriter_rewrites_custom_getattribute_and_setattr(): + + class Proxy: + + def __getattribute__(self, name): + return object.__getattribute__(self, name) + + def __setattr__(self, name, value): + object.__setattr__(self, name, value) + + proxy = Proxy() + context = {'proxy': proxy} + + assert _rewrite_expression('proxy.value', context) == "type(proxy).__getattribute__(proxy, 'value')" + assert _rewrite_assignment('proxy.value = A', context) == "type(proxy).__setattr__(proxy, 'value', A)" + + +def test_attribute_rewriter_preserves_plain_attribute_syntax(): + + class Holder: + + def __init__(self): + self.value = None + + holder = Holder() + context = {'holder': holder} + + assert _rewrite_expression('holder.value', context) == 'holder.value' + assert _rewrite_assignment('holder.value = A', context) is None + + +def test_attribute_rewriter_preserves_plain_method_syntax(): + + class Holder: + + def method(self, value): + return value + + @staticmethod + def static_method(value): + return value + + @classmethod + def class_method(cls, value): + return value + + holder = Holder() + context = {'holder': holder} + + assert _rewrite_expression('holder.method', context) == 'holder.method' + assert _rewrite_expression('holder.method(A)', context) == 'holder.method(A)' + assert _rewrite_expression('holder.static_method(A)', context) == 'holder.static_method(A)' + assert _rewrite_expression('holder.class_method(A)', context) == 'holder.class_method(A)' + + +def test_attribute_rewriter_preserves_mpi4py_method_syntax(): + MPI = pytest.importorskip('mpi4py.MPI') + + commworld = MPI.COMM_WORLD + context = {'commworld': commworld} + + assert _rewrite_expression('commworld.Bcast(A)', context) == 'commworld.Bcast(A)' + + +def test_schedule_tree_lowers_plain_object_registered_methods(): + MPI = pytest.importorskip('mpi4py.MPI') + + commworld = MPI.COMM_WORLD + + @dace.program + def comm_world_bcast(A: dace.int32[10]): + commworld.Bcast(A) + + stree = comm_world_bcast.to_schedule_tree(np.zeros((10, ), dtype=np.int32)) + + assert isinstance(stree.children[0], tn.LibraryCall) + assert stree.children[0].node.name == 'Bcast' + assert stree.children[0].node.properties['receiver_class'] == 'Intracomm' + assert stree.children[0].node.properties['receiver'] == 'commworld' + assert not any(isinstance(node, tn.StatementNode) for node in stree.preorder_traversal()) + + +def test_schedule_tree_infers_plain_object_registered_method_results(): + MPI = pytest.importorskip('mpi4py.MPI') + + commworld = MPI.COMM_WORLD + + @dace.program + def comm_world_isend(A: dace.int32[1]): + req = commworld.Isend(A, 0, 0) + return req + + stree = comm_world_isend.to_schedule_tree(np.zeros((1, ), dtype=np.int32)) + library_calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.LibraryCall)] + + assert len(library_calls) == 1 + assert library_calls[0].node.name == 'Isend' + request_name = library_calls[0].out_memlets['out'].data + request_desc = stree.containers[request_name] + assert isinstance(request_desc, dace.data.Array) + assert tuple(request_desc.shape) == (1, ) + assert isinstance(request_desc.dtype, dace.dtypes.opaque) + assert not any(isinstance(node, tn.StatementNode) for node in stree.preorder_traversal()) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/python_frontend/schedule_tree/callable_support_test.py b/tests/python_frontend/schedule_tree/callable_support_test.py new file mode 100644 index 0000000000..90e3a204d6 --- /dev/null +++ b/tests/python_frontend/schedule_tree/callable_support_test.py @@ -0,0 +1,75 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast + +import pytest + +from dace import data, dtypes +from dace.frontend.python.schedule_tree.callable_support import CallableArgumentSpecializer, CallableResolver +from dace.frontend.python.schedule_tree.lambda_support import LambdaResolver +from dace.frontend.python.schedule_tree.type_inference import _Binding + + +def test_callable_specializer_detects_callback_expressions(): + callback_descriptor = data.Scalar(dtypes.callback(None), transient=False) + callable_resolver = CallableResolver(callable_bindings={}, evaluation_context=lambda: {}) + helper = CallableArgumentSpecializer( + lambda_resolver=LambdaResolver({}, {'f': ast.parse('lambda a: a', mode='eval').body}, {}), + callable_resolver=callable_resolver, + bindings={'cb': _Binding(descriptor=callback_descriptor, kind='callback', structure=None)}, + infer_descriptor=lambda node: None, + resolve_data_access=lambda node: None, + is_callback_descriptor=lambda descriptor: isinstance(descriptor, data.Scalar) and isinstance( + descriptor.dtype, dtypes.callback), + callback_specialization_value=lambda: callback_descriptor) + + assert helper.is_callback_expression(ast.Name(id='f', ctx=ast.Load())) + assert helper.is_callback_expression(ast.Name(id='cb', ctx=ast.Load())) + + +def test_callable_specializer_extracts_lambda_and_callable_bindings(): + + def cb(value): + return value + + def inner(A, f, cb=None, literal=None): + return A + + callback_descriptor = data.Scalar(dtypes.callback(None), transient=False) + array_descriptor = data.Scalar(dtypes.float64, transient=True) + callable_resolver = CallableResolver(callable_bindings={ + 'inner': inner, + 'cb': cb + }, + evaluation_context=lambda: { + 'inner': inner, + 'cb': cb + }) + helper = CallableArgumentSpecializer( + lambda_resolver=LambdaResolver({}, {'f': ast.parse('lambda a: a', mode='eval').body}, {'cb': cb}), + callable_resolver=callable_resolver, + bindings={}, + infer_descriptor=lambda node: array_descriptor if isinstance(node, ast.Name) and node.id == 'A' else None, + resolve_data_access=lambda node: None, + is_callback_descriptor=lambda descriptor: isinstance(descriptor, data.Scalar) and isinstance( + descriptor.dtype, dtypes.callback), + callback_specialization_value=lambda: callback_descriptor) + + call_node = ast.parse('inner(A, f, cb=cb, literal=5)', mode='eval').body + + args, kwargs, lambda_bindings, callable_bindings = helper.extract_call_specialization(call_node, ast.unparse) + + assert len(args) == 2 + assert isinstance(args[0], data.Scalar) + assert args[0].dtype == dtypes.float64 + assert args[0].transient is False + assert isinstance(args[1], data.Scalar) + assert isinstance(args[1].dtype, dtypes.callback) + assert kwargs['cb'] is cb + assert kwargs['literal'] == 5 + assert 'f' in lambda_bindings + assert callable_bindings == {'cb': cb} + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/python_frontend/schedule_tree/callback_support_test.py b/tests/python_frontend/schedule_tree/callback_support_test.py new file mode 100644 index 0000000000..e20febaf4f --- /dev/null +++ b/tests/python_frontend/schedule_tree/callback_support_test.py @@ -0,0 +1,79 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +from types import SimpleNamespace + +from dace import data, dtypes + +from dace.frontend.python.schedule_tree.callable_support import CallableResolver +from dace.frontend.python.schedule_tree.callback_support import CallbackHandler, CallbackOutliner +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +def test_callback_outliner_wraps_assignment_as_function_and_call(): + node = ast.parse('it = iter(generator)').body[0] + + function_code, call_code = CallbackOutliner.outline(node, + callback_name='__stree_callback', + input_names=[], + output_names=['it']) + + assert function_code.as_string.startswith('def __stree_callback():') + assert 'it = iter(generator)' in function_code.as_string + assert function_code.as_string.endswith('return it') + assert call_code.as_string == 'it = __stree_callback()' + + +def test_callback_outliner_supports_statement_groups(): + body = ast.parse('x = a + 1\ny = x + 1').body + + function_code, call_code = CallbackOutliner.outline(body, + callback_name='__stree_callback', + input_names=['a'], + output_names=['x', 'y']) + + assert function_code.as_string.startswith('def __stree_callback(a):') + assert 'x = (a + 1)' in function_code.as_string + assert 'y = (x + 1)' in function_code.as_string + assert function_code.as_string.endswith('return (x, y)') + assert call_code.as_string == '(x, y) = __stree_callback(a)' + + +def test_callback_handler_wraps_node_and_registers_unknown_outputs(): + appended_nodes = [] + bindings = {'generator': SimpleNamespace(descriptor=data.Scalar(dtypes.int64, transient=True), kind='scalar')} + + def _register_binding(name, descriptor, kind): + bindings[name] = SimpleNamespace(descriptor=descriptor, kind=kind) + + handler = CallbackHandler(bindings=bindings, + callback_mutated_global_names=set(), + callable_resolver=CallableResolver(callable_bindings={}, evaluation_context=lambda: {}), + evaluation_context=lambda: {}, + append_node=appended_nodes.append, + register_binding=_register_binding, + fresh_callback_name=lambda: '__stree_callback', + fresh_transient_name=lambda prefix='__stree_tmp': prefix, + render_callback_code=ast.unparse, + collect_scope_declarations=lambda node: (set(), set()), + raise_syntax_error=lambda node, message: (_ for _ in ()).throw(AssertionError(message)), + binding_kind_for_descriptor=lambda descriptor: 'scalar', + pyobject_scalar_descriptor=lambda: data.Scalar(dtypes.pyobject(), transient=True), + is_pyobject_scalar_descriptor=lambda descriptor: isinstance( + getattr(descriptor, 'dtype', None), dtypes.pyobject), + is_iterator_protocol_call=lambda value: False, + is_iterator_next_call=lambda value: False) + + handler.wrap_node(ast.parse('it = iter(generator)').body[0], 'pyobject call') + + assert len(appended_nodes) == 1 + callback_node = appended_nodes[0] + assert isinstance(callback_node, tn.PythonCallbackNode) + assert callback_node.reason == 'pyobject call' + assert callback_node.input_names == ['generator'] + assert callback_node.output_names == ['it'] + assert callback_node.outlined_function_name == '__stree_callback' + assert callback_node.outlined_function_code is not None + assert callback_node.outlined_call_code is not None + assert bindings['it'].kind == 'scalar' + assert bindings['it'].descriptor.dtype == dtypes.pyobject() diff --git a/tests/python_frontend/schedule_tree/desugaring_test.py b/tests/python_frontend/schedule_tree/desugaring_test.py new file mode 100644 index 0000000000..dc7bdbd4a9 --- /dev/null +++ b/tests/python_frontend/schedule_tree/desugaring_test.py @@ -0,0 +1,85 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast + +import dace +from dace.frontend.python import astutils +from dace.frontend.python.schedule_tree import callback_reason, desugar_schedule_tree_expansions + + +def _desugar_statements(source: str, *, global_vars=None, known_descriptors=None): + desugared = _desugar_module(source, global_vars=global_vars, known_descriptors=known_descriptors) + return [astutils.unparse(statement) for statement in desugared.body] + + +def _desugar_module(source: str, *, global_vars=None, known_descriptors=None): + module = ast.parse(source) + return desugar_schedule_tree_expansions(module, + filename='', + global_vars=dict(global_vars or {}), + known_descriptors=known_descriptors) + + +def test_schedule_tree_desugaring_materializes_analyzable_tuple_assignment_rhs(): + statements = _desugar_statements('A, B = B, A') + assert statements == [ + '__stree_tuple_tmp = (B, A)', '__stree_tuple_tmp_0 = B', '__stree_tuple_tmp_1 = A', 'A = __stree_tuple_tmp_0', + 'B = __stree_tuple_tmp_1' + ] + + +def test_schedule_tree_desugaring_leaves_function_return_destructuring(): + statements = _desugar_statements('A, B = make_pair()') + assert statements == ['(A, B) = make_pair()'] + + +def test_schedule_tree_desugaring_normalizes_parenthesized_function_return_destructuring(): + desugared = _desugar_module('(A, B) = make_pair()') + expected = ast.parse('A, B = make_pair()') + + assert ast.dump(desugared, include_attributes=False) == ast.dump(expected, include_attributes=False) + + +def test_schedule_tree_desugaring_preserves_short_circuit_nested_index_guard(): + statements = _desugar_statements('if flag and A[b[i]] == 0:\n out[0] = 1') + + assert len(statements) == 1 + assert '__stree_idx' not in statements[0] + assert 'A[b[i]]' in statements[0] + + +def test_schedule_tree_desugaring_rewrites_while_with_hoisted_index_to_guarded_infinite_loop(): + statements = _desugar_statements('while A[b[i]] == 0:\n i += 1') + + assert statements == [ + 'while True:\n __stree_idx = b[i]\n if (not (A[__stree_idx] == 0)):\n break\n i += 1' + ] + + +def test_schedule_tree_desugaring_marks_while_else_with_hoisted_index_for_callback(): + module = ast.parse('while A[b[i]] == 0:\n i += 1\nelse:\n out[0] = 1') + desugared = desugar_schedule_tree_expansions(module, filename='', global_vars={}) + + assert len(desugared.body) == 1 + assert callback_reason(desugared.body[0]) == 'while loop test outlining with else' + + +def test_schedule_tree_desugaring_canonicalizes_negative_array_index(): + statements = _desugar_statements('tmp = A[-1]', known_descriptors={'A': dace.float64[5]}) + + assert statements == ['tmp = A[(5 - 1)]'] + + +def test_schedule_tree_desugaring_canonicalizes_symbolic_negative_array_index(): + n = dace.symbol('n') + i = dace.symbol('i', integer=True, positive=True) + + statements = _desugar_statements('tmp = A[-i]', global_vars={'i': i}, known_descriptors={'A': dace.float64[n]}) + + assert statements == ['tmp = A[(n - i)]'] + + +def test_schedule_tree_desugaring_canonicalizes_negative_tuple_index_from_known_length(): + statements = _desugar_statements('t = (a, b, c)\ntmp = t[-1]') + + assert statements == ['t = (a, b, c)', 't_0 = a', 't_1 = b', 't_2 = c', 'tmp = t_2'] diff --git a/tests/python_frontend/schedule_tree/dict_support_test.py b/tests/python_frontend/schedule_tree/dict_support_test.py new file mode 100644 index 0000000000..ce3fc18b21 --- /dev/null +++ b/tests/python_frontend/schedule_tree/dict_support_test.py @@ -0,0 +1,116 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast + +import dace +from dace import data +from dace import dtypes +from dace.data.creation import create_datadescriptor +from dace.data.pydata import PythonDict +from dace.frontend.python.schedule_tree.dict_support import DictSupportContext, DictSupportLibrary, StaticDictBinding, \ + infer_dict_assignment_descriptor, infer_dict_literal_binding, infer_dict_literal_descriptor, \ + infer_dict_subscript_descriptor + + +def test_create_datadescriptor_infers_typed_python_dict(): + descriptor = create_datadescriptor({'a': 1.0, 'b': 2.0}) + + assert isinstance(descriptor, PythonDict) + assert isinstance(descriptor.key_type, data.Scalar) + assert descriptor.key_type.dtype == dace.string + assert isinstance(descriptor.value_type, data.Scalar) + assert descriptor.value_type.dtype == dace.float64 + + +def test_create_datadescriptor_infers_pyobject_for_heterogeneous_values(): + descriptor = create_datadescriptor({'a': 1.0, 'b': 'two'}) + + assert isinstance(descriptor, PythonDict) + assert isinstance(descriptor.key_type, data.Scalar) + assert descriptor.key_type.dtype == dace.string + assert isinstance(descriptor.value_type, data.Scalar) + assert descriptor.value_type.dtype == dtypes.pyobject() + + +def test_infer_dict_literal_descriptor_uses_pyobject_for_unknown_value(): + dict_node = ast.parse("{'left': value, 'right': 2.0}", mode='eval').body + + descriptor = infer_dict_literal_descriptor( + dict_node, lambda node: None, lambda node, annotated: data.Scalar(dace.float64, transient=True) + if isinstance(node, ast.Constant) and isinstance(node.value, float) else None) + + assert isinstance(descriptor, PythonDict) + assert descriptor.key_type.dtype == dtypes.pyobject() + assert descriptor.value_type.dtype == dtypes.pyobject() + + +def test_infer_dict_literal_descriptor_falls_back_per_component(): + dict_node = ast.parse("{'left': value, 'right': 2.0}", mode='eval').body + + descriptor = infer_dict_literal_descriptor( + dict_node, lambda node: data.Scalar(dace.string, transient=True) + if isinstance(node, ast.Constant) and isinstance(node.value, str) else None, + lambda node, annotated: data.Scalar(dace.float64, transient=True) + if isinstance(node, ast.Constant) and isinstance(node.value, float) else None) + + assert isinstance(descriptor, PythonDict) + assert isinstance(descriptor.key_type, data.Scalar) + assert descriptor.key_type.dtype == dace.string + assert isinstance(descriptor.value_type, data.Scalar) + assert descriptor.value_type.dtype == dtypes.pyobject() + + +def test_infer_dict_assignment_descriptor_widens_value_type(): + descriptor = PythonDict(data.Scalar(dace.string, transient=True), + data.Scalar(dace.float64, transient=True), + transient=True) + target = ast.parse("mapping['left']", mode='eval').body + value = ast.parse("'two'", mode='eval').body + + updated = infer_dict_assignment_descriptor( + descriptor, target.slice, value, lambda node: None, + lambda node, annotated: data.Scalar(dace.string, transient=True) + if isinstance(node, ast.Constant) and isinstance(node.value, str) else None, lambda: {}) + + assert isinstance(updated, PythonDict) + assert updated.key_type.dtype == dace.string + assert updated.value_type.dtype == dtypes.pyobject() + + +def test_infer_dict_subscript_descriptor_uses_static_key_binding(): + descriptor = PythonDict(data.Scalar(dace.string, transient=True), + data.Scalar(dtypes.pyobject(), transient=True), + transient=True) + node = ast.parse("{'left': 1.0, 'right': 'two'}", mode='eval').body + binding = infer_dict_literal_binding( + node, lambda current: None, lambda current, annotated: data.Scalar(dace.float64, transient=True) + if isinstance(current, ast.Constant) and isinstance(current.value, float) else + (data.Scalar(dace.string, transient=True) + if isinstance(current, ast.Constant) and isinstance(current.value, str) else None), lambda: {}) + + left = infer_dict_subscript_descriptor(descriptor, ast.parse("'left'", mode='eval').body, lambda: {}, binding) + missing = infer_dict_subscript_descriptor(descriptor, ast.parse("'missing'", mode='eval').body, lambda: {}, binding) + + assert isinstance(left, data.Scalar) + assert left.dtype == dace.float64 + assert missing is None + + +def test_dict_support_library_routes_shared_inference(): + library = DictSupportLibrary() + context = DictSupportContext( + infer_descriptor=lambda current: data.Scalar(dace.string, transient=True) + if isinstance(current, ast.Constant) and isinstance(current.value, str) else None, + infer_scalar_descriptor=lambda current, annotated: data.Scalar(dace.float64, transient=True) + if isinstance(current, ast.Constant) and isinstance(current.value, float) else None, + evaluation_context=lambda: {}) + node = ast.parse("{'left': 1.0, 'right': 2.0}", mode='eval').body + + descriptor = library.infer_literal_descriptor(context, node) + binding = library.infer_literal_binding(context, node) + subscript = library.infer_subscript_descriptor(context, descriptor, ast.parse("'left'", mode='eval').body, binding) + + assert isinstance(descriptor, PythonDict) + assert isinstance(binding, StaticDictBinding) + assert isinstance(subscript, data.Scalar) + assert subscript.dtype == dace.float64 diff --git a/tests/python_frontend/schedule_tree/dunder_support_test.py b/tests/python_frontend/schedule_tree/dunder_support_test.py new file mode 100644 index 0000000000..7947fa0679 --- /dev/null +++ b/tests/python_frontend/schedule_tree/dunder_support_test.py @@ -0,0 +1,409 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +import math + +import dace +import pytest +from dace.frontend.python import astutils +from dace.frontend.python.common import SDFGConvertible +from dace.frontend.python.schedule_tree import desugar_schedule_tree_expansions +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +def _desugar_statements(source: str, *, global_vars=None, known_descriptors=None): + module = ast.parse(source) + desugared = desugar_schedule_tree_expansions(module, + filename='', + global_vars=dict(global_vars or {}), + known_descriptors=known_descriptors) + return [astutils.unparse(statement) for statement in desugared.body] + + +class DunderHost: + + def __call__(self, value): + return value + + def __add__(self, value): + return value + + def __radd__(self, value): + return value + + def __rmatmul__(self, value): + return value + + def __iadd__(self, value): + return value + + def __neg__(self): + return 1 + + def __pos__(self): + return 1 + + def __invert__(self): + return 1 + + def __eq__(self, value): + return False + + def __contains__(self, value): + return False + + def __getitem__(self, index): + return self + + def __setitem__(self, index, value): + return None + + def __delitem__(self, index): + return None + + def __hash__(self): + return 1 + + def __repr__(self): + return 'host' + + def __str__(self): + return 'host' + + def __bool__(self): + return True + + def __int__(self): + return 1 + + def __float__(self): + return 1.0 + + def __bytes__(self): + return b'host' + + def __complex__(self): + return 1j + + def __format__(self, spec): + return spec + + def __len__(self): + return 1 + + def __iter__(self): + return iter(()) + + def __reversed__(self): + return iter(()) + + def __next__(self): + return 1 + + def __divmod__(self, value): + return value + + def __rdivmod__(self, value): + return value + + def __abs__(self): + return 1 + + def __round__(self, digits=None): + return 1 + + def __trunc__(self): + return 1 + + def __floor__(self): + return 1 + + def __ceil__(self): + return 1 + + def __dir__(self): + return [] + + +class ClassSubscriptable: + + @classmethod + def __class_getitem__(cls, value): + return value + + +class MetaCheck(type): + + def __instancecheck__(cls, value): + return True + + def __subclasscheck__(cls, value): + return True + + +class Checked(metaclass=MetaCheck): + pass + + +@pytest.mark.parametrize(('source', 'expected'), [ + ('return obj(A)', 'return obj.__call__(A)'), + ('return obj + A', 'return obj.__add__(A)'), + ('return A + obj', 'return obj.__radd__(A)'), + ('return A @ obj', 'return obj.__rmatmul__(A)'), + ('return -obj', 'return obj.__neg__()'), + ('return +obj', 'return obj.__pos__()'), + ('return ~obj', 'return obj.__invert__()'), + ('return obj == A', 'return obj.__eq__(A)'), + ('return A in obj', 'return obj.__contains__(A)'), + ('return obj[i]', 'return obj.__getitem__(i)'), + ('obj[i] = A', 'obj.__setitem__(i, A)'), + ('del obj[i]', 'obj.__delitem__(i)'), + ('return hash(obj)', 'return obj.__hash__()'), + ('return repr(obj)', 'return obj.__repr__()'), + ('return str(obj)', 'return obj.__str__()'), + ('return bool(obj)', 'return obj.__bool__()'), + ('return int(obj)', 'return obj.__int__()'), + ('return float(obj)', 'return obj.__float__()'), + ('return bytes(obj)', 'return obj.__bytes__()'), + ('return complex(obj)', 'return obj.__complex__()'), + ('return format(obj, spec)', 'return obj.__format__(spec)'), + ('return len(obj)', 'return obj.__len__()'), + ('return iter(obj)', 'return obj.__iter__()'), + ('return reversed(obj)', 'return obj.__reversed__()'), + ('return next(obj)', 'return obj.__next__()'), + ('return divmod(obj, A)', 'return obj.__divmod__(A)'), + ('return divmod(A, obj)', 'return obj.__rdivmod__(A)'), + ('return abs(obj)', 'return obj.__abs__()'), + ('return round(obj)', 'return obj.__round__()'), + ('return round(obj, digits)', 'return obj.__round__(digits)'), + ('return math.trunc(obj)', 'return obj.__trunc__()'), + ('return math.floor(obj)', 'return obj.__floor__()'), + ('return math.ceil(obj)', 'return obj.__ceil__()'), + ('return dir(obj)', 'return obj.__dir__()'), + ('return T[A]', 'return T.__class_getitem__(A)'), + ('return isinstance(x, T)', 'return T.__instancecheck__(x)'), + ('return issubclass(U, T)', 'return T.__subclasscheck__(U)'), +]) +def test_schedule_tree_dunder_desugaring_rewrites_supported_sugar(source, expected): + statements = _desugar_statements(source, + global_vars={ + 'obj': DunderHost(), + 'math': math, + 'T': ClassSubscriptable, + 'Checked': Checked, + }) + + if 'instancecheck' in expected or 'subclasscheck' in expected: + statements = _desugar_statements(source, global_vars={'T': Checked}) + + assert statements == [expected] + + +def test_schedule_tree_dunder_desugaring_leaves_parseable_free_function_call_direct(): + + def callee(value): + return value + + statements = _desugar_statements('return callee(A)', global_vars={'callee': callee}) + + assert statements == ['return callee(A)'] + + +def test_schedule_tree_dunder_desugaring_prefers_direct_operator_on_distinct_objects(): + + class LeftDirect: + + def __matmul__(self, other): + return other + + class RightReflected: + + def __rmatmul__(self, other): + return other + + statements = _desugar_statements('return lhs @ rhs', global_vars={'lhs': LeftDirect(), 'rhs': RightReflected()}) + + assert statements == ['return lhs.__matmul__(rhs)'] + + +def test_schedule_tree_dunder_desugaring_uses_reflected_operator_when_left_is_missing_direct(): + + class LeftPlain: + pass + + class RightReflected: + + def __rmatmul__(self, other): + return other + + statements = _desugar_statements('return lhs @ rhs', global_vars={'lhs': LeftPlain(), 'rhs': RightReflected()}) + + assert statements == ['return rhs.__rmatmul__(lhs)'] + + +def test_schedule_tree_dunder_desugaring_leaves_class_construction_direct(): + + class Builder: + + def __init__(self, value): + self.value = value + + statements = _desugar_statements('return Builder(A)', global_vars={'Builder': Builder}) + + assert statements == ['return Builder(A)'] + + +def test_schedule_tree_dunder_desugaring_rewrites_augassign_before_lowering(): + statements = _desugar_statements('obj += A', global_vars={'obj': DunderHost()}) + + assert statements == ['obj = obj.__iadd__(A)'] + + +def test_schedule_tree_dunder_desugaring_rewrites_subscript_augassign_before_lowering(): + statements = _desugar_statements('obj[i] += A', global_vars={'obj': DunderHost(), 'i': 0}) + + assert statements == ['obj.__setitem__(i, obj.__getitem__(i).__iadd__(A))'] + + +def test_python_frontend_schedule_tree_callable_object_call_is_inlined(): + + class CallableObject: + + @dace.method + def __call__(self, A: dace.float64[8]): + return A + 1 + + callable_object = CallableObject() + + @dace.program + def outer(A: dace.float64[8]): + return callable_object(A) + + stree = outer.to_schedule_tree() + + assert isinstance(stree.children[0], tn.FunctionCallScope) + assert stree.children[0].call.callee_name == '__call__' + assert stree.children[0].call.arguments == {'A': 'A'} + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_parseable_free_function_call_is_inlined(): + + def callee(A: dace.float64[8]): + return A + 1 + + @dace.program + def outer(A: dace.float64[8]): + return callee(A) + + stree = outer.to_schedule_tree() + + assert isinstance(stree.children[0], tn.FunctionCallScope) + assert stree.children[0].call.callee_name == 'callee' + assert stree.children[0].call.arguments == {'A': 'A'} + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_dunder_add_is_inlined(): + + class Adder: + + @dace.method + def __add__(self, A: dace.float64[8]): + return A + 1 + + adder = Adder() + + @dace.program + def outer(A: dace.float64[8]): + return adder + A + + stree = outer.to_schedule_tree() + + assert isinstance(stree.children[0], tn.FunctionCallScope) + assert stree.children[0].call.callee_name == '__add__' + assert stree.children[0].call.arguments == {'A': 'A'} + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_dunder_rmatmul_is_inlined(): + + class Reflector: + + @dace.method + def __rmatmul__(self, A: dace.float64[4, 4]): + return A + 1 + + reflector = Reflector() + + @dace.program + def outer(A: dace.float64[4, 4]): + return A @ reflector + + stree = outer.to_schedule_tree() + + assert isinstance(stree.children[0], tn.FunctionCallScope) + assert stree.children[0].call.callee_name == '__rmatmul__' + assert stree.children[0].call.arguments == {'A': 'A'} + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_sdfg_call_stays_opaque(): + + @dace.program + def inner(A: dace.float64[8], B: dace.float64[8]): + return A + B + + sdfg_obj = inner.to_sdfg() + + @dace.program + def outer(A: dace.float64[8], B: dace.float64[8]): + return sdfg_obj(A, B) + + stree = outer.to_schedule_tree() + + assert not any(isinstance(node, tn.FunctionCallScope) for node in stree.preorder_traversal()) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + assert isinstance(stree.children[0], tn.SDFGCallNode) + assert isinstance(stree.children[0].sdfg, dace.SDFG) + assert stree.children[0].sdfg.name == sdfg_obj.name + assert stree.children[0].call.callee_name.endswith('inner') + assert stree.children[0].call.arguments == {'A': 'A', 'B': 'B'} + assert stree.children[0].return_targets == ['__stree_retval'] + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == '__stree_retval' + + +def test_python_frontend_schedule_tree_sdfg_convertible_call_stays_opaque(): + + class Convertible(SDFGConvertible): + + def __init__(self): + self.name = 'convertible' + + def __call__(self, *args, **kwargs): + raise AssertionError('SDFGConvertible should not execute during schedule-tree generation') + + def __sdfg__(self, A, B): + + @dace.program + def inner(X: dace.float64[8], Y: dace.float64[8]): + return X + Y + + return inner.to_sdfg(A, B) + + def __sdfg_signature__(self): + return ['A', 'B'], [] + + convertible = Convertible() + + @dace.program + def outer(A: dace.float64[8], B: dace.float64[8]): + return convertible(A, B) + + stree = outer.to_schedule_tree() + + assert not any(isinstance(node, tn.FunctionCallScope) for node in stree.preorder_traversal()) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + assert isinstance(stree.children[0], tn.SDFGCallNode) + assert isinstance(stree.children[1], tn.ReturnNode) diff --git a/tests/python_frontend/schedule_tree/function_call_test.py b/tests/python_frontend/schedule_tree/function_call_test.py new file mode 100644 index 0000000000..d92e4f8161 --- /dev/null +++ b/tests/python_frontend/schedule_tree/function_call_test.py @@ -0,0 +1,1008 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. +"""Tests for function-call inlining in the schedule-tree frontend.""" + +import numpy as np +import pytest +import dace +from dace.frontend.python.schedule_tree import function_inlining +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +def test_basic_inlined_call(): + """A calls B with direct array args — verify FunctionCallScope + inlined body.""" + + @dace.program + def callee(X: dace.float64[4], Y: dace.float64[4]): + return X + Y + + @dace.program + def caller(A: dace.float64[4], B: dace.float64[4]): + C = callee(A, B) + return C + + stree = caller.to_schedule_tree() + + # Find the FunctionCallScope. + call_scopes = [c for c in stree.children if isinstance(c, tn.FunctionCallScope)] + assert len(call_scopes) == 1 + + scope = call_scopes[0] + assert scope.call.callee_name == 'callee' + assert scope.call.arguments == {'X': 'A', 'Y': 'B'} + # Body should be non-empty (the callee's inlined content). + assert len(scope.children) >= 1 + + +def test_call_with_return_value(): + """x = callee(A) — verify ReturnNode replaced with assignment.""" + + @dace.program + def callee(A: dace.float64[4]): + return A + 1 + + @dace.program + def caller(A: dace.float64[4]): + x = callee(A) + return x + + stree = caller.to_schedule_tree() + + call_scopes = [c for c in stree.children if isinstance(c, tn.FunctionCallScope)] + assert len(call_scopes) == 1 + scope = call_scopes[0] + + # The callee's ReturnNode should have been rewritten to an assignment. + return_nodes = [c for c in scope.children if isinstance(c, tn.ReturnNode)] + assert len(return_nodes) == 0, 'ReturnNode should be rewritten to AssignNode' + + assign_nodes = [c for c in scope.children if isinstance(c, tn.AssignNode)] + assert len(assign_nodes) >= 1 + + +def test_multiple_calls_to_same_function(): + """callee(A); callee(B) — two separate FunctionCallScope nodes.""" + + @dace.program + def callee(X: dace.float64[4]): + return X + 1 + + @dace.program + def caller(A: dace.float64[4], B: dace.float64[4]): + x = callee(A) + y = callee(B) + return x + + stree = caller.to_schedule_tree() + + call_scopes = [n for n in stree.preorder_traversal() if isinstance(n, tn.FunctionCallScope)] + assert len(call_scopes) == 2 + assert call_scopes[0].call.callee_name == 'callee' + assert call_scopes[1].call.callee_name == 'callee' + + +def test_name_collision_renaming(): + """Caller and callee both have transient '__stree_tmp' — verify renaming.""" + + @dace.program + def callee(X: dace.float64[4], Y: dace.float64[4]): + return X + Y + + @dace.program + def caller(A: dace.float64[4], B: dace.float64[4]): + # This expression materializes into __stree_tmp in the caller. + C = callee(A + 1, B) + return C + + stree = caller.to_schedule_tree() + + call_scopes = [n for n in stree.preorder_traversal() if isinstance(n, tn.FunctionCallScope)] + assert len(call_scopes) == 1 + scope = call_scopes[0] + + # The callee's internal temporary must NOT collide with the + # caller's __stree_tmp (used for A+1). + caller_container_names = set(stree.containers.keys()) + assert '__stree_tmp' in caller_container_names, 'caller should have __stree_tmp for A+1' + + +def test_nested_calls_a_b_c(): + """A -> B -> C — verify bottom-up inlining produces correct structure.""" + + @dace.program + def C_func(X: dace.float64[4]): + return X + 1 + + @dace.program + def B_func(X: dace.float64[4]): + return C_func(X) + + @dace.program + def A_func(X: dace.float64[4]): + return B_func(X) + + stree = A_func.to_schedule_tree() + + # A should have a FunctionCallScope for B. + a_calls = [n for n in stree.children if isinstance(n, tn.FunctionCallScope)] + assert len(a_calls) == 1 + assert a_calls[0].call.callee_name == 'B_func' + + # B's inlined body should contain a FunctionCallScope for C. + b_calls = [n for n in a_calls[0].children if isinstance(n, tn.FunctionCallScope)] + assert len(b_calls) == 1 + assert b_calls[0].call.callee_name == 'C_func' + + # C's inlined body should contain actual computation. + assert len(b_calls[0].children) >= 1 + + +def test_function_inlining_progress_tracks_unique_callees(monkeypatch): + progress_calls = [] + + def tracking_progressbar(iterable, title=None, n=None, progress=None, time_threshold=5.0): + progress_calls.append({ + 'title': title, + 'n': n, + 'completed': 0, + }) + record = progress_calls[-1] + for item in iterable: + record['completed'] += 1 + yield item + + monkeypatch.setattr(function_inlining, 'optional_progressbar', tracking_progressbar) + + @dace.program + def callee_a(X: dace.float64[4]): + return X + 1 + + @dace.program + def callee_b(X: dace.float64[4]): + return X + 2 + + @dace.program + def caller(A: dace.float64[4]): + left = callee_a(A) + right = callee_b(A) + return left + right + + stree = caller.to_schedule_tree() + + call_scopes = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(call_scopes) == 2 + assert len(progress_calls) == 1 + assert progress_calls[0]['title'] == 'Parsing nested DaCe functions' + assert progress_calls[0]['n'] == 2 + assert progress_calls[0]['completed'] == 2 + + +def test_call_with_materialized_args(): + """callee(A+1, B+2) — verify temporaries feed into FunctionCallScope.""" + + @dace.program + def callee(X: dace.float64[4], Y: dace.float64[4]): + return X + Y + + @dace.program + def caller(A: dace.float64[4], B: dace.float64[4]): + return callee(A + 1, B + 2) + + stree = caller.to_schedule_tree() + + # Arguments are materialized into map scopes before the call. + maps = [c for c in stree.children if isinstance(c, tn.MapScope)] + assert len(maps) >= 2, 'A+1 and B+2 should each produce a MapScope' + + call_scopes = [c for c in stree.children if isinstance(c, tn.FunctionCallScope)] + assert len(call_scopes) == 1 + scope = call_scopes[0] + # Arguments should reference the materialized temporaries. + assert '__stree_tmp' in scope.call.arguments.values() or any( + v.startswith('__stree_tmp') for v in scope.call.arguments.values()) + + +def test_call_with_keyword_arguments(): + """callee(Y=B, X=A) — verify argument mapping handles keywords.""" + + @dace.program + def callee(X: dace.float64[4], Y: dace.float64[4]): + return X + Y + + @dace.program + def caller(A: dace.float64[4], B: dace.float64[4]): + return callee(Y=B, X=A) + + stree = caller.to_schedule_tree() + + call_scopes = [c for c in stree.children if isinstance(c, tn.FunctionCallScope)] + assert len(call_scopes) == 1 + assert call_scopes[0].call.arguments == {'X': 'A', 'Y': 'B'} + + +def test_function_call_scope_as_string(): + """Verify the as_string() representation of FunctionCallScope.""" + + @dace.program + def callee(X: dace.float64[4]): + return X + 1 + + @dace.program + def caller(A: dace.float64[4]): + return callee(A) + + stree = caller.to_schedule_tree() + text = stree.as_string() + assert 'call callee(X=A):' in text + + +def test_bare_call_statement(): + """callee(A) as a bare statement — no return targets.""" + + @dace.program + def callee(out: dace.float64[4], X: dace.float64[4]): + out[:] = X + 1 + + @dace.program + def caller(A: dace.float64[4], B: dace.float64[4]): + callee(B, A) + return B + + stree = caller.to_schedule_tree() + + call_scopes = [c for c in stree.children if isinstance(c, tn.FunctionCallScope)] + assert len(call_scopes) == 1 + scope = call_scopes[0] + assert scope.call.callee_name == 'callee' + # Bare call should have no return targets. + assert scope._return_targets is None + # Body should still be inlined. + assert len(scope.children) >= 1 + + +# -------------------------------------------------------------------- # +# Descriptor inference tests # +# -------------------------------------------------------------------- # + + +def test_descriptor_inference_numpy_sum(): + """numpy.sum(A, axis=0) should produce a LibraryCall with correct output shape.""" + + @dace.program + def prog(A: dace.float64[4, 5]): + x = np.sum(A, axis=0) + return x + + stree = prog.to_schedule_tree() + + # Should produce a LibraryCall for numpy.sum, not an opaque AssignNode. + lib_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall)] + assert len(lib_calls) >= 1, f'Expected LibraryCall for numpy.sum, got:\n{stree.as_string()}' + sum_call = [lc for lc in lib_calls if lc.node.name == 'numpy.sum'] + assert len(sum_call) == 1, f'Expected one numpy.sum LibraryCall, got:\n{stree.as_string()}' + + # Output container should have the reduced shape (5,). + out_memlet = list(sum_call[0].out_memlets.values())[0] + out_name = out_memlet.data + assert out_name in stree.containers + desc = stree.containers[out_name] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (5, ) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_sum_full_reduction(): + """numpy.sum(A) with no axis should produce a Scalar.""" + + @dace.program + def prog(A: dace.float64[4, 5]): + x = np.sum(A) + return x + + stree = prog.to_schedule_tree() + + lib_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall) and n.node.name == 'numpy.sum'] + assert len(lib_calls) == 1, f'Expected one numpy.sum LibraryCall, got:\n{stree.as_string()}' + out_name = list(lib_calls[0].out_memlets.values())[0].data + assert out_name in stree.containers + desc = stree.containers[out_name] + assert isinstance(desc, dace.data.Scalar) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_where(): + """numpy.where(cond, A, 1.0) should preserve the runtime replacement's x/y broadcasted shape.""" + + @dace.program + def prog(cond: dace.bool_[2, 1], A: dace.float32[2, 3]): + x = np.where(cond, A, 1.0) + return x + + stree = prog.to_schedule_tree() + + lib_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall)] + assert len(lib_calls) >= 1, f'Expected LibraryCall for numpy.where, got:\n{stree.as_string()}' + where_call = [lc for lc in lib_calls if lc.node.name == 'numpy.where'] + assert len(where_call) == 1, f'Expected one numpy.where LibraryCall, got:\n{stree.as_string()}' + + out_memlet = list(where_call[0].out_memlets.values())[0] + out_name = out_memlet.data + assert out_name in stree.containers + desc = stree.containers[out_name] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (2, 3) + assert desc.dtype == dace.float32 + + +def test_descriptor_inference_numpy_select(): + """numpy.select should match the runtime replacement's nested where descriptor shape and dtype.""" + + @dace.program + def prog(cond: dace.bool_[2, 1], A: dace.float32[2, 3]): + x = np.select([cond], [A], default=1.0) + return x + + stree = prog.to_schedule_tree() + + lib_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall)] + assert len(lib_calls) >= 1, f'Expected LibraryCall for numpy.select, got:\n{stree.as_string()}' + select_call = [lc for lc in lib_calls if lc.node.name == 'numpy.select'] + assert len(select_call) == 1, f'Expected one numpy.select LibraryCall, got:\n{stree.as_string()}' + + out_memlet = list(select_call[0].out_memlets.values())[0] + out_name = out_memlet.data + assert out_name in stree.containers + desc = stree.containers[out_name] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (2, 3) + assert desc.dtype == dace.float32 + + +def test_descriptor_inference_numpy_clip(): + """numpy.clip should infer through the same ufunc-based branching as the runtime replacement.""" + + @dace.program + def prog(A: dace.float32[2, 3]): + x = np.clip(A, 1.0, 3.0) + return x + + stree = prog.to_schedule_tree() + + lib_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall)] + assert len(lib_calls) >= 1, f'Expected LibraryCall for numpy.clip, got:\n{stree.as_string()}' + clip_call = [lc for lc in lib_calls if lc.node.name == 'numpy.clip'] + assert len(clip_call) == 1, f'Expected one numpy.clip LibraryCall, got:\n{stree.as_string()}' + + out_memlet = list(clip_call[0].out_memlets.values())[0] + out_name = out_memlet.data + assert out_name in stree.containers + desc = stree.containers[out_name] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (2, 3) + assert desc.dtype == dace.float32 + + +def test_descriptor_inference_numpy_rot90(): + """numpy.rot90 should swap the selected axes for odd k values.""" + + @dace.program + def prog(A: dace.float64[2, 3]): + x = np.rot90(A) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (3, 2) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_fft(): + """numpy.fft.fft should preserve shape and promote real inputs to complex.""" + + @dace.program + def prog(A: dace.float32[8]): + x = np.fft.fft(A) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (8, ) + assert desc.dtype == dace.complex64 + + +def test_descriptor_inference_numpy_ifft(): + """numpy.fft.ifft should preserve shape and complex dtype.""" + + @dace.program + def prog(A: dace.complex64[8]): + x = np.fft.ifft(A) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (8, ) + assert desc.dtype == dace.complex64 + + +def test_descriptor_inference_numpy_linalg_inv(): + """numpy.linalg.inv should preserve matrix shape and dtype.""" + + @dace.program + def prog(A: dace.float64[4, 4]): + x = np.linalg.inv(A) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, 4) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_linalg_solve(): + """numpy.linalg.solve should infer the shape and dtype of the right-hand side.""" + + @dace.program + def prog(A: dace.float64[4, 4], B: dace.float64[4]): + x = np.linalg.solve(A, B) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, ) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_linalg_cholesky(): + """numpy.linalg.cholesky should preserve matrix shape and dtype.""" + + @dace.program + def prog(A: dace.float64[4, 4]): + x = np.linalg.cholesky(A) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, 4) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_dot(): + """numpy.dot should follow the current frontend replacement's matrix-multiplication branch for 2D inputs.""" + + @dace.program + def prog(A: dace.float64[4, 3], B: dace.float64[3, 2]): + x = np.dot(A, B) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, 2) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_tensordot(): + """numpy.tensordot should infer the non-contracted output modes from the runtime replacement rules.""" + + @dace.program + def prog(A: dace.float64[2, 3, 4], B: dace.float64[4, 3, 5]): + x = np.tensordot(A, B, axes=([2, 1], [0, 1])) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (2, 5) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_einsum(): + """numpy.einsum should infer its output shape from the parsed output subscripts.""" + + @dace.program + def prog(A: dace.float64[4, 3], B: dace.float64[3, 2]): + x = np.einsum('ik,kj->ij', A, B) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, 2) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_einsum_multi_contraction(): + """numpy.einsum should preserve only the non-contracted modes for multi-dimensional contractions.""" + + A_dim, B_dim, C_dim, D_dim, E_dim = (dace.symbol(name) for name in ('A_dim', 'B_dim', 'C_dim', 'D_dim', 'E_dim')) + + @dace.program + def prog(A: dace.float64[A_dim, B_dim, C_dim, D_dim], B: dace.float64[B_dim, D_dim, C_dim, E_dim]): + x = np.einsum('abcd,bdce->ae', A, B) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (A_dim, E_dim) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_einsum_repeated_output_index(): + """numpy.einsum should allow repeated output labels like i->ii for diagonal expansion.""" + + vec_len = dace.symbol('vec_len') + + @dace.program + def prog(A: dace.float64[vec_len]): + x = np.einsum('i->ii', A) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (vec_len, vec_len) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_einsum_contracts_away_input(): + """numpy.einsum should handle outputs that keep labels from only one input, like j,k->k.""" + + reduced_dim, kept_dim = (dace.symbol(name) for name in ('reduced_dim', 'kept_dim')) + + @dace.program + def prog(A: dace.float64[reduced_dim], B: dace.float64[kept_dim]): + x = np.einsum('j,k->k', A, B) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (kept_dim, ) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_mean(): + """numpy.mean should promote integer input to float64.""" + + @dace.program + def prog(A: dace.int32[10]): + x = np.mean(A) + return x + + stree = prog.to_schedule_tree() + + lib_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall) and n.node.name == 'numpy.mean'] + assert len(lib_calls) == 1, f'Expected numpy.mean LibraryCall, got:\n{stree.as_string()}' + out_name = list(lib_calls[0].out_memlets.values())[0].data + desc = stree.containers[out_name] + assert isinstance(desc, dace.data.Scalar) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_reshape(): + """numpy.reshape should produce array with the new shape.""" + + @dace.program + def prog(A: dace.float64[3, 4]): + x = np.reshape(A, (12, )) + return x + + stree = prog.to_schedule_tree() + + # numpy.reshape may be lowered as a TaskletNode or LibraryCall — either is fine. + # The important thing is the output descriptor has the correct shape. + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (12, ) + + +def test_descriptor_inference_numpy_transpose(): + """numpy.transpose should reverse axes by default.""" + + @dace.program + def prog(A: dace.float64[3, 5]): + x = np.transpose(A) + return x + + stree = prog.to_schedule_tree() + + # The important thing is the output descriptor has the reversed shape. + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (5, 3) + + +def test_descriptor_inference_numpy_vstack(): + + @dace.program + def prog(A: dace.float64[2, 3], B: dace.float64[2, 3]): + x = np.vstack((A, B)) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, 3) + + +def test_descriptor_inference_numpy_split_structured_result(): + + @dace.program + def prog(A: dace.float64[6]): + left, right = np.split(A, 2) + return left + + stree = prog.to_schedule_tree() + + assert 'left' in stree.containers + left_desc = stree.containers['left'] + assert isinstance(left_desc, dace.data.Array) + assert tuple(left_desc.shape) == (3, ) + + assert 'right' in stree.containers + right_desc = stree.containers['right'] + assert isinstance(right_desc, dace.data.Array) + assert tuple(right_desc.shape) == (3, ) + + def test_attribute_inference_size_scalar(): + + @dace.program + def prog(a: dace.float64[3, 5]): + x = a.size + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Scalar) + assert desc.dtype == dace.int64 + + +def test_descriptor_inference_len_is_scalar(): + + @dace.program + def prog(A: dace.float64[4, 5]): + n = len(A) + return n + + stree = prog.to_schedule_tree() + + assert 'n' in stree.containers + desc = stree.containers['n'] + assert isinstance(desc, dace.data.Scalar) + assert desc.dtype == dace.int64 + + +def test_descriptor_inference_linspace_retstep_structured_result(): + + @dace.program + def prog(): + space, step = np.linspace(2.5, 10.0, num=3, retstep=True) + return space + + stree = prog.to_schedule_tree() + + assert not any(isinstance(node, tn.StatementNode) for node in stree.preorder_traversal()) + lib_calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.LibraryCall)] + assert len(lib_calls) == 1 + assert lib_calls[0].node.name == 'numpy.linspace' + assert set(lib_calls[0].out_memlets) == {'out0', 'out1'} + assert lib_calls[0].out_memlets['out0'].data == 'space' + assert lib_calls[0].out_memlets['out1'].data == 'step' + + assert 'space' in stree.containers + space_desc = stree.containers['space'] + assert isinstance(space_desc, dace.data.Array) + assert tuple(space_desc.shape) == (3, ) + assert space_desc.dtype == dace.float64 + + assert 'step' in stree.containers + step_desc = stree.containers['step'] + assert isinstance(step_desc, dace.data.Scalar) + assert step_desc.dtype == dace.float64 + + +# -------------------------------------------------------------------- # +# Method descriptor inference tests # +# -------------------------------------------------------------------- # + + +def test_method_inference_sum_scalar(): + """a.sum() should produce a Scalar descriptor via method inference.""" + + @dace.program + def prog(a: dace.float64[8]): + return a.sum() + + stree = prog.to_schedule_tree() + + lib_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall)] + assert len(lib_calls) >= 1, f'Expected LibraryCall for a.sum(), got:\n{stree.as_string()}' + + +def test_method_inference_sum_with_axis(): + """a.sum(axis=0) should produce a reduced Array descriptor.""" + + @dace.program + def prog(a: dace.float64[3, 4]): + x = a.sum(axis=0) + return x + + stree = prog.to_schedule_tree() + + lib_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall)] + assert len(lib_calls) >= 1, f'Expected LibraryCall for a.sum(axis=0), got:\n{stree.as_string()}' + + # Check the output container has the correct reduced shape. + out_name = list(lib_calls[0].out_memlets.values())[0].data + assert out_name in stree.containers + desc = stree.containers[out_name] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, ) + + +def test_method_inference_reshape(): + """a.reshape((12,)) should propagate the new shape.""" + + @dace.program + def prog(a: dace.float64[3, 4]): + x = a.reshape((12, )) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (12, ) + + +# -------------------------------------------------------------------- # +# Attribute descriptor inference tests # +# -------------------------------------------------------------------- # + + +def test_attribute_inference_T(): + """a.T should produce an Array with reversed shape.""" + + @dace.program + def prog(a: dace.float64[3, 5]): + x = a.T + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (5, 3) + + +# -------------------------------------------------------------------- # +# Operator descriptor inference tests # +# -------------------------------------------------------------------- # + + +def test_operator_inference_matmul(): + """A @ B should use the operator descriptor registry.""" + + @dace.program + def prog(A: dace.float64[4, 3], B: dace.float64[3, 2]): + return A @ B + + stree = prog.to_schedule_tree() + + lib_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall) and n.node.name == 'MatMul'] + assert len(lib_calls) >= 1, f'Expected MatMul LibraryCall, got:\n{stree.as_string()}' + out_name = list(lib_calls[0].out_memlets.values())[0].data + desc = stree.containers[out_name] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, 2) + + +def test_operator_inference_add_broadcast(): + """A + B should infer the broadcasted output descriptor.""" + + @dace.program + def prog(A: dace.float64[4, 1], B: dace.float64[1, 3]): + x = A + B + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, 3) + assert desc.dtype == dace.float64 + + +def test_operator_inference_compare_bool_array(): + """A < 0 should infer a boolean output array.""" + + @dace.program + def prog(A: dace.float64[4]): + x = A < 0.0 + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, ) + assert desc.dtype == dace.bool_ + + +def test_operator_inference_unary_negate_array(): + """-A should preserve the array shape and dtype class.""" + + @dace.program + def prog(A: dace.float64[4]): + x = -A + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (4, ) + assert desc.dtype == dace.float64 + + +def test_operator_inference_boolop_scalar_and(): + """Scalar boolean `and` should infer a boolean scalar result.""" + + @dace.program + def prog(a: dace.bool_, b: dace.bool_): + x = a and b + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Scalar) + assert desc.dtype == dace.bool_ + + +# -------------------------------------------------------------------- # +# Nested inference test # +# -------------------------------------------------------------------- # + + +def test_nested_inference_sum_of_matmul(): + """np.sum(A @ B) should chain MatMul + numpy.sum LibraryCalls.""" + + @dace.program + def prog(A: dace.float64[4, 3], B: dace.float64[3, 2]): + return np.sum(A @ B) + + stree = prog.to_schedule_tree() + + matmul_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall) and n.node.name == 'MatMul'] + sum_calls = [n for n in stree.preorder_traversal() if isinstance(n, tn.LibraryCall) and n.node.name == 'numpy.sum'] + + assert len(matmul_calls) >= 1, f'Expected MatMul LibraryCall, got:\n{stree.as_string()}' + assert len(sum_calls) >= 1, f'Expected numpy.sum LibraryCall, got:\n{stree.as_string()}' + + # numpy.sum result should be a Scalar. + sum_out_name = list(sum_calls[0].out_memlets.values())[0].data + desc = stree.containers[sum_out_name] + assert isinstance(desc, dace.data.Scalar) + + +def test_descriptor_inference_custom_arraylike(): + """np.asarray should infer descriptors for objects that implement __array__.""" + + class CustomArrayLike: + + def __array__(self, dtype=None): + return np.eye(2, 5, dtype=dtype if dtype is not None else np.float64) + + custom = CustomArrayLike() + + @dace.program + def prog(): + x = np.multiply(custom, 2) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (2, 5) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_numpy_asarray_custom_arraylike(): + """np.asarray should preserve shape and dtype for custom __array__ objects.""" + + class CustomArrayLike: + + def __array__(self, dtype=None): + return np.eye(2, 5, dtype=dtype if dtype is not None else np.float64) + + custom = CustomArrayLike() + + @dace.program + def prog(): + x = np.asarray(custom) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (2, 5) + assert desc.dtype == dace.float64 + + +def test_descriptor_inference_custom_array_interface(): + """Objects with __array_interface__ should infer directly as array inputs.""" + + class CustomArrayInterfaceLike: + + def __init__(self): + self._array = np.zeros((2, 5), dtype=np.float64) + + @property + def __array_interface__(self): + return self._array.__array_interface__ + + custom = CustomArrayInterfaceLike() + + @dace.program + def prog(): + x = np.transpose(custom) + return x + + stree = prog.to_schedule_tree() + + assert 'x' in stree.containers + desc = stree.containers['x'] + assert isinstance(desc, dace.data.Array) + assert tuple(desc.shape) == (5, 2) + assert desc.dtype == dace.float64 + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/python_frontend/schedule_tree/lambda_devirtualization_test.py b/tests/python_frontend/schedule_tree/lambda_devirtualization_test.py new file mode 100644 index 0000000000..d9d340bff7 --- /dev/null +++ b/tests/python_frontend/schedule_tree/lambda_devirtualization_test.py @@ -0,0 +1,118 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import pytest + +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +def test_local_lambda_array_call_devirtualizes(): + + @dace.program + def prog(A: dace.float64[4], B: dace.float64[4]): + f = lambda a, b: a + b + return f(A, B) + + stree = prog.to_schedule_tree() + + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + assert not any(isinstance(node, tn.FunctionCallScope) for node in stree.preorder_traversal()) + assert any(isinstance(node, tn.MapScope) for node in stree.children) + + +def test_global_lambda_array_call_devirtualizes(): + f = lambda a, b: a + b + + @dace.program + def prog(A: dace.float64[4], B: dace.float64[4]): + return f(A, B) + + stree = prog.to_schedule_tree() + + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + assert not any(isinstance(node, tn.FunctionCallScope) for node in stree.preorder_traversal()) + assert any(isinstance(node, tn.MapScope) for node in stree.children) + + +def test_lambda_capture_devirtualizes(): + offset = 3.0 + + @dace.program + def prog(A: dace.float64[4]): + f = lambda a: a + offset + return f(A) + + stree = prog.to_schedule_tree() + + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + assert any(isinstance(node, tn.MapScope) for node in stree.children) + + +def test_lambda_body_call_to_dace_program_becomes_call_scope(): + + @dace.program + def callee(A: dace.float64[4], B: dace.float64[4]): + return A + B + + f = lambda a, b: callee(a, b) + + @dace.program + def prog(A: dace.float64[4], B: dace.float64[4]): + return f(A, B) + + stree = prog.to_schedule_tree() + + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 1 + assert calls[0].call.callee_name == 'callee' + + +def test_lambda_argument_to_nested_program_devirtualizes(): + + @dace.program + def inner(A: dace.float64[4], B: dace.float64[4], f): + return f(A, B) + + @dace.program + def outer(A: dace.float64[4], B: dace.float64[4]): + f = lambda a, b: a + b + g = f + return inner(A, B, g) + + stree = outer.to_schedule_tree() + + assert 'f' in stree.containers + assert 'g' in stree.containers + assert isinstance(stree.containers['f'].dtype, dace.dtypes.callback) + assert isinstance(stree.containers['g'].dtype, dace.dtypes.callback) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 1 + assert any(isinstance(node, tn.MapScope) for node in calls[0].preorder_traversal()) + + +def test_external_lambda_argument_to_nested_program_stays_callback_typed(): + external = eval('lambda a, b: a + b') + + @dace.program + def inner(A: dace.float64, B: dace.float64, f): + return f(A, B) + + @dace.program + def outer(A: dace.float64, B: dace.float64, external: dace.callback(dace.float64, dace.float64, dace.float64)): + f = external + return inner(A, B, f) + + stree = outer.to_schedule_tree() + assert 'f' in stree.containers + assert isinstance(stree.containers['f'], dace.data.Scalar) + assert isinstance(stree.containers['f'].dtype, dace.dtypes.callback) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 1 + assert 'tasklet(f[0], A[0], B[0])' in calls[0].as_string() + assert 'assign __stree_retval = __stree_tmp' in calls[0].as_string() + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/python_frontend/schedule_tree/lambda_support_test.py b/tests/python_frontend/schedule_tree/lambda_support_test.py new file mode 100644 index 0000000000..9ee52e8062 --- /dev/null +++ b/tests/python_frontend/schedule_tree/lambda_support_test.py @@ -0,0 +1,46 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast + +from dace.frontend.python import astutils +from dace.frontend.python.schedule_tree.lambda_support import LambdaResolver + + +def test_lambda_resolver_inlines_named_lambda_calls(): + lambda_bindings = {'f': ast.parse('lambda a, b: a + b', mode='eval').body} + resolver = LambdaResolver({}, lambda_bindings, {}) + + rewritten = resolver.inline_known_lambda_calls(ast.parse('f(A, B)', mode='eval').body) + + assert isinstance(rewritten, ast.BinOp) + assert astutils.unparse(rewritten.left) == 'A' + assert astutils.unparse(rewritten.right) == 'B' + + +def test_lambda_resolver_recovers_global_lambda_with_capture(): + offset = 3.0 + f = lambda a: a + offset + resolver = LambdaResolver({'f': f}, {}, {}) + + lambda_node = resolver.resolve_known_lambda_node(ast.Name(id='f', ctx=ast.Load())) + + assert lambda_node is not None + assert isinstance(lambda_node.body, ast.BinOp) + assert astutils.unparse(lambda_node.body.left) == 'a' + assert astutils.unparse(lambda_node.body.right) == '3.0' + + +def test_lambda_resolver_exposes_callable_capture_through_globals(): + + def callee(a, b): + return a + b + + f = lambda a, b: callee(a, b) + resolver = LambdaResolver({'f': f}, {}, {}) + + lambda_node = resolver.resolve_known_lambda_node(ast.Name(id='f', ctx=ast.Load())) + + assert lambda_node is not None + assert isinstance(lambda_node.body, ast.Call) + assert astutils.unparse(lambda_node.body.func) == 'callee' + assert resolver.globals['callee'] is callee diff --git a/tests/python_frontend/schedule_tree/nested_function_test.py b/tests/python_frontend/schedule_tree/nested_function_test.py new file mode 100644 index 0000000000..8390eab7cf --- /dev/null +++ b/tests/python_frontend/schedule_tree/nested_function_test.py @@ -0,0 +1,450 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +import inspect + +import dace +import numpy as np +import pytest +from typing import Optional, Dict, Any, Tuple, Sequence + +from dace.frontend.python.common import DaceSyntaxError +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.frontend.python.common import ScheduleTreeConvertible + +__schedule_tree_callback_scale = 5 + + +def test_simple_nested_function_becomes_call_scope(): + + @dace.program + def prog(A: dace.float64[4], B: dace.float64[4]): + + def helper(x, y): + return x + y + + return helper(A, B) + + stree = prog.to_schedule_tree() + + assert 'helper' in stree.containers + assert isinstance(stree.containers['helper'].dtype, dace.dtypes.callback) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 1 + assert calls[0].call.callee_name == 'helper' + assert any(isinstance(node, tn.MapScope) for node in calls[0].preorder_traversal()) + + +def test_nested_function_capture_becomes_call_scope(): + offset = 2.0 + + @dace.program + def prog(A: dace.float64[4]): + + def helper(x): + return x + offset + + return helper(A) + + stree = prog.to_schedule_tree() + + assert 'helper' in stree.containers + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 1 + assert calls[0].call.callee_name == 'helper' + assert any(isinstance(node, tn.MapScope) for node in calls[0].preorder_traversal()) + + +def test_nested_function_raise_is_inlined(): + + @dace.program + def prog(A: dace.float64[4]): + + def helper(x): + raise ValueError(x[0]) + + helper(A) + + stree = prog.to_schedule_tree() + + raise_nodes = [node for node in stree.preorder_traversal() if isinstance(node, tn.RaiseNode)] + assert len(raise_nodes) == 1 + assert len(raise_nodes[0].args) == 1 + assert raise_nodes[0].args[0].as_string != 'x[0]' + assert raise_nodes[0].args[0].as_string.endswith('[0]') + + +def test_nested_function_body_call_to_dace_program_becomes_call_scope(): + + @dace.program + def callee(A: dace.float64[4], B: dace.float64[4]): + return A + B + + @dace.program + def prog(A: dace.float64[4], B: dace.float64[4]): + + def helper(x, y): + return callee(x, y) + + return helper(A, B) + + stree = prog.to_schedule_tree() + + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 2 + assert [call.call.callee_name for call in calls] == ['helper', 'callee'] + + +def test_nested_function_argument_to_nested_program_becomes_call_regions(): + + @dace.program + def inner(A: dace.float64[4], B: dace.float64[4], f): + return f(A, B) + + @dace.program + def outer(A: dace.float64[4], B: dace.float64[4]): + + def helper(x, y): + return x + y + + alias = helper + return inner(A, B, alias) + + stree = outer.to_schedule_tree() + + assert 'helper' in stree.containers + assert 'alias' in stree.containers + assert isinstance(stree.containers['helper'].dtype, dace.dtypes.callback) + assert isinstance(stree.containers['alias'].dtype, dace.dtypes.callback) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 2 + assert [call.call.callee_name for call in calls] == ['inner', 'helper'] + assert any(isinstance(node, tn.MapScope) for node in calls[1].preorder_traversal()) + + +def test_multistatement_nested_function_becomes_call_scope(): + + @dace.program + def prog(A: dace.float64[4]): + + def helper(x): + y = x + 1 + return y + + return helper(A) + + stree = prog.to_schedule_tree() + + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 1 + assert calls[0].call.callee_name == 'helper' + + +def test_nested_function_nonlocal_rebinds_outer_reference(): + + @dace.program + def prog(A: dace.float64[4], B: dace.float64[4]): + tmp = A + + def helper(): + nonlocal tmp + tmp = B + + helper() + return tmp + + stree = prog.to_schedule_tree() + + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 1 + refsets = [node for node in calls[0].preorder_traversal() if isinstance(node, tn.RefSetNode)] + assert len(refsets) == 1 + assert refsets[0].target == 'tmp' + + +def test_nested_function_nonlocal_external_capture_is_added_to_closure(): + + def make_prog(): + captured = np.ones(4, dtype=np.float64) + + @dace.program + def prog(A: dace.float64[4]): + + def helper(): + nonlocal captured + return captured + A + + return helper() + + return prog + + stree = make_prog().to_schedule_tree() + + assert 'captured' in stree.containers + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 1 + assert any(isinstance(node, tn.MapScope) for node in calls[0].preorder_traversal()) + + +def test_nested_function_global_capture_is_added_to_closure(): + globals()['__schedule_tree_nested_global_capture'] = np.ones(4, dtype=np.float64) + + try: + + @dace.program + def prog(A: dace.float64[4]): + + def helper(): + global __schedule_tree_nested_global_capture + return __schedule_tree_nested_global_capture + A + + return helper() + + stree = prog.to_schedule_tree() + + finally: + del globals()['__schedule_tree_nested_global_capture'] + + assert '__schedule_tree_nested_global_capture' in stree.containers + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 1 + assert any(isinstance(node, tn.MapScope) for node in calls[0].preorder_traversal()) + + +def test_nested_function_with_nonlocal_callback_fallback_is_rejected(): + + def passthrough(fn): + return fn + + @dace.program + def prog(A: dace.float64[4], B: dace.float64[4]): + tmp = A + + @passthrough + def helper(): + nonlocal tmp + tmp = B + + helper() + return tmp + + with pytest.raises(DaceSyntaxError, match='nonlocal'): + prog.to_schedule_tree() + + +def test_nested_function_with_global_callback_fallback_is_allowed(): + globals()['__schedule_tree_callback_global'] = 2.0 + + def passthrough(fn): + return fn + + try: + + @dace.program + def prog(A: dace.float64[4]): + + @passthrough + def helper(x): + global __schedule_tree_callback_global + return x + __schedule_tree_callback_global + + return helper(A) + + stree = prog.to_schedule_tree() + + finally: + del globals()['__schedule_tree_callback_global'] + + callbacks = [node for node in stree.preorder_traversal() if isinstance(node, tn.PythonCallbackNode)] + assert len(callbacks) == 2 + assert sorted(callback.reason for callback in callbacks) == ['nested function', 'pyobject call'] + + +def test_nested_function_with_global_callback_fallback_is_rejected_if_enclosing_program_uses_global(): + globals()['__schedule_tree_callback_global'] = 5 + + def passthrough(fn): + return fn + + try: + + @dace.program + def prog(A: dace.float64[4]): + + @passthrough + def helper(x): + global __schedule_tree_callback_global + __schedule_tree_callback_global = 6 + + helper(A) + return __schedule_tree_callback_global - 1 + + with pytest.raises(DaceSyntaxError, match='used in the enclosing program'): + prog.to_schedule_tree() + + finally: + del globals()['__schedule_tree_callback_global'] + + +def test_decorated_nested_function_stays_callback(): + + def passthrough(fn): + return fn + + @dace.program + def prog(A: dace.float64[4]): + + @passthrough + def helper(x): + return x + 1 + + return helper(A) + + stree = prog.to_schedule_tree() + + callbacks = [node for node in stree.preorder_traversal() if isinstance(node, tn.PythonCallbackNode)] + assert len(callbacks) == 2 + assert sorted(callback.reason for callback in callbacks) == ['nested function', 'pyobject call'] + + +def test_decorated_nested_function_callback_outliner_recovers_callable_handle(): + + def passthrough(fn): + return fn + + @dace.program + def prog(): + offset = 7 + + @passthrough + def helper(x, /, y=2, *, twist=1): + return ((x + offset) * (y + twist)) + __schedule_tree_callback_scale + + return 0 + + stree = prog.to_schedule_tree() + + callbacks = [node for node in stree.preorder_traversal() if isinstance(node, tn.PythonCallbackNode)] + assert len(callbacks) == 1 + callback = callbacks[0] + assert callback.reason == 'nested function' + assert callback.input_names == ['offset'] + assert callback.output_names == ['helper'] + assert callback.outlined_function_name is not None + assert callback.outlined_function_code is not None + assert callback.outlined_call_code is not None + + namespace = { + 'passthrough': passthrough, + '__schedule_tree_callback_scale': __schedule_tree_callback_scale, + } + + outlined_function_module = ast.fix_missing_locations( + ast.Module(body=list(callback.outlined_function_code.code), type_ignores=[])) + exec(compile(outlined_function_module, '', 'exec'), namespace, namespace) + + callback_factory = namespace[callback.outlined_function_name] + assert callable(callback_factory) + + recovered_direct = callback_factory(7) + assert callable(recovered_direct) + assert str(inspect.signature(recovered_direct)) == '(x, /, y=2, *, twist=1)' + assert recovered_direct(3, y=4, twist=2) == 65 + + namespace['offset'] = 7 + + outlined_call_module = ast.fix_missing_locations( + ast.Module(body=list(callback.outlined_call_code.code), type_ignores=[])) + exec(compile(outlined_call_module, '', 'exec'), namespace, namespace) + + recovered_from_callsite = namespace['helper'] + assert callable(recovered_from_callsite) + assert recovered_from_callsite(4) == 38 + assert recovered_from_callsite(3, y=4, twist=2) == 65 + + +def test_regression_convertibles(): + + class B(ScheduleTreeConvertible): + + def __init__(self) -> None: + self._dace_program = dace.method(self.__call__).__get__(self) + + def __schedule_tree__(self, + *args, + lambda_bindings: Optional[Dict[str, ast.AST]] = None, + callable_bindings: Optional[Dict[str, Any]] = None, + **kwargs) -> tn.ScheduleTreeRoot: + + return self._dace_program.to_schedule_tree(*args, **kwargs) + + def __schedule_tree_signature__(self) -> Tuple[Sequence[str], Sequence[str]]: + return (["input_0", "input_1"], []) + + def __call__(self, input_0, input_1): + input_0[:] += input_1[:] + + class A(ScheduleTreeConvertible): + + def __init__(self): + self._b = B() + self._tmp_input = np.zeros((10, )) + self._dace_program = dace.method(self.__call__).__get__(self) + + def __schedule_tree__(self, + *args, + lambda_bindings: Optional[Dict[str, ast.AST]] = None, + callable_bindings: Optional[Dict[str, Any]] = None, + **kwargs) -> tn.ScheduleTreeRoot: + + return self._dace_program.to_schedule_tree(*args, **kwargs) + + def __schedule_tree_signature__(self) -> Tuple[Sequence[str], Sequence[str]]: + + return (["input_0", "input_optional"], []) + + def __call__(self, input_0, input_optional=None): + + if input_optional is None: + self._b(input_0, self._tmp_input) + else: + self._b(input_0, input_optional) + + arr_0 = np.ones((10, )) + arr_1 = np.ones((10, )) + + a = A() + stree = a._dace_program.to_schedule_tree(arr_0) + rendered = stree.as_string() + + assert 'object at' not in rendered + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 1 + assert calls[0].call.callee_name == 'B' + assert calls[0].call.arguments == {'input_0': 'input_0', 'input_1': '__g_self__tmp_input'} + assert any(isinstance(node, tn.MapScope) for node in calls[0].preorder_traversal()) + + stree = a._dace_program.to_schedule_tree(arr_0, arr_1) + rendered = stree.as_string() + + assert 'object at' not in rendered + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + calls = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert len(calls) == 2 + assert calls[0].call.callee_name == 'B' + assert calls[0].call.arguments == {'input_0': 'input_0', 'input_1': '__g_self__tmp_input'} + assert calls[1].call.callee_name == 'B' + assert calls[1].call.arguments == {'input_0': 'input_0', 'input_1': 'input_optional'} + assert any(isinstance(node, tn.MapScope) for node in calls[0].preorder_traversal()) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/python_frontend/schedule_tree/numpy_frontend_test.py b/tests/python_frontend/schedule_tree/numpy_frontend_test.py new file mode 100644 index 0000000000..d93b393c8e --- /dev/null +++ b/tests/python_frontend/schedule_tree/numpy_frontend_test.py @@ -0,0 +1,826 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +mynp = np + + +def test_python_frontend_schedule_tree_numpy_elementwise_assignment_and_update(): + + @dace.program + def computed(A: dace.float64[8], B: dace.float64[8], out: dace.float64[8]): + out[:] = A[:] + B[:] + out[:] += A[:] + + stree = computed.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + assert isinstance(stree.children[0].node, tn.FrontendMap) + assert isinstance(stree.children[0].children[0], tn.TaskletNode) + assert isinstance(stree.children[0].children[0].node, tn.FrontendTasklet) + assert len(stree.children[0].children[0].in_memlets) == 2 + assert isinstance(stree.children[1], tn.MapScope) + assert isinstance(stree.children[1].node, tn.FrontendMap) + assert isinstance(stree.children[1].children[0], tn.TaskletNode) + assert isinstance(stree.children[1].children[0].node, tn.FrontendTasklet) + assert len(stree.children[1].children[0].in_memlets) == 2 + + +def test_python_frontend_schedule_tree_numpy_broadcast_map(): + + @dace.program + def computed(A: dace.float64[2, 3], B: dace.float64[3], out: dace.float64[2, 3]): + out[:] = A[:] + B + + stree = computed.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + assert isinstance(stree.children[0].node, tn.FrontendMap) + assert len(stree.children[0].node.params) == 2 + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert str(tasklet.in_memlets['in1'].subset) == '__i1' + + +def test_python_frontend_schedule_tree_numpy_broadcast_column_assignment(): + + @dace.program + def computed(A: dace.float64[5, 3], B: dace.float64[5, 1], out: dace.float64[5, 3]): + out[:] = A - B + + stree = computed.to_schedule_tree() + + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0', '__i1'] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 - in1)' + assert str(tasklet.in_memlets['in0'].subset) == '__i0, __i1' + assert str(tasklet.in_memlets['in1'].subset) == '__i0, 0' + assert str(tasklet.out_memlets['out'].subset) == '__i0, __i1' + + +def test_python_frontend_schedule_tree_numpy_broadcast_prepended_dimension_assignment(): + + @dace.program + def computed(A: dace.float64[5, 3], B: dace.float64[2, 5, 1], out: dace.float64[2, 5, 3]): + out[:] = A - B + + stree = computed.to_schedule_tree() + + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0', '__i1', '__i2'] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 - in1)' + assert str(tasklet.in_memlets['in0'].subset) == '__i1, __i2' + assert str(tasklet.in_memlets['in1'].subset) == '__i0, __i1, 0' + assert str(tasklet.out_memlets['out'].subset) == '__i0, __i1, __i2' + + +def test_python_frontend_schedule_tree_numpy_broadcast_both_axes_assignment(): + + @dace.program + def computed(A: dace.float64[5, 1], B: dace.float64[1, 3], out: dace.float64[5, 3]): + out[:] = A + B + + stree = computed.to_schedule_tree() + + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0', '__i1'] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 + in1)' + assert str(tasklet.in_memlets['in0'].subset) == '__i0, 0' + assert str(tasklet.in_memlets['in1'].subset) == '0, __i1' + assert str(tasklet.out_memlets['out'].subset) == '__i0, __i1' + + +def test_python_frontend_schedule_tree_numpy_map_inside_loop_with_scalar_index(): + + @dace.program + def computed(A: dace.float64[8], out: dace.float64[8]): + ref1: dace.data.ArrayReference(A.dtype, A.shape) = A + ref2: dace.data.ArrayReference(out.dtype, out.shape) = out + for i in range(8): + ref2[:] = ref1[:] + i * ref1[2] + + stree = computed.to_schedule_tree() + + assert isinstance(stree.children[0], tn.RefSetNode) + assert isinstance(stree.children[1], tn.RefSetNode) + assert isinstance(stree.children[2], tn.LoopScope) + assert isinstance(stree.children[2].children[0], tn.MapScope) + tasklet = stree.children[2].children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 + (i * in1))' + + +def test_python_frontend_schedule_tree_advanced_indexing_is_copy_not_view(): + + @dace.program + def advanced_prog(A: dace.float64[8], ind: dace.int32[4]): + basic = A[1:5] + advanced = A[ind] + return advanced + + stree = advanced_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.ViewNode) + assert isinstance(stree.children[1], tn.MapScope) + assert isinstance(stree.children[1].children[0], tn.TaskletNode) + assert stree.children[1].children[0].node.code.as_string == 'out = in0[idx0_0]' + assert isinstance(stree.children[2], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_advanced_indexing_expression_map(): + + @dace.program + def advanced_prog(A: dace.float64[8], ind: dace.int32[4], B: dace.float64[4], out: dace.float64[4]): + out[:] = A[ind] + B[:] + + stree = advanced_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert len(tasklet.in_memlets) == 3 + assert tasklet.node.code.as_string == 'out = (in0[idx0_0] + in1)' + + +def test_python_frontend_schedule_tree_multidim_advanced_indexing_expression_map(): + + @dace.program + def advanced_prog(A: dace.float64[6, 6], I: dace.int32[4], J: dace.int32[4], out: dace.float64[4]): + out[:] = A[I, J] + + stree = advanced_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert len(tasklet.in_memlets) == 3 + assert tasklet.node.code.as_string == 'out = in0[idx0_0, idx0_1]' + + +def test_python_frontend_schedule_tree_advanced_indexing_target_assign(): + + @dace.program + def advanced_prog(A: dace.float64[8], ind: dace.int32[4]): + A[ind] = 2 + + stree = advanced_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert len(tasklet.in_memlets) == 1 + assert tasklet.node.code.as_string == 'out[outidx_0] = 2' + + +def test_python_frontend_schedule_tree_advanced_indexing_target_augassign(): + + @dace.program + def advanced_prog(A: dace.float64[8], ind: dace.int32[4], B: dace.float64[4]): + A[ind] += B + + stree = advanced_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert len(tasklet.in_memlets) == 3 + assert tasklet.node.code.as_string == 'out[outidx_0] = (cur + in0)' + assert str(tasklet.in_memlets['outidx_0'].subset) == '__i0' + assert str(tasklet.in_memlets['in0'].subset) == '__i0' + assert str(tasklet.in_memlets['cur'].subset) == '0:8' + + +def test_python_frontend_schedule_tree_advanced_indexing_target_mixed_range(): + + @dace.program + def advanced_prog(A: dace.float64[20, 20, 20], ind: dace.int32[4]): + A[1:2, ind, 3:10] = 2 + + stree = advanced_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + assert len(stree.children[0].node.params) == 2 + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out[outidx_0] = 2' + + +def test_python_frontend_schedule_tree_boolean_mask_target_augassign(): + + @dace.program + def advanced_prog(A: dace.float64[20, 30], barr: dace.bool_[20, 30]): + A[barr] += 5 + + stree = advanced_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + assert len(stree.children[0].node.params) == 2 + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert 'if mask:' in tasklet.node.code.as_string + assert 'out = (cur + 5)' in tasklet.node.code.as_string + + +def test_python_frontend_schedule_tree_boolean_mask_target_inline_assign(): + + @dace.program + def advanced_prog(A: dace.float64[20, 30]): + A[A > 15] = 2 + + stree = advanced_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + assert len(stree.children[0].node.params) == 2 + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert 'if (in100 > 15):' in tasklet.node.code.as_string + assert 'out = 2' in tasklet.node.code.as_string + + +def test_python_frontend_schedule_tree_boolean_mask_read_named_library_call(): + + @dace.program + def advanced_prog(A: dace.float64[20, 30], barr: dace.bool_[20, 30]): + return A[barr] + + stree = advanced_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.LibraryCall) + assert stree.children[0].node.name == 'boolean_mask_gather' + assert set(stree.children[0].in_memlets.keys()) == {'data', 'mask'} + result_name = stree.children[0].out_memlets['out'].data + result_desc = stree.containers[result_name] + assert isinstance(result_desc, dace.data.Array) + assert len(result_desc.shape) == 1 + assert result_desc.total_size == 600 + assert str(result_desc.shape[0]).startswith('__stree_mask_nnz') + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_boolean_mask_read_inline_library_call(): + + @dace.program + def advanced_prog(A: dace.float64[20, 30], B: dace.float64[20, 30]): + return A[(A > 15) & (B < 20)] + + stree = advanced_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.LibraryCall) + assert stree.children[0].node.name == 'boolean_mask_gather' + assert 'mask_expr' in stree.children[0].node.properties + assert 'in100' in stree.children[0].node.properties['mask_expr'] + assert 'in101' in stree.children[0].node.properties['mask_expr'] + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_numpy_indirection_update_lowering(): + + M, N = (dace.symbol(name) for name in ['M', 'N']) + + @dace.program + def indirection(A: dace.float64[M], x: dace.int32[N]): + A[:] = 1.0 + for j in range(1, N): + A[x[j]] += A[x[j - 1]] + + stree = indirection.to_schedule_tree() + + assert len(stree.children) == 2 + assert isinstance(stree.children[0], tn.MapScope) + assert isinstance(stree.children[0].children[0], tn.TaskletNode) + assert stree.children[0].children[0].node.code.as_string == 'out = 1.0' + assert isinstance(stree.children[1], tn.LoopScope) + assert [type(child) for child in stree.children[1].children] == [tn.TaskletNode, tn.TaskletNode, tn.TaskletNode] + first_idx = stree.children[1].children[0] + second_idx = stree.children[1].children[1] + update = stree.children[1].children[2] + assert first_idx.node.code.as_string == '__stree_idx = x[j]' + assert str(first_idx.in_memlets['in0'].subset) == 'j' + assert str(first_idx.out_memlets['out'].subset) == '0' + assert second_idx.node.code.as_string == '__stree_idx1 = x[(j - 1)]' + assert str(second_idx.in_memlets['in0'].subset) == 'j - 1' + assert str(second_idx.out_memlets['out'].subset) == '0' + assert update.node.code.as_string == 'out = (in0 + in1)' + assert str(update.in_memlets['in0'].subset) == '__stree_idx' + assert str(update.in_memlets['in1'].subset) == '__stree_idx1' + assert str(update.out_memlets['out'].subset) == '__stree_idx' + + +def test_python_frontend_schedule_tree_numpy_nested_indirection_copy_lowering(): + + @dace.program + def nested(A: dace.float64[50], f: dace.int32[40], g: dace.int32[30], out: dace.float64[1]): + out[0] = A[f[g[0]]] + + stree = nested.to_schedule_tree() + + assert [type(child) for child in stree.children] == [tn.TaskletNode, tn.TaskletNode, tn.CopyNode] + first_idx = stree.children[0] + second_idx = stree.children[1] + copy_node = stree.children[2] + assert first_idx.node.code.as_string == '__stree_idx = g[0]' + assert str(first_idx.in_memlets['in0'].subset) == '0' + assert str(first_idx.out_memlets['out'].subset) == '0' + assert second_idx.node.code.as_string == '__stree_idx1 = f[__stree_idx]' + assert str(second_idx.in_memlets['in0'].subset) == '__stree_idx' + assert str(second_idx.out_memlets['out'].subset) == '0' + assert str(copy_node.memlet) == 'A[__stree_idx1] -> [0]' + + +def test_python_frontend_schedule_tree_numpy_newaxis_map(): + + @dace.program + def computed(A: dace.float64[2], B: dace.float64[3], out: dace.float64[2, 3]): + out[:] = A[:, None] + B[None, :] + + stree = computed.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 + in1)' + assert str(tasklet.in_memlets['in0'].subset) == '__i0' + assert str(tasklet.in_memlets['in1'].subset) == '__i1' + + +def test_python_frontend_schedule_tree_numpy_explicit_newaxis_map(): + + @dace.program + def computed(A: dace.float64[2], B: dace.float64[3], out: dace.float64[2, 3]): + out[:] = A[:, np.newaxis] + B[np.newaxis, :] + + stree = computed.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 + in1)' + assert str(tasklet.in_memlets['in0'].subset) == '__i0' + assert str(tasklet.in_memlets['in1'].subset) == '__i1' + + +def test_python_frontend_schedule_tree_numpy_explicit_newaxis_return_shape(): + + @dace.program + def indexing_test(A: dace.float64[20, 30]): + return A[:, np.newaxis, np.newaxis, :] + + stree = indexing_test.to_schedule_tree() + + assert len(stree.children) == 2 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0', '__i1', '__i2', '__i3'] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = in0' + assert str(tasklet.in_memlets['in0'].subset) == '__i0, __i3' + result_name = tasklet.out_memlets['out'].data + result_desc = stree.containers[result_name] + assert isinstance(result_desc, dace.data.Array) + assert tuple(result_desc.shape) == (20, 1, 1, 30) + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == result_name + + +def test_python_frontend_schedule_tree_numpy_multiple_newaxis_return_shape(): + + @dace.program + def indexing_test(A: dace.float64[10, 20, 30]): + return A[np.newaxis, :, np.newaxis, np.newaxis, :, np.newaxis, :, np.newaxis] + + stree = indexing_test.to_schedule_tree() + + assert len(stree.children) == 2 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0', '__i1', '__i2', '__i3', '__i4', '__i5', '__i6', '__i7'] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = in0' + assert str(tasklet.in_memlets['in0'].subset) == '__i1, __i4, __i6' + result_name = tasklet.out_memlets['out'].data + result_desc = stree.containers[result_name] + assert isinstance(result_desc, dace.data.Array) + assert tuple(result_desc.shape) == (1, 10, 1, 1, 20, 1, 30, 1) + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == result_name + + +def test_python_frontend_schedule_tree_numpy_ellipsis_return_shape(): + + @dace.program + def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): + return A[1:5, ..., 0] + + stree = indexing_test.to_schedule_tree() + + assert len(stree.children) == 2 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0', '__i1', '__i2', '__i3'] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = in0' + assert str(tasklet.in_memlets['in0'].subset) == '__i0 + 1, __i1, __i2, __i3, 0' + result_name = tasklet.out_memlets['out'].data + result_desc = stree.containers[result_name] + assert isinstance(result_desc, dace.data.Array) + assert tuple(result_desc.shape) == (4, 5, 5, 5) + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == result_name + + +def test_python_frontend_schedule_tree_numpy_advanced_indexing_with_newaxes_return_shape(): + + @dace.program + def indexing_test(A: dace.float64[6, 6, 6, 6, 6, 6, 6], indices: dace.int32[3, 3], indices2: dace.int32[3, 3, 3]): + return A[None, :5, indices, indices2, ..., 1:6:3, 4, np.newaxis] + + stree = indexing_test.to_schedule_tree() + + assert len(stree.children) == 2 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0', '__i1', '__i2', '__i3', '__i4', '__i5', '__i6', '__i7', '__i8'] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = in0[idx0_0, idx0_1]' + assert str(tasklet.in_memlets['in0'].subset) == '__i4, 0:6, 0:6, __i5, __i6, 3*__i7 + 1, 4' + assert str(tasklet.in_memlets['idx0_0'].subset) == '__i1, __i2' + assert str(tasklet.in_memlets['idx0_1'].subset) == '__i0, __i1, __i2' + result_name = tasklet.out_memlets['out'].data + result_desc = stree.containers[result_name] + assert isinstance(result_desc, dace.data.Array) + assert tuple(result_desc.shape) == (3, 3, 3, 1, 5, 6, 6, 2, 1) + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == result_name + + +def test_python_frontend_schedule_tree_numpy_ufunc_map(): + + @dace.program + def called(A: dace.float64[8], out: dace.float64[8]): + out[:] = np.sqrt(A[:]) + + stree = called.to_schedule_tree() + + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.MapScope) + assert isinstance(stree.children[0].node, tn.FrontendMap) + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = numpy.sqrt(in0)' + + +def test_python_frontend_schedule_tree_numpy_multi_output_ufunc_map(): + + @dace.program + def called(A: dace.int32[8], B: dace.int32[8]): + Q, R = np.divmod(A, B) + return Q, R + + stree = called.to_schedule_tree() + + assert not any(isinstance(node, tn.StatementNode) for node in stree.preorder_traversal()) + assert isinstance(stree.children[0], tn.MapScope) + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == '(out0, out1) = numpy.divmod(in0, in1)' + assert set(tasklet.out_memlets) == {'out0', 'out1'} + + +def test_python_frontend_schedule_tree_numpy_batched_matmul_library_call(): + + @dace.program + def mmmtest(a: dace.float64[3, 34, 32], b: dace.float64[3, 32, 31]): + return a @ b + + stree = mmmtest.to_schedule_tree() + + assert isinstance(stree.children[0], tn.LibraryCall) + assert stree.children[0].node.name == 'MatMul' + result_name = stree.children[0].out_memlets['out'].data + result_desc = stree.containers[result_name] + assert isinstance(result_desc, dace.data.Array) + assert tuple(result_desc.shape) == (3, 34, 31) + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == '__stree_tmp' + + +def test_python_frontend_schedule_tree_numpy_batched_matmul_stationary_left_library_call(): + + @dace.program + def mmmtest(a: dace.float64[34, 32], b: dace.float64[3, 32, 31]): + return a @ b + + stree = mmmtest.to_schedule_tree() + + assert isinstance(stree.children[0], tn.LibraryCall) + assert stree.children[0].node.name == 'MatMul' + result_name = stree.children[0].out_memlets['out'].data + result_desc = stree.containers[result_name] + assert isinstance(result_desc, dace.data.Array) + assert tuple(result_desc.shape) == (3, 34, 31) + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == '__stree_tmp' + + +def test_python_frontend_schedule_tree_numpy_bitxor_pseudoscalar_dtype_inference(): + + @dace.program + def scalar_bitxor_prog(A: dace.int64[5, 5], B: dace.int64[1]): + return A ^ B + + stree = scalar_bitxor_prog.to_schedule_tree() + + assert len(stree.children) == 2 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0', '__i1'] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 ^ in1)' + assert str(tasklet.in_memlets['in0'].subset) == '__i0, __i1' + assert str(tasklet.in_memlets['in1'].subset) == '0' + result_desc = stree.containers[tasklet.out_memlets['out'].data] + assert tuple(result_desc.shape) == (5, 5) + assert result_desc.dtype == dace.int64 + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == '__stree_tmp' + + +def test_python_frontend_schedule_tree_numpy_compare_pseudoscalar_dtype_inference(): + + @dace.program + def scalar_lt_prog(A: dace.int64[5, 5], B: dace.int64[1]): + return A < B + + stree = scalar_lt_prog.to_schedule_tree() + + assert len(stree.children) == 2 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0', '__i1'] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 < in1)' + result_desc = stree.containers[tasklet.out_memlets['out'].data] + assert tuple(result_desc.shape) == (5, 5) + assert result_desc.dtype == dace.bool_ + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == '__stree_tmp' + + +def test_python_frontend_schedule_tree_numpy_transpose_stays_library_call(): + + @dace.program + def called(A: dace.float64[3, 5]): + return np.transpose(A) + + stree = called.to_schedule_tree() + + assert isinstance(stree.children[0], tn.LibraryCall) + assert stree.children[0].node.name == 'numpy.transpose' + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_numpy_method_reshape_stays_library_call(): + + @dace.program + def called(A: dace.float64[3, 4]): + return A.reshape((12, )) + + stree = called.to_schedule_tree() + + assert isinstance(stree.children[0], tn.LibraryCall) + assert stree.children[0].node.name == 'reshape' + assert stree.children[0].node.properties['receiver_class'] == 'Array' + assert stree.children[0].node.properties['access_kind'] == 'method' + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_numpy_attribute_stays_library_call(): + + @dace.program + def called(A: dace.float64[3, 5]): + return A.T + + stree = called.to_schedule_tree() + + assert isinstance(stree.children[0], tn.LibraryCall) + assert stree.children[0].node.name == 'T' + assert stree.children[0].node.properties['receiver_class'] == 'Array' + assert stree.children[0].node.properties['access_kind'] == 'attribute' + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_numpy_arange_symbolic_library_call(): + + K = dace.symbol('K') + + @dace.program + def called(): + return np.arange(K, dtype=np.int32) + + stree = called.to_schedule_tree(K=8) + + assert isinstance(stree.children[0], tn.LibraryCall) + assert stree.children[0].node.name == 'numpy.arange' + assert stree.children[0].node.properties['start'] == '0' + assert stree.children[0].node.properties['stop'] == 'K' + result_name = stree.children[0].out_memlets['out'].data + result_desc = stree.containers[result_name] + assert isinstance(result_desc, dace.data.Array) + assert tuple(result_desc.shape) == (K, ) + assert result_desc.dtype == dace.int32 + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_numpy_arange_scalar_argument_promotes_symbol(): + + @dace.program + def called(n: dace.int32): + return np.arange(n, dtype=np.int32) + + stree = called.to_schedule_tree(np.int32(8)) + + assert isinstance(stree.children[0], tn.TaskletNode) + assert stree.children[0].node.code.as_string == 'out = n' + assert isinstance(stree.children[1], tn.AssignNode) + assert stree.children[1].name.startswith('__sym_n') + assert isinstance(stree.children[2], tn.LibraryCall) + assert stree.children[2].node.name == 'numpy.arange' + assert stree.children[2].node.properties['stop'] == stree.children[1].name + result_name = stree.children[2].out_memlets['out'].data + result_desc = stree.containers[result_name] + assert isinstance(result_desc, dace.data.Array) + assert str(result_desc.shape[0]).startswith('__sym_n') + assert isinstance(stree.children[3], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_numpy_arange_data_scalar_argument_promotes_symbol(): + + @dace.program + def called(A: dace.int32[1]): + return np.arange(A[0], dtype=np.int32) + + stree = called.to_schedule_tree(np.array([8], dtype=np.int32)) + + assert isinstance(stree.children[0], tn.TaskletNode) + assert stree.children[0].node.code.as_string == 'out = in0' + assert str(stree.children[0].in_memlets['in0'].subset) == '0' + assert isinstance(stree.children[1], tn.AssignNode) + assert stree.children[1].name.startswith('__sym_A_0_') + assert isinstance(stree.children[2], tn.LibraryCall) + assert stree.children[2].node.properties['stop'] == stree.children[1].name + result_name = stree.children[2].out_memlets['out'].data + result_desc = stree.containers[result_name] + assert isinstance(result_desc, dace.data.Array) + assert str(result_desc.shape[0]).startswith('__sym_A_0_') + assert isinstance(stree.children[3], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_nested_numpy_attributes_are_materialized(): + + @dace.program + def called(A: dace.float64[3, 5]): + return A.T.T + + stree = called.to_schedule_tree() + + library_calls = [node for node in stree.children if isinstance(node, tn.LibraryCall)] + assert len(library_calls) == 2 + assert all(node.node.name == 'T' for node in library_calls) + assert all(node.node.properties['access_kind'] == 'attribute' for node in library_calls) + assert isinstance(stree.children[-1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_nested_numpy_attribute_method_chain_is_materialized(): + + @dace.program + def called(A: dace.float64[3, 5]): + return A.T.T.ravel() + + stree = called.to_schedule_tree() + + library_calls = [node for node in stree.children if isinstance(node, tn.LibraryCall)] + assert [node.node.name for node in library_calls] == ['T', 'T', 'ravel'] + assert library_calls[0].node.properties['access_kind'] == 'attribute' + assert library_calls[1].node.properties['access_kind'] == 'attribute' + assert library_calls[2].node.properties['access_kind'] == 'method' + assert isinstance(stree.children[-1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_numpy_array_literal_is_materialized_before_elementwise_lowering(): + + @dace.program + def computed(A: dace.float64[3], out: dace.float64[3]): + out[:] = A + np.array([1.0, 2.0, 3.0]) + + stree = computed.to_schedule_tree() + + assert isinstance(stree.containers['__stree_tmp'], dace.data.Array) + assert tuple(stree.containers['__stree_tmp'].shape) == (3, ) + assert isinstance(stree.children[0], tn.TaskletNode) + assert stree.children[0].node.code.as_string in { + 'out = np.array([1.0, 2.0, 3.0])', + 'out = numpy.array([1.0, 2.0, 3.0])', + } + assert isinstance(stree.children[1], tn.MapScope) + tasklet = stree.children[1].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 + in1)' + assert len(tasklet.in_memlets) == 2 + + +def test_python_frontend_schedule_tree_list_literal_in_array_expression_is_materialized_as_array(): + + @dace.program + def computed(A: dace.float64[3], out: dace.float64[3]): + out[:] = A * [1.0, 2.0, 3.0] + + stree = computed.to_schedule_tree() + + assert isinstance(stree.containers['__stree_tmp'], dace.data.Array) + assert tuple(stree.containers['__stree_tmp'].shape) == (3, ) + assert isinstance(stree.children[0], tn.TaskletNode) + assert stree.children[0].node.code.as_string in { + 'out = np.array([1.0, 2.0, 3.0])', + 'out = numpy.array([1.0, 2.0, 3.0])', + } + assert isinstance(stree.children[1], tn.MapScope) + tasklet = stree.children[1].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 * in1)' + assert len(tasklet.in_memlets) == 2 + + +def test_python_frontend_schedule_tree_aliased_numpy_array_literal_is_materialized_once(): + + @dace.program + def computed(A: dace.float64[3], out: dace.float64[3]): + out[:] = A + mynp.array([1.0, 2.0, 3.0]) + + stree = computed.to_schedule_tree() + + assert isinstance(stree.containers['__stree_tmp'], dace.data.Array) + assert tuple(stree.containers['__stree_tmp'].shape) == (3, ) + assert isinstance(stree.children[0], tn.TaskletNode) + assert stree.children[0].node.code.as_string in { + 'out = mynp.array([1.0, 2.0, 3.0])', + 'out = numpy.array([1.0, 2.0, 3.0])', + } + assert isinstance(stree.children[1], tn.MapScope) + tasklet = stree.children[1].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 + in1)' + assert len(tasklet.in_memlets) == 2 + + +def test_python_frontend_schedule_tree_numpy_compiletime_full_slice_lowers_to_map(): + + @dace.program + def sliceprog(A: dace.float64[20], slc: dace.compiletime): + A[slc] += 5 + + stree = sliceprog.to_schedule_tree(slc=slice(None, None, None)) + + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0'] + assert stree.children[0].node.ranges == [('0', '20', '1')] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = (in0 + 5)' + assert str(tasklet.in_memlets['in0'].subset) == '__i0' + assert str(tasklet.out_memlets['out'].subset) == '__i0' + + +def test_python_frontend_schedule_tree_numpy_literal_slice_lowers_to_map(): + + @dace.program + def slicer(A: dace.float64[20]): + A[slice(2, 10, 2)] = 2 + + stree = slicer.to_schedule_tree() + + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['__i0'] + assert stree.children[0].node.ranges == [('2', '10', '2')] + tasklet = stree.children[0].children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert tasklet.node.code.as_string == 'out = 2' + assert str(tasklet.out_memlets['out'].subset) == '__i0' + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/python_frontend/schedule_tree/pyobject_fallback_test.py b/tests/python_frontend/schedule_tree/pyobject_fallback_test.py new file mode 100644 index 0000000000..ece81432d9 --- /dev/null +++ b/tests/python_frontend/schedule_tree/pyobject_fallback_test.py @@ -0,0 +1,321 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +import sys + +import dace +import pytest +from dace import data, dtypes +from dace.data.pydata import PythonClass +from dace.frontend.python import preprocessing +from dace.frontend.python.schedule_tree import ScheduleTreeTypeInference +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +def _assert_pyobject_scalar(descriptor: data.Data) -> None: + assert isinstance(descriptor, data.Scalar) + assert descriptor.dtype == dtypes.pyobject() + + +def _assert_string_scalar(descriptor: data.Data) -> None: + assert isinstance(descriptor, data.Scalar) + assert descriptor.dtype == dtypes.string + + +def _infer_schedule_tree_bindings(program, argtypes=None): + argtypes = dict(argtypes or {'A': dace.float64[4]}) + modules = {name: value.__name__ for name, value in program.global_vars.items() if dtypes.ismodule(value)} + modules['builtins'] = '' + parsed_ast, _ = preprocessing.preprocess_dace_program(program.f, + argtypes, + dict(program.global_vars), + modules, + resolve_functions=program.resolve_functions, + default_args=set(), + normalize_generic_for_loops=True, + preserve_object_attributes=True, + disallowed_stmts=set()) + return ScheduleTreeTypeInference(parsed_ast.program_globals, argtypes).infer(parsed_ast.preprocessed_ast) + + +def test_schedule_tree_type_inference_opaque_assignments_use_pyobject_scalar(): + + class Box: + + def __init__(self): + self.value = None + + @dace.program + def prog(A: dace.float64[4]): + box = Box() + tmp = box.value + A[0] = A[0] + + bindings = _infer_schedule_tree_bindings(prog) + + _assert_pyobject_scalar(bindings['box'].descriptor) + _assert_pyobject_scalar(bindings['tmp'].descriptor) + + +def test_python_frontend_schedule_tree_opaque_assignments_use_pyobject_scalar(): + + class Box: + + def __init__(self): + self.value = None + + @dace.program + def prog(A: dace.float64[4]): + box = Box() + tmp = box.value + A[0] = A[0] + + stree = prog.to_schedule_tree() + + _assert_pyobject_scalar(stree.containers['box']) + _assert_pyobject_scalar(stree.containers['tmp']) + assert isinstance(stree.children[0], tn.AssignNode) + assert isinstance(stree.children[1], tn.TaskletNode) + + +def test_python_frontend_schedule_tree_callback_outputs_use_pyobject_scalar(): + + @dace.program + def prog(A: dace.float64[4]): + import math as m + tmp = m + A[0] = A[0] + + stree = prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.PythonCallbackNode) + _assert_pyobject_scalar(stree.containers['m']) + _assert_pyobject_scalar(stree.containers['tmp']) + + +def test_schedule_tree_type_inference_direct_class_annotated_alias_array_field_only_stays_structure(): + + class Holder: + arr: dace.float64[4] + + @dace.program + def prog(holder: Holder, A: dace.float64[4]): + alias: Holder = holder + alias.arr[:] = A[:] + + bindings = _infer_schedule_tree_bindings(prog, { + 'holder': dace.data.Structure.from_class(Holder), + 'A': dace.float64[4] + }) + + assert isinstance(bindings['alias'].descriptor, dace.data.Structure) + assert not isinstance(bindings['alias'].descriptor, PythonClass) + + +def test_schedule_tree_type_inference_direct_class_annotated_alias_scalar_rebinding_uses_pythonclass(): + + class Holder: + scalar: dace.float64 + arr: dace.float64[4] + + @dace.program + def prog(holder: Holder, A: dace.float64[4]): + alias: Holder = holder + alias.scalar = A[0] + + bindings = _infer_schedule_tree_bindings(prog, { + 'holder': dace.data.Structure.from_class(Holder), + 'A': dace.float64[4] + }) + + assert isinstance(bindings['alias'].descriptor, PythonClass) + + +def test_schedule_tree_type_inference_direct_class_annotated_alias_new_field_uses_pythonclass(): + + class Holder: + arr: dace.float64[4] + + @dace.program + def prog(holder: Holder, A: dace.float64[4]): + alias: Holder = holder + alias.new_field = A[0] + + bindings = _infer_schedule_tree_bindings(prog, { + 'holder': dace.data.Structure.from_class(Holder), + 'A': dace.float64[4] + }) + + assert isinstance(bindings['alias'].descriptor, PythonClass) + + +def test_schedule_tree_type_inference_dict_same_key_update_widens_value_type(): + + @dace.program + def prog(A: dace.float64[2]): + mapping = {'left': A[0], 'right': A[1]} + mapping['left'] = 'two' + value = mapping['left'] + return 0.0 + + bindings = _infer_schedule_tree_bindings(prog, {'A': dace.float64[2]}) + + assert bindings['mapping'].descriptor.value_type.dtype == dtypes.pyobject() + _assert_string_scalar(bindings['value'].descriptor) + + +def test_schedule_tree_type_inference_dict_known_static_reads_stay_precise(): + + @dace.program + def prog(A: dace.float64[2]): + mapping = {'left': A[0], 'right': 'two'} + left = mapping['left'] + right = mapping['right'] + return 0.0 + + bindings = _infer_schedule_tree_bindings(prog, {'A': dace.float64[2]}) + + assert bindings['mapping'].descriptor.value_type.dtype == dtypes.pyobject() + assert isinstance(bindings['left'].descriptor, data.Scalar) + assert bindings['left'].descriptor.dtype == dace.float64 + _assert_string_scalar(bindings['right'].descriptor) + + +def test_schedule_tree_type_inference_constant_scalars_use_literal_descriptors(): + + @dace.program + def prog(A: dace.float64[4]): + text = 'bla' + number = 5.03 + A[0] = A[0] + + bindings = _infer_schedule_tree_bindings(prog) + + _assert_string_scalar(bindings['text'].descriptor) + assert isinstance(bindings['number'].descriptor, data.Scalar) + assert bindings['number'].descriptor.dtype == dace.float64 + + +def test_python_frontend_schedule_tree_constant_scalars_use_literal_descriptors(): + + @dace.program + def prog(A: dace.float64[4]): + text = 'bla' + number = 5.03 + A[0] = A[0] + + stree = prog.to_schedule_tree() + + _assert_string_scalar(stree.containers['text']) + assert isinstance(stree.containers['number'], data.Scalar) + assert stree.containers['number'].dtype == dace.float64 + + +def test_python_frontend_schedule_tree_runtime_fstring_callback_outputs_use_string_scalar(): + + @dace.program + def prog(i: dace.int32): + return f'value={i}' + + stree = prog.to_schedule_tree() + + callbacks = [node for node in stree.preorder_traversal() if isinstance(node, tn.PythonCallbackNode)] + + assert len(callbacks) == 1 + assert callbacks[0].reason == 'f-string' + assert len(callbacks[0].output_names) == 1 + result_name = callbacks[0].output_names[0] + _assert_string_scalar(stree.containers[result_name]) + assert isinstance(stree.children[-1], tn.ReturnNode) + assert stree.children[-1].values[0] == result_name + + +def test_schedule_tree_type_inference_nested_generic_conflicts_do_not_leak(): + if sys.version_info < (3, 12): + pytest.skip('Generic function type parameters require Python 3.12+') + + source = ''' +def prog[T, *Ts](A): + tmp = A + + def inner[T, *Ts](x: T, y: tuple[*Ts]): + tmp = 1 + return x + + return tmp +''' + + module = ast.parse(source) + function = module.body[0] + bindings = ScheduleTreeTypeInference({'dace': dace}, {'A': dace.float64[4]}).infer(function) + + assert 'tmp' in bindings + assert isinstance(bindings['tmp'].descriptor, data.Array) + assert bindings['tmp'].descriptor.dtype == dace.float64 + assert tuple(bindings['tmp'].descriptor.shape) == (4, ) + + +def test_schedule_tree_type_inference_distinguishes_list_and_tuple_indices(): + + @dace.program + def list_prog(A: dace.float64[5, 6]): + tmp = A[[1, 2]] + + @dace.program + def tuple_prog(A: dace.float64[5, 6]): + tmp = A[(1, 2)] + + list_bindings = _infer_schedule_tree_bindings(list_prog, {'A': dace.float64[5, 6]}) + tuple_bindings = _infer_schedule_tree_bindings(tuple_prog, {'A': dace.float64[5, 6]}) + + assert isinstance(list_bindings['tmp'].descriptor, data.Array) + assert tuple(list_bindings['tmp'].descriptor.shape) == (2, 6) + assert list_bindings['tmp'].descriptor.dtype == dace.float64 + assert isinstance(tuple_bindings['tmp'].descriptor, data.Scalar) + assert tuple_bindings['tmp'].descriptor.dtype == dace.float64 + assert tuple_bindings['tmp'].kind == 'scalar' + + +def test_schedule_tree_type_inference_distinguishes_list_and_tuple_indices_with_symbolic_shape(): + n = dace.symbol('n') + + @dace.program + def list_prog(A: dace.float64[5, n]): + tmp = A[[1, 2]] + + @dace.program + def tuple_prog(A: dace.float64[5, n]): + tmp = A[(1, 2)] + + list_bindings = _infer_schedule_tree_bindings(list_prog, {'A': dace.float64[5, n]}) + tuple_bindings = _infer_schedule_tree_bindings(tuple_prog, {'A': dace.float64[5, n]}) + + assert tuple(list_bindings['tmp'].descriptor.shape) == (2, n) + assert isinstance(tuple_bindings['tmp'].descriptor, data.Scalar) + + +def test_schedule_tree_type_inference_symbolic_static_slice_shape(): + n = dace.symbol('n') + + @dace.program + def slice_prog(A: dace.float64[n]): + tmp = A[1:n:2] + + bindings = _infer_schedule_tree_bindings(slice_prog, {'A': dace.float64[n]}) + + assert isinstance(bindings['tmp'].descriptor, data.Array) + assert str(bindings['tmp'].descriptor.shape[0]) == 'ceiling(n/2 - 1/2)' + + +def test_schedule_tree_type_inference_ellipsis_shape(): + n = dace.symbol('n') + + @dace.program + def ellipsis_prog(A: dace.float64[4, n, 6, 7]): + tmp = A[1:3, ..., 0] + + bindings = _infer_schedule_tree_bindings(ellipsis_prog, {'A': dace.float64[4, n, 6, 7]}) + + assert isinstance(bindings['tmp'].descriptor, data.Array) + assert bindings['tmp'].descriptor.dtype == dace.float64 + assert tuple(bindings['tmp'].descriptor.shape) == (2, n, 6) diff --git a/tests/python_frontend/schedule_tree/python_frontend_test.py b/tests/python_frontend/schedule_tree/python_frontend_test.py new file mode 100644 index 0000000000..45b4e24d05 --- /dev/null +++ b/tests/python_frontend/schedule_tree/python_frontend_test.py @@ -0,0 +1,2388 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +import contextlib +import numpy as np +import pytest +import sys +import warnings +from typing import Optional + +import dace +from dace import dtypes +from dace.data.pydata import PythonDict, PythonList +from dace.frontend.python.common import DaceSyntaxError, SDFGConvertible, ScheduleTreeConvertible +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +def test_python_frontend_schedule_tree_structured_control_flow(): + + @dace.program + def structured(A: dace.float64[20]): + tmp = A[:] + for i in range(20): + if i < 10: + continue + else: + break + return tmp + + stree = structured.to_schedule_tree() + + assert isinstance(stree.children[0], tn.ViewNode) + assert isinstance(stree.children[1], tn.LoopScope) + assert 'i' in stree.symbols + assert isinstance(stree.children[1].loop, tn.FrontendLoop) + assert isinstance(stree.children[1].children[0], tn.IfScope) + assert isinstance(stree.children[1].children[0].children[0], tn.ContinueNode) + assert isinstance(stree.children[1].children[1], tn.ElseScope) + assert isinstance(stree.children[1].children[1].children[0], tn.BreakNode) + assert isinstance(stree.children[2], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_root_repository(): + offset = 3.0 + + @dace.program + def structured(A: dace.float64[20]): + return A + offset + + stree = structured.to_schedule_tree() + + assert isinstance(stree, tn.ScheduleTreeRoot) + assert stree.name.endswith('_structured') + assert stree.arg_names == ['A'] + assert 'A' in stree.containers + assert 'offset' in stree.constants + + +def test_python_frontend_schedule_tree_allocations_and_cache(): + + @dace.program + def alloc_copy(A: dace.float64[4]): + tmp = np.empty_like(A) + tmp[:] = A[:] + return tmp + + stree_first = alloc_copy.to_schedule_tree(use_cache=True) + stree_second = alloc_copy.to_schedule_tree(use_cache=True) + + assert stree_first is not stree_second + assert 'tmp' in stree_first.containers + assert (isinstance(stree_first.children[0], tn.LibraryCall) + and stree_first.children[0].node.name == 'numpy.empty_like') + assert isinstance(stree_first.children[1], tn.CopyNode) + assert isinstance(stree_first.children[2], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_references(): + + @dace.program + def refs(A: dace.float64[4], B: dace.float64[4], flag: dace.bool_): + ref: dace.data.ArrayReference(A.dtype, A.shape) = A + if flag: + ref = B + return ref + + stree = refs.to_schedule_tree() + + assert isinstance(stree.children[0], tn.RefSetNode) + assert isinstance(stree.children[1], tn.IfScope) + assert isinstance(stree.children[1].children[0], tn.RefSetNode) + assert isinstance(stree.children[2], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_unannotated_branch_references(): + + @dace.program + def refs(A: dace.float64[20], B: dace.float64[20], i: dace.int32[1], out: dace.float64[20]): + if i[0] < 5: + ref = A + else: + ref = B + out[:] = ref + + stree = refs.to_schedule_tree() + + assert isinstance(stree.children[0], tn.IfScope) + assert isinstance(stree.children[0].children[0], tn.RefSetNode) + assert isinstance(stree.children[1], tn.ElseScope) + assert isinstance(stree.children[1].children[0], tn.RefSetNode) + assert isinstance(stree.children[2], tn.CopyNode) + + +def test_python_frontend_schedule_tree_map_scope(): + + @dace.program + def mapped(A: dace.float64[8]): + for i in dace.map[0:8]: + A[i] = A[i] + + stree = mapped.to_schedule_tree() + + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.MapScope) + assert isinstance(stree.children[0].node, tn.FrontendMap) + assert isinstance(stree.children[0].children[0], tn.CopyNode) + + +def test_python_frontend_schedule_tree_function_call_return(): + + @dace.program + def inner(A: dace.float64[8], B: dace.float64[8]): + return np.sum(A + B) + + @dace.program + def outer(A: dace.float64[8], B: dace.float64[8]): + return inner(A + 1, B + 2) + + stree = outer.to_schedule_tree() + + assert len(stree.children) == 4 + assert isinstance(stree.children[0], tn.MapScope) + assert isinstance(stree.children[1], tn.MapScope) + # The nested call is inlined into a FunctionCallScope. + call_scope = stree.children[2] + assert isinstance(call_scope, tn.FunctionCallScope) + assert call_scope.call.callee_name == 'inner' + assert call_scope.call.arguments == {'A': '__stree_tmp', 'B': '__stree_tmp1'} + assert len(call_scope.children) >= 1 + assert isinstance(stree.children[3], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_nested_program_calls_are_not_executed(monkeypatch): + + @dace.program + def inner(A: dace.float64[8], B: dace.float64[8]): + return np.sum(A + B) + + @dace.program + def outer(A: dace.float64[8], B: dace.float64[8]): + return inner(A + 1, B + 2) + + from dace.frontend.python import parser as dace_parser + + original_call = dace_parser.DaceProgram.__call__ + seen = [] + + def _guard(self, *args, **kwargs): + if self is inner: + seen.append((args, kwargs)) + raise AssertionError('nested program executed during schedule-tree generation') + return original_call(self, *args, **kwargs) + + monkeypatch.setattr(dace_parser.DaceProgram, '__call__', _guard) + + stree = outer.to_schedule_tree() + + assert seen == [] + assert isinstance(stree.children[2], tn.FunctionCallScope) + assert isinstance(stree.children[3], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_constant_return_materializes_descriptor(): + + @dace.program + def constant(): + return 5 + + stree = constant.to_schedule_tree() + + assert isinstance(stree.children[-2], tn.TaskletNode) + assert stree.children[-2].node.code.as_string == '__stree_retval = 5' + assert isinstance(stree.children[-1], tn.ReturnNode) + assert stree.children[-1].values == ['__stree_retval'] + + +def test_python_frontend_schedule_tree_symbolic_return_materializes_descriptor(): + n = dace.symbol('n') + + @dace.program + def symbolic_constant(): + return n + 1 + + stree = symbolic_constant.to_schedule_tree() + + assert isinstance(stree.children[-2], tn.TaskletNode) + assert stree.children[-2].node.code.as_string == '__stree_retval = (n + 1)' + assert isinstance(stree.children[-1], tn.ReturnNode) + assert stree.children[-1].values == ['__stree_retval'] + + +def test_python_frontend_schedule_tree_function_call_assignment(): + + @dace.program + def inner(A: dace.float64[8], B: dace.float64[8]): + return A + B + + @dace.program + def outer(A: dace.float64[8], B: dace.float64[8], out: dace.float64[8]): + out[:] = inner(A + 1, B + 2) + + stree = outer.to_schedule_tree() + + assert len(stree.children) == 3 + assert isinstance(stree.children[0], tn.MapScope) + assert isinstance(stree.children[1], tn.MapScope) + # The nested call is inlined into a FunctionCallScope. + call_scope = stree.children[2] + assert isinstance(call_scope, tn.FunctionCallScope) + assert call_scope.call.callee_name == 'inner' + assert call_scope.call.arguments == {'A': '__stree_tmp', 'B': '__stree_tmp1'} + # The callee's body should be inlined with the return rewritten + # as an assignment to the caller's target. + assert len(call_scope.children) >= 1 + + +def test_python_frontend_schedule_tree_external_schedule_tree_convertible_call(): + + class Convertible(ScheduleTreeConvertible): + + def __init__(self): + + @dace.program + def inner(A: dace.float64[8], B: dace.float64[8]): + return A + B + + self.inner = inner + + def __schedule_tree__(self, *args, lambda_bindings=None, callable_bindings=None, **kwargs): + return self.inner.__schedule_tree__(*args, + lambda_bindings=lambda_bindings, + callable_bindings=callable_bindings, + **kwargs) + + def __schedule_tree_signature__(self): + return (['A', 'B'], []) + + convertible = Convertible() + + @dace.program + def outer(A: dace.float64[8], B: dace.float64[8]): + return convertible(A + 1, B + 2) + + stree = outer.to_schedule_tree() + + assert len(stree.children) == 4 + assert isinstance(stree.children[0], tn.MapScope) + assert isinstance(stree.children[1], tn.MapScope) + call_scope = stree.children[2] + assert isinstance(call_scope, tn.FunctionCallScope) + assert call_scope.call.callee_name == 'Convertible' + assert call_scope.call.arguments == {'A': '__stree_tmp', 'B': '__stree_tmp1'} + assert len(call_scope.children) >= 1 + assert isinstance(stree.children[3], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_callable_object_call_is_inlined(): + + class CallableObject: + + def __call__(self, A: dace.float64[8]): + return A + 1 + + callable_object = CallableObject() + + @dace.program + def outer(A: dace.float64[8]): + return callable_object(A) + + stree = outer.to_schedule_tree() + + assert isinstance(stree.children[0], tn.FunctionCallScope) + assert stree.children[0].call.callee_name == '__call__' + assert stree.children[0].call.arguments == {'A': 'A'} + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_parseable_free_function_call_is_inlined(): + + def callee(A: dace.float64[8]): + return A + 1 + + @dace.program + def outer(A: dace.float64[8]): + return callee(A) + + stree = outer.to_schedule_tree() + + assert isinstance(stree.children[0], tn.FunctionCallScope) + assert stree.children[0].call.callee_name == 'callee' + assert stree.children[0].call.arguments == {'A': 'A'} + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_dunder_add_is_inlined(): + + class Adder: + + @dace.method + def __add__(self, A: dace.float64[8]): + return A + 1 + + adder = Adder() + + @dace.program + def outer(A: dace.float64[8]): + return adder + A + + stree = outer.to_schedule_tree() + + assert isinstance(stree.children[0], tn.FunctionCallScope) + assert stree.children[0].call.callee_name == '__add__' + assert stree.children[0].call.arguments == {'A': 'A'} + assert isinstance(stree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_sdfg_call_stays_opaque(): + + @dace.program + def inner(A: dace.float64[8], B: dace.float64[8]): + return A + B + + sdfg_obj = inner.to_sdfg() + + @dace.program + def outer(A: dace.float64[8], B: dace.float64[8]): + return sdfg_obj(A, B) + + stree = outer.to_schedule_tree() + + assert not any(isinstance(node, tn.FunctionCallScope) for node in stree.preorder_traversal()) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + assert isinstance(stree.children[0], tn.SDFGCallNode) + assert isinstance(stree.children[0].sdfg, dace.SDFG) + assert stree.children[0].sdfg.name == sdfg_obj.name + assert stree.children[0].call.callee_name.endswith('inner') + assert stree.children[0].call.arguments == {'A': 'A', 'B': 'B'} + assert stree.children[0].return_targets == ['__stree_retval'] + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == '__stree_retval' + + +def test_python_frontend_schedule_tree_sdfg_convertible_call_stays_opaque(): + + class Convertible(SDFGConvertible): + + def __init__(self): + self.name = 'convertible' + + def __call__(self, *args, **kwargs): + raise AssertionError('SDFGConvertible should not execute during schedule-tree generation') + + def __sdfg__(self, A, B): + + @dace.program + def inner(X: dace.float64[8], Y: dace.float64[8]): + return X + Y + + return inner.to_sdfg(A, B) + + def __sdfg_signature__(self): + return (['A', 'B'], []) + + def __sdfg_closure__(self, reevaluate=None): + return {} + + convertible = Convertible() + + @dace.program + def outer(A: dace.float64[8], B: dace.float64[8]): + return convertible(A, B) + + stree = outer.to_schedule_tree() + + assert not any(isinstance(node, tn.FunctionCallScope) for node in stree.preorder_traversal()) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + assert isinstance(stree.children[0], tn.SDFGCallNode) + assert isinstance(stree.children[0].sdfg, dace.SDFG) + assert stree.children[0].call.callee_name == 'convertible' + assert stree.children[0].call.arguments == {'A': 'A', 'B': 'B'} + assert stree.children[0].return_targets == ['__stree_retval'] + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == '__stree_retval' + + +def test_python_frontend_schedule_tree_return_materializes_array_expression(): + + @dace.program + def returned(A: dace.float64[8], B: dace.float64[8]): + return A + B + + stree = returned.to_schedule_tree() + + assert len(stree.children) == 2 + assert isinstance(stree.children[0], tn.MapScope) + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == '__stree_tmp' + + +def test_python_frontend_schedule_tree_compile_time_fstring_stays_direct(): + + prefix = 'value=' + + @dace.program + def returned(): + tmp = f'{prefix}5' + return tmp + + stree = returned.to_schedule_tree() + + assert stree.containers['tmp'].dtype == dace.dtypes.string + assert isinstance(stree.children[0], tn.AssignNode) + assert isinstance(stree.children[1], tn.ReturnNode) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + + +def test_python_frontend_schedule_tree_matmul_chain_library_calls(): + + @dace.program + def chained(A: dace.float64[4, 3], B: dace.float64[3, 2], C: dace.float64[2, 5]): + return A @ B @ C + + stree = chained.to_schedule_tree() + + assert len(stree.children) == 3 + assert isinstance(stree.children[0], tn.LibraryCall) + assert isinstance(stree.children[1], tn.LibraryCall) + assert isinstance(stree.children[2], tn.ReturnNode) + assert isinstance(stree.children[0].node, tn.FrontendLibrary) + assert isinstance(stree.children[1].node, tn.FrontendLibrary) + assert stree.children[0].node.name == 'MatMul' + assert stree.children[1].node.name == 'MatMul' + assert stree.children[2].values[0] == '__stree_tmp1' + + +def test_python_frontend_schedule_tree_reduction_calls(): + + @dace.program + def np_sum_prog(a: dace.float64[8]): + return np.sum(a) + + @dace.program + def method_sum_prog(a: dace.float64[8]): + return a.sum() + + np_sum_tree = np_sum_prog.to_schedule_tree() + method_sum_tree = method_sum_prog.to_schedule_tree() + + # np.sum(a) should now be materialized as a LibraryCall writing to a + # scalar temporary, followed by a ReturnNode referencing it. + assert isinstance(np_sum_tree.children[0], tn.LibraryCall) + assert np_sum_tree.children[0].node.name == 'numpy.sum' + assert isinstance(np_sum_tree.children[1], tn.ReturnNode) + # a.sum() is method syntax — now covered by the method descriptor registry, + # so it should also be a LibraryCall followed by a ReturnNode. + assert isinstance(method_sum_tree.children[0], tn.LibraryCall) + assert isinstance(method_sum_tree.children[1], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_method_field_read_is_not_left_as_raw_self_attribute(): + + class FieldReader: + + def __init__(self): + self._tmp_field = np.arange(8, dtype=np.float64) + + @dace.method + def reduce(self): + return np.sum(self._tmp_field) + + stree = FieldReader().reduce.to_schedule_tree() + + assert 'self._tmp_field' not in stree.as_string() + assert any( + isinstance(node, tn.LibraryCall) and node.node.name == 'numpy.sum' for node in stree.preorder_traversal()) + + +def test_python_frontend_schedule_tree_nested_method_self_containers_do_not_conflict(): + + class Inner: + + def __init__(self, size: int, offset: float): + self._tmp = np.arange(size, dtype=np.float64) + offset + + @dace.method + def reduce(self): + return np.sum(self._tmp) + + class Outer: + + def __init__(self): + self._tmp = np.arange(4, dtype=np.float64) + self.inner1 = Inner(5, 1.0) + self.inner2 = Inner(6, 2.0) + + @dace.method + def reduce(self): + return np.sum(self._tmp) + self.inner1.reduce() + self.inner2.reduce() + + stree = Outer().reduce.to_schedule_tree() + + closure_shapes = sorted( + descriptor.shape[0] for name, descriptor in stree.containers.items() + if name.startswith('__g_') and len(descriptor.shape) == 1 and descriptor.shape[0] in (4, 5, 6)) + + assert closure_shapes == [4, 5, 6] + assert 'self._tmp' not in stree.as_string() + assert [type(child) for child in stree.children + ] == [tn.LibraryCall, tn.FunctionCallScope, tn.FunctionCallScope, tn.TaskletNode, tn.ReturnNode] + assert stree.children[-1].values[0] == stree.children[-2].out_memlets['out'].data + + +def test_python_frontend_schedule_tree_method_unresolved_new_field_assignment_uses_copy_node(): + + class FieldWriter: + + @dace.method + def write(self, A: dace.float64[8]): + self.new_field = A[0] + return A[0] + + stree = FieldWriter().write.to_schedule_tree() + + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'self.new_field' + assert 'self' in stree.containers + assert isinstance(stree.containers['self'], dace.data.pydata.PythonClass) + assert stree.containers['self'].name == 'FieldWriter' + assert isinstance(stree.containers['self'].members['new_field'], dace.data.Scalar) + assert all('new_field' not in name for name in stree.containers) + + +def test_python_frontend_schedule_tree_descriptor_and_attribute_access(): + + class ArrayDescriptor: + + def __set_name__(self, owner, name): + self.name = '_' + name + + def __get__(self, obj, objtype=None): + return getattr(obj, self.name) + + def __set__(self, obj, value): + setattr(obj, self.name, value) + + class DescriptorHolder: + arr = ArrayDescriptor() + + def __init__(self): + self.arr = np.zeros(8, dtype=np.float64) + + descriptor_holder = DescriptorHolder() + + @dace.program + def descriptor_prog(A: dace.float64[8], out: dace.float64[8]): + descriptor_holder.arr = A + out[:] = descriptor_holder.arr + + class AttrHolder: + + def __init__(self): + self.arr = np.zeros(8, dtype=np.float64) + + attr_holder = AttrHolder() + + @dace.program + def attr_prog(A: dace.float64[8], out: dace.float64[8]): + attr_holder.arr = A + out[:] = attr_holder.arr + + descriptor_tree = descriptor_prog.to_schedule_tree() + attribute_tree = attr_prog.to_schedule_tree() + + assert isinstance(descriptor_tree.children[0], tn.StatementNode) + assert descriptor_tree.children[ + 0].code.as_string == "type(descriptor_holder).__dict__['arr'].__set__(descriptor_holder, A)" + assert isinstance(descriptor_tree.children[1], tn.StatementNode) + assert descriptor_tree.children[1].code.as_string == ( + "out[:] = type(descriptor_holder).__dict__['arr'].__get__(descriptor_holder, type(descriptor_holder))") + assert isinstance(attribute_tree.children[0], tn.StatementNode) + assert attribute_tree.children[0].code.as_string == 'attr_holder.arr = A' + assert isinstance(attribute_tree.children[1], tn.StatementNode) + assert attribute_tree.children[1].code.as_string == 'out[:] = attr_holder.arr' + + +def test_python_frontend_schedule_tree_descriptor_setter_protocol_is_preserved(): + + class OffsetDescriptor: + + def __set_name__(self, owner, name): + self.name = '_' + name + + def __get__(self, obj, objtype=None): + return getattr(obj, self.name) + + def __set__(self, obj, value): + setattr(obj, self.name, value + 1) + + class DescriptorHolder: + arr = OffsetDescriptor() + + def __init__(self): + self.arr = np.zeros(8, dtype=np.float64) + + descriptor_holder = DescriptorHolder() + + @dace.program + def descriptor_prog(A: dace.float64[8], out: dace.float64[8]): + descriptor_holder.arr = A + out[:] = descriptor_holder.arr + + stree = descriptor_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.StatementNode) + assert stree.children[0].code.as_string == "type(descriptor_holder).__dict__['arr'].__set__(descriptor_holder, A)" + assert isinstance(stree.children[1], tn.StatementNode) + assert stree.children[1].code.as_string == ( + "out[:] = type(descriptor_holder).__dict__['arr'].__get__(descriptor_holder, type(descriptor_holder))") + + +def test_python_frontend_schedule_tree_structure_scalar_field_assignment_errors_to_use_pythonclass(): + + class Holder: + scalar: dace.float64 + arr: dace.float64[8] + + Struct = dace.data.Structure.from_class(Holder) + + @dace.program + def prog(holder: Struct, A: dace.float64[8]): + holder.scalar = A[0] + holder.arr[:] = A[:] + + with pytest.raises(DaceSyntaxError, match=r'non-array field "scalar".*PythonClass'): + prog.to_schedule_tree() + + +def test_python_frontend_schedule_tree_structure_new_field_assignment_errors_to_use_pythonclass(): + + class Holder: + arr: dace.float64[8] + + Struct = dace.data.Structure.from_class(Holder) + + @dace.program + def prog(holder: Struct, A: dace.float64[8]): + holder.new_field = A[0] + holder.arr[:] = A[:] + + with pytest.raises(DaceSyntaxError, match=r'Creating field "new_field".*PythonClass'): + prog.to_schedule_tree() + + +def test_python_frontend_schedule_tree_direct_class_scalar_field_assignment_uses_pythonclass(): + + class Holder: + scalar: dace.float64 + arr: dace.float64[8] + + @dace.program + def prog(holder: Holder, A: dace.float64[8]): + holder.scalar = A[0] + holder.arr[:] = A[:] + + with warnings.catch_warnings(): + warnings.simplefilter('error') + stree = prog.to_schedule_tree() + + assert isinstance(stree.containers['holder'], dace.data.pydata.PythonClass) + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'holder.scalar' + assert isinstance(stree.children[1], tn.CopyNode) + + +def test_python_frontend_schedule_tree_direct_class_new_field_assignment_uses_pythonclass(): + + class Holder: + arr: dace.float64[8] + + @dace.program + def prog(holder: Holder, A: dace.float64[8]): + holder.new_field = A[0] + holder.arr[:] = A[:] + + with warnings.catch_warnings(): + warnings.simplefilter('error') + stree = prog.to_schedule_tree() + + assert isinstance(stree.containers['holder'], dace.data.pydata.PythonClass) + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'holder.new_field' + + +def test_python_frontend_schedule_tree_direct_class_new_array_field_assignment_uses_refset(): + + class Holder: + scalar: dace.float64 + + @dace.program + def prog(holder: Holder, A: dace.float64[8], out: dace.float64[8]): + holder.new_data = A + out[:] = holder.new_data[:] + + with warnings.catch_warnings(): + warnings.simplefilter('error') + stree = prog.to_schedule_tree() + + assert isinstance(stree.containers['holder'], dace.data.pydata.PythonClass) + assert isinstance(stree.children[0], tn.RefSetNode) + assert stree.children[0].target == 'holder.new_data' + assert isinstance(stree.children[1], tn.CopyNode) + + +def test_python_frontend_schedule_tree_direct_class_literal_new_field_assignment_uses_tasklet(): + + class Holder: + scalar: dace.float64 + + @dace.program + def prog(holder: Holder): + holder.new_field = 4.25 + + with warnings.catch_warnings(): + warnings.simplefilter('error') + stree = prog.to_schedule_tree() + + assert isinstance(stree.containers['holder'], dace.data.pydata.PythonClass) + assert isinstance(stree.children[0], tn.TaskletNode) + assert stree.children[0].node.code.as_string == 'out = 4.25' + + +def test_python_frontend_schedule_tree_direct_class_array_field_only_stays_structure(): + + class Holder: + arr: dace.float64[8] + + @dace.program + def prog(holder: Holder, A: dace.float64[8], out: dace.float64[8]): + holder.arr[:] = A[:] + out[:] = holder.arr + + with warnings.catch_warnings(): + warnings.simplefilter('error') + stree = prog.to_schedule_tree() + + assert isinstance(stree.containers['holder'], dace.data.Structure) + assert not isinstance(stree.containers['holder'], dace.data.pydata.PythonClass) + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'holder.arr' + assert isinstance(stree.children[1], tn.CopyNode) + + +def test_python_frontend_schedule_tree_direct_class_nested_scalar_field_assignment_uses_pythonclass(): + + class Inner: + scalar: dace.float64 + + class Outer: + inner: Inner + + @dace.program + def prog(wrapper: Outer, A: dace.float64[8]): + wrapper.inner.scalar = A[0] + + with warnings.catch_warnings(): + warnings.simplefilter('error') + stree = prog.to_schedule_tree() + + assert isinstance(stree.containers['wrapper'], dace.data.pydata.PythonClass) + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'wrapper.inner.scalar' + + +def test_python_frontend_schedule_tree_structure_annotated_field_scalar_assignment_warns(): + + class Inner: + scalar: dace.float64 + arr: dace.float64[8] + + InnerStruct = dace.data.Structure.from_class(Inner) + + class Outer: + inner: InnerStruct + + @dace.program + def prog(wrapper: Outer, A: dace.float64[8]): + wrapper.inner.scalar = A[0] + wrapper.inner.arr[:] = A[:] + + with pytest.warns(UserWarning, match=r'non-array field "scalar".*PythonClass'): + stree = prog.to_schedule_tree() + + assert isinstance(stree.containers['wrapper'], dace.data.Structure) + assert not isinstance(stree.containers['wrapper'], dace.data.pydata.PythonClass) + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'wrapper.inner.scalar' + assert isinstance(stree.children[1], tn.CopyNode) + + +def test_python_frontend_schedule_tree_structure_annotated_field_new_field_assignment_warns(): + + class Inner: + arr: dace.float64[8] + + InnerStruct = dace.data.Structure.from_class(Inner) + + class Outer: + inner: InnerStruct + + @dace.program + def prog(wrapper: Outer, A: dace.float64[8]): + wrapper.inner.new_field = A[0] + + with pytest.warns(UserWarning, match=r'Creating field "new_field".*PythonClass'): + stree = prog.to_schedule_tree() + + assert isinstance(stree.containers['wrapper'], dace.data.Structure) + assert not isinstance(stree.containers['wrapper'], dace.data.pydata.PythonClass) + assert isinstance(stree.children[0], tn.StatementNode) + assert stree.children[0].code.as_string == 'wrapper.inner.new_field = A[0]' + + +def test_python_frontend_schedule_tree_nested_calls_propagate_pythonclass_to_top_level(): + + class Holder: + arr: dace.float64[8] + + @dace.program + def leaf(holder: Holder, A: dace.float64[8]): + holder.new_field = A[0] + + @dace.program + def mid(holder: Holder, A: dace.float64[8]): + leaf(holder, A) + + @dace.program + def top(holder: Holder, A: dace.float64[8]): + mid(holder, A) + + with warnings.catch_warnings(): + warnings.simplefilter('error') + stree = top.to_schedule_tree() + + assert isinstance(stree.containers['holder'], dace.data.pydata.PythonClass) + call_scopes = [node for node in stree.preorder_traversal() if isinstance(node, tn.FunctionCallScope)] + assert [scope.call.callee_name for scope in call_scopes] == ['mid', 'leaf'] + + +def test_python_frontend_schedule_tree_direct_class_annotated_alias_scalar_field_assignment_uses_pythonclass(): + + class Holder: + scalar: dace.float64 + + @dace.program + def prog(holder: Holder, A: dace.float64[8]): + alias: Holder = holder + alias.scalar = A[0] + + with warnings.catch_warnings(): + warnings.simplefilter('error') + stree = prog.to_schedule_tree() + + assert isinstance(stree.containers['alias'], dace.data.pydata.PythonClass) + assert isinstance(stree.children[0], tn.RefSetNode) + assert stree.children[0].target == 'alias' + assert isinstance(stree.children[1], tn.CopyNode) + assert stree.children[1].target == 'alias.scalar' + + +def test_python_frontend_schedule_tree_direct_class_annotated_alias_new_field_assignment_uses_pythonclass(): + + class Holder: + arr: dace.float64[8] + + @dace.program + def prog(holder: Holder, A: dace.float64[8]): + alias: Holder = holder + alias.new_field = A[0] + + with warnings.catch_warnings(): + warnings.simplefilter('error') + stree = prog.to_schedule_tree() + + assert isinstance(stree.containers['alias'], dace.data.pydata.PythonClass) + assert isinstance(stree.children[0], tn.RefSetNode) + assert stree.children[0].target == 'alias' + assert isinstance(stree.children[1], tn.CopyNode) + assert stree.children[1].target == 'alias.new_field' + + +def test_python_frontend_schedule_tree_optional_none_branch(): + + @dace.program + def optional_none_prog(field: Optional[dace.float64[8]], A: dace.float64[8], out: dace.float64[8]): + if field is None: + out[:] = A[:] + else: + out[:] = field[:] + + stree = optional_none_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.IfScope) + assert stree.children[0].condition.as_string == '(field is None)' + assert isinstance(stree.children[0].children[0], tn.CopyNode) + assert isinstance(stree.children[1], tn.ElseScope) + assert isinstance(stree.children[1].children[0], tn.CopyNode) + + +def test_python_frontend_schedule_tree_list_comprehension(): + + @dace.program + def list_comp_prog(A: dace.float64[8]): + tmp = [A[i] for i in range(4)] + return tmp + + stree = list_comp_prog.to_schedule_tree() + + # Comprehensions are now desugared to explicit loops in preprocessing. + # The tree should contain an init (__comp_tmp = []) and a loop. + assert isinstance(stree, tn.ScheduleTreeRoot) + # Find the loop that was desugared from the comprehension + loops = [c for c in stree.children if isinstance(c, tn.LoopScope)] + assert len(loops) >= 1 + + +def test_python_frontend_schedule_tree_linked_object_reference(): + + class Node: + + def __init__(self, arr, next_node=None): + self.arr = arr + self.next = next_node + + linked = Node(np.zeros(8, dtype=np.float64), Node(np.ones(8, dtype=np.float64))) + + @dace.program + def linked_prog(out: dace.float64[8]): + ref = linked.next.arr + out[:] = ref + + stree = linked_prog.to_schedule_tree() + + assert len(stree.children) == 2 + assert isinstance(stree.children[0], tn.RefSetNode) + assert stree.children[0].source_expr == 'linked.next.arr' + assert isinstance(stree.children[1], tn.CopyNode) + + +def test_python_frontend_schedule_tree_normalized_loop_iterators(): + + @dace.program + def array_iter_prog(A: dace.float64[4], out: dace.float64[4]): + for val in A: + out[0] = val + + @dace.program + def zip_prog(A: dace.float64[4], B: dace.float64[4], out: dace.float64[4]): + for a, b in zip(A, B): + out[0] = a + b + + @dace.program + def enumerate_prog(A: dace.float64[4], out: dace.float64[4]): + for i, val in enumerate(A): + out[i] = val + + @dace.program + def enumerate_zip_flat_prog(A: dace.float64[4], B: dace.float64[4], out: dace.float64[4]): + for i, pair in enumerate(zip(A, B)): + out[i] = pair[0] + pair[1] + + @dace.program + def enumerate_zip_unpack_prog(A: dace.float64[4], B: dace.float64[4], out: dace.float64[4]): + for i, (a, b) in enumerate(zip(A, B)): + out[i] = a + b + + array_tree = array_iter_prog.to_schedule_tree() + zip_tree = zip_prog.to_schedule_tree() + enumerate_tree = enumerate_prog.to_schedule_tree() + enumerate_zip_flat_tree = enumerate_zip_flat_prog.to_schedule_tree() + enumerate_zip_unpack_tree = enumerate_zip_unpack_prog.to_schedule_tree() + + assert len(array_tree.children) == 1 + assert isinstance(array_tree.children[0], tn.LoopScope) + assert isinstance(array_tree.children[0].loop, tn.FrontendLoop) + assert isinstance(array_tree.children[0].children[0], tn.CopyNode) + + assert len(zip_tree.children) == 1 + assert isinstance(zip_tree.children[0], tn.LoopScope) + assert isinstance(zip_tree.children[0].loop, tn.FrontendLoop) + assert isinstance(zip_tree.children[0].children[0], tn.TaskletNode) + + assert len(enumerate_tree.children) == 1 + assert isinstance(enumerate_tree.children[0], tn.LoopScope) + assert isinstance(enumerate_tree.children[0].loop, tn.FrontendLoop) + assert isinstance(enumerate_tree.children[0].children[0], tn.CopyNode) + + assert len(enumerate_zip_flat_tree.children) == 1 + assert isinstance(enumerate_zip_flat_tree.children[0], tn.LoopScope) + assert isinstance(enumerate_zip_flat_tree.children[0].loop, tn.FrontendLoop) + assert isinstance(enumerate_zip_flat_tree.children[0].children[0], tn.TaskletNode) + + assert len(enumerate_zip_unpack_tree.children) == 1 + assert isinstance(enumerate_zip_unpack_tree.children[0], tn.LoopScope) + assert isinstance(enumerate_zip_unpack_tree.children[0].loop, tn.FrontendLoop) + assert isinstance(enumerate_zip_unpack_tree.children[0].children[0], tn.TaskletNode) + + +def test_python_frontend_schedule_tree_generic_iterator_fallback(): + + class CounterIterable: + + def __iter__(self): + return iter([1.0, 2.0, 3.0]) + + counter = CounterIterable() + + @dace.program + def iter_prog(out: dace.float64[4]): + for val in dace.nounroll(counter): + out[0] = val + + stree = iter_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.AssignNode) + assert isinstance(stree.children[1], tn.AssignNode) + assert isinstance(stree.children[2], tn.AssignNode) + assert isinstance(stree.children[3], tn.AssignNode) + assert isinstance(stree.children[4], tn.LoopScope) + assert isinstance(stree.children[4].loop, tn.FrontendLoop) + assert stree.children[4].loop.loop_condition.as_string == '__dace_iter_has_next_2' + assert isinstance(stree.children[4].children[0], tn.CopyNode) + assert isinstance(stree.children[4].children[1], tn.AssignNode) + assert stree.children[4].children[1].name == '__dace_iter_next_1' + assert isinstance(stree.children[4].children[2], tn.AssignNode) + assert stree.children[4].children[2].name == '__dace_iter_has_next_2' + assert isinstance(stree.children[4].children[3], tn.AssignNode) + assert stree.children[4].children[3].name == '__dace_iter_value_3' + + +def test_python_frontend_schedule_tree_generic_iterator_inference_has_no_runtime_side_effect(): + + class CounterIterable: + + def __init__(self): + self.iter_calls = 0 + + def __iter__(self): + self.iter_calls += 1 + return iter([1.0, 2.0, 3.0]) + + counter = CounterIterable() + + @dace.program + def iter_prog(out: dace.float64[4]): + for val in dace.nounroll(counter): + out[0] = val + + stree = iter_prog.to_schedule_tree() + + assert counter.iter_calls == 0 + assert isinstance(stree.children[4], tn.LoopScope) + assert isinstance(stree.children[4].children[0], tn.CopyNode) + + +def test_python_frontend_schedule_tree_generic_iterator_tuple_value(): + + class PairIterable: + + def __iter__(self): + return iter([(1.0, 2.0), (3.0, 4.0)]) + + pairs = PairIterable() + + @dace.program + def iter_prog(out: dace.float64[4]): + for pair in dace.nounroll(pairs): + out[0] = pair[0] + pair[1] + + stree = iter_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.AssignNode) + assert isinstance(stree.children[1], tn.AssignNode) + assert isinstance(stree.children[2], tn.AssignNode) + assert isinstance(stree.children[3], tn.AssignNode) + assert isinstance(stree.children[4], tn.LoopScope) + assert isinstance(stree.children[4].children[0], tn.TaskletNode) + + +def test_python_frontend_schedule_tree_generic_iterator_fallback_destructuring(): + + class PairIterable: + + def __iter__(self): + return iter([(1.0, 2.0), (3.0, 4.0)]) + + pairs = PairIterable() + + @dace.program + def iter_prog(out: dace.float64[4]): + for a, b in dace.nounroll(pairs): + out[0] = a + b + + stree = iter_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.AssignNode) + assert isinstance(stree.children[1], tn.AssignNode) + assert isinstance(stree.children[2], tn.AssignNode) + assert isinstance(stree.children[3], tn.AssignNode) + assert isinstance(stree.children[4], tn.LoopScope) + assert isinstance(stree.children[4].loop, tn.FrontendLoop) + assert isinstance(stree.children[4].children[0], tn.StatementNode) + assert stree.children[4].children[0].code.as_string == '(a, b) = __dace_iter_value_3' + assert isinstance(stree.children[4].children[1], tn.TaskletNode) + assert isinstance(stree.children[4].children[2], tn.AssignNode) + assert isinstance(stree.children[4].children[3], tn.AssignNode) + assert isinstance(stree.children[4].children[4], tn.AssignNode) + + +def test_python_frontend_schedule_tree_generic_iterator_generator_object(): + + def reverse_range(sz): + cur = sz + for _ in range(sz): + yield float(cur) + cur -= 1 + + generator = reverse_range(3) + + @dace.program + def iter_prog(out: dace.float64[4]): + for val in dace.nounroll(generator): + out[0] = val + + stree = iter_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.AssignNode) + assert isinstance(stree.children[1], tn.AssignNode) + assert isinstance(stree.children[2], tn.AssignNode) + assert isinstance(stree.children[3], tn.AssignNode) + assert isinstance(stree.children[4], tn.LoopScope) + assert isinstance(stree.children[4].loop, tn.FrontendLoop) + assert isinstance(stree.children[4].children[0], tn.CopyNode) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + assert next(generator) == 3.0 + + +def test_python_frontend_schedule_tree_tuple_of_arrays_unrolls(): + + @dace.program + def iter_prog(a: dace.float64[2, 3, 4], b: dace.float64[2, 3, 4], c: dace.float64[2, 3, 4], out: dace.float64[2, 3, + 4]): + for arr in (a, b, c): + out[:] = arr + + with pytest.warns(UserWarning, match=r'implicitly unrolled'): + stree = iter_prog.to_schedule_tree() + + assert [type(child) for child in stree.children] == [tn.CopyNode, tn.CopyNode, tn.CopyNode] + assert [child.memlet.data for child in stree.children] == ['a', 'b', 'c'] + + +def test_python_frontend_schedule_tree_nounroll_array_annotation_binds_reference(): + + @dace.program + def iter_prog(a: dace.float64[2, 3, 4], b: dace.float64[2, 3, 4], c: dace.float64[2, 3, 4], out: dace.float64[2, 3, + 4]): + list_of_arrays = [a, b, c] + for arr in dace.nounroll(list_of_arrays): + arr: dace.float64[2, 3, 4] + out[:] = arr + + stree = iter_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'list_of_arrays_0' + assert stree.children[0].memlet.data == 'a' + assert isinstance(stree.children[1], tn.CopyNode) + assert stree.children[1].target == 'list_of_arrays_1' + assert stree.children[1].memlet.data == 'b' + assert isinstance(stree.children[2], tn.CopyNode) + assert stree.children[2].target == 'list_of_arrays_2' + assert stree.children[2].memlet.data == 'c' + assert isinstance(stree.containers['list_of_arrays'], PythonList) + assert isinstance(stree.children[3], tn.AssignNode) + assert stree.children[3].name == '__dace_iter_0' + assert isinstance(stree.children[4], tn.AssignNode) + assert stree.children[4].name == '__dace_iter_next_1' + assert isinstance(stree.children[5], tn.AssignNode) + assert stree.children[5].name == '__dace_iter_has_next_2' + assert isinstance(stree.children[6], tn.AssignNode) + assert stree.children[6].name == '__dace_iter_value_3' + assert isinstance(stree.children[7], tn.LoopScope) + assert stree.children[7].loop.loop_condition.as_string == '__dace_iter_has_next_2' + assert isinstance(stree.children[7].children[0], tn.RefSetNode) + assert stree.children[7].children[0].target == 'arr' + assert isinstance(stree.children[7].children[1], tn.CopyNode) + assert stree.children[7].children[1].memlet.data == 'arr' + assert isinstance(stree.children[7].children[2], tn.AssignNode) + assert stree.children[7].children[2].name == '__dace_iter_next_1' + assert isinstance(stree.children[7].children[3], tn.AssignNode) + assert stree.children[7].children[3].name == '__dace_iter_has_next_2' + assert isinstance(stree.children[7].children[4], tn.AssignNode) + assert stree.children[7].children[4].name == '__dace_iter_value_3' + assert '__dace_iter_end_' not in stree.as_string() + + +def test_python_frontend_schedule_tree_typed_dict_literal_and_static_read(): + + @dace.program + def dict_prog(A: dace.float64[2]): + mapping = {'left': A[0], 'right': A[1]} + value = mapping['left'] + return value + + stree = dict_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.StatementNode) + assert stree.children[0].code.as_string == "mapping = {'left': A[0], 'right': A[1]}" + assert isinstance(stree.containers['mapping'], PythonDict) + assert isinstance(stree.containers['mapping'].key_type, dace.data.Scalar) + assert stree.containers['mapping'].key_type.dtype == dace.string + assert isinstance(stree.containers['mapping'].value_type, dace.data.Scalar) + assert stree.containers['mapping'].value_type.dtype == dace.float64 + assert isinstance(stree.children[1], tn.AssignNode) + assert stree.children[1].name == 'value' + assert stree.children[1].value.as_string == "mapping['left']" + assert isinstance(stree.containers['value'], dace.data.Scalar) + assert stree.containers['value'].dtype == dace.float64 + assert isinstance(stree.children[2], tn.ReturnNode) + + +def test_python_frontend_schedule_tree_heterogeneous_dict_values_fallback_to_pyobject(): + + @dace.program + def dict_prog(A: dace.float64[2]): + mapping = {'left': A[0], 'right': 'two'} + return 0.0 + + stree = dict_prog.to_schedule_tree() + + assert isinstance(stree.containers['mapping'], PythonDict) + assert stree.containers['mapping'].key_type.dtype == dace.string + assert stree.containers['mapping'].value_type.dtype == dtypes.pyobject() + + +def test_python_frontend_schedule_tree_heterogeneous_dict_known_static_reads_stay_precise(): + + @dace.program + def dict_prog(A: dace.float64[2]): + mapping = {'left': A[0], 'right': 'two'} + left = mapping['left'] + right = mapping['right'] + return left + + stree = dict_prog.to_schedule_tree() + + assert isinstance(stree.containers['mapping'], PythonDict) + assert stree.containers['mapping'].value_type.dtype == dtypes.pyobject() + assert isinstance(stree.containers['left'], dace.data.Scalar) + assert stree.containers['left'].dtype == dace.float64 + assert isinstance(stree.containers['right'], dace.data.Scalar) + assert stree.containers['right'].dtype == dace.string + assert isinstance(stree.children[1], tn.AssignNode) + assert stree.children[1].value.as_string == "mapping['left']" + assert isinstance(stree.children[2], tn.AssignNode) + assert stree.children[2].value.as_string == "mapping['right']" + + +def test_python_frontend_schedule_tree_dict_same_key_update_widens_value_type(): + + @dace.program + def dict_prog(A: dace.float64[2]): + mapping = {'left': A[0], 'right': A[1]} + mapping['left'] = 'two' + value = mapping['left'] + return value + + stree = dict_prog.to_schedule_tree() + + assert isinstance(stree.containers['mapping'], PythonDict) + assert stree.containers['mapping'].key_type.dtype == dace.string + assert stree.containers['mapping'].value_type.dtype == dtypes.pyobject() + assert isinstance(stree.containers['value'], dace.data.Scalar) + assert stree.containers['value'].dtype == dace.string + assert isinstance(stree.children[1], tn.StatementNode) + assert stree.children[1].code.as_string == "mapping['left'] = 'two'" + assert isinstance(stree.children[2], tn.AssignNode) + assert stree.children[2].value.as_string == "mapping['left']" + + +def test_python_frontend_schedule_tree_free_iter_and_next_calls(): + + def reverse_range(sz): + cur = sz + for _ in range(sz): + yield float(cur) + cur -= 1 + + generator = reverse_range(3) + + @dace.program + def iter_prog(out: dace.float64[3]): + it = iter(generator) + out[0] = next(it) + out[1] = next(it) + out[2] = next(it) + + stree = iter_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.PythonCallbackNode) + assert stree.children[0].reason == 'pyobject call' + assert stree.children[0].code.as_string == 'it = iter(generator)' + assert stree.children[0].outlined_function_name.startswith('__stree_callback') + assert stree.children[0].outlined_function_code.as_string.startswith( + f'def {stree.children[0].outlined_function_name}():') + assert stree.children[0].outlined_call_code.as_string == f'it = {stree.children[0].outlined_function_name}()' + for index, child in enumerate(stree.children[1:]): + assert isinstance(child, tn.TaskletNode) + assert child.node.code.as_string == f'out[{index}] = next(it)' + assert len([node for node in stree.preorder_traversal() if isinstance(node, tn.PythonCallbackNode)]) == 1 + assert next(generator) == 3.0 + + +def test_python_frontend_schedule_tree_internal_generator_with_next_calls(): + + def reverse_range(sz): + cur = sz + for _ in range(sz): + yield float(cur) + cur -= 1 + + @dace.program + def iter_prog(out: dace.float64[3]): + gen = reverse_range(3) + out[0] = next(gen) + out[1] = next(gen) + out[2] = next(gen) + + stree = iter_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.PythonCallbackNode) + assert stree.children[0].reason == 'pyobject call' + assert stree.children[0].code.as_string == 'gen = reverse_range(3)' + for index, child in enumerate(stree.children[1:]): + assert isinstance(child, tn.TaskletNode) + assert child.node.code.as_string == f'out[{index}] = next(gen)' + assert len([node for node in stree.preorder_traversal() if isinstance(node, tn.PythonCallbackNode)]) == 1 + + +def test_python_frontend_schedule_tree_next_iter_dict_values(): + + @dace.program + def iter_prog(out: dace.int64[1]): + x = {1: 1, 2: 2, 3: 3} + out[0] = next(iter(x.values())) + + stree = iter_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.StatementNode) + assert stree.children[0].code.as_string == 'x = {1: 1, 2: 2, 3: 3}' + assert isinstance(stree.children[1], tn.PythonCallbackNode) + assert stree.children[1].reason == 'pyobject call' + assert stree.children[1].code.as_string == '__stree_tmp = x.values()' + assert isinstance(stree.children[2], tn.PythonCallbackNode) + assert stree.children[2].reason == 'pyobject call' + assert stree.children[2].code.as_string == '__stree_tmp1 = iter(__stree_tmp)' + assert isinstance(stree.children[3], tn.TaskletNode) + assert stree.children[3].node.code.as_string == 'out[0] = next(__stree_tmp1)' + assert len([node for node in stree.preorder_traversal() if isinstance(node, tn.PythonCallbackNode)]) == 2 + + +def test_python_frontend_schedule_tree_untyped_next_warns(): + + def reverse_range(sz): + cur = sz + for _ in range(sz): + yield float(cur) + cur -= 1 + + @dace.program + def iter_prog(out: dace.float64[1]): + gen = reverse_range(3) + val = next(gen) + out[0] = val + + with pytest.warns(UserWarning, + match=r'Could not infer the result type of iterator next\(\) in schedule-tree lowering'): + stree = iter_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.PythonCallbackNode) + assert stree.children[0].code.as_string == 'gen = reverse_range(3)' + assert isinstance(stree.children[1], tn.PythonCallbackNode) + assert stree.children[1].code.as_string == 'val = next(gen)' + assert isinstance(stree.children[2], tn.CopyNode) + + +def test_python_frontend_schedule_tree_annotated_next_assignment_is_typed(): + + def reverse_range(sz): + cur = sz + for _ in range(sz): + yield float(cur) + cur -= 1 + + @dace.program + def iter_prog(out: dace.float64[1]): + gen = reverse_range(3) + val: dace.float64 = next(gen) + out[0] = val + + with pytest.raises(pytest.fail.Exception, match='DID NOT WARN'): + with pytest.warns(UserWarning, + match=r'Could not infer the result type of iterator next\(\) in schedule-tree lowering'): + stree = iter_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.PythonCallbackNode) + assert stree.children[0].code.as_string == 'gen = reverse_range(3)' + assert isinstance(stree.children[1], tn.TaskletNode) + assert stree.children[1].node.code.as_string == 'val = next(gen)' + assert isinstance(stree.children[2], tn.CopyNode) + + +def test_python_frontend_schedule_tree_tuple_swap_statement(): + + @dace.program + def swap_prog(A: dace.float64[4], B: dace.float64[4]): + A, B = B, A + return A + + stree = swap_prog.to_schedule_tree() + + assert [type(child) + for child in stree.children] == [tn.CopyNode, tn.CopyNode, tn.CopyNode, tn.CopyNode, tn.ReturnNode] + assert [child.target for child in stree.children[:4]] == ['__stree_tuple_tmp_0', '__stree_tuple_tmp_1', 'A', 'B'] + assert [child.memlet.data + for child in stree.children[:4]] == ['B', 'A', '__stree_tuple_tmp_0', '__stree_tuple_tmp_1'] + assert not any(isinstance(node, tn.StatementNode) for node in stree.preorder_traversal()) + + +def test_python_frontend_schedule_tree_tuple_permutation_materializes_rhs(): + + @dace.program + def perm_prog(A: dace.float64[4], B: dace.float64[4], C: dace.float64[4]): + A, B, C = C, A, A + return A + + stree = perm_prog.to_schedule_tree() + + assert [type(child) for child in stree.children + ] == [tn.CopyNode, tn.CopyNode, tn.CopyNode, tn.CopyNode, tn.CopyNode, tn.CopyNode, tn.ReturnNode] + assert [child.target for child in stree.children[:6] + ] == ['__stree_tuple_tmp_0', '__stree_tuple_tmp_1', '__stree_tuple_tmp_2', 'A', 'B', 'C'] + assert [child.memlet.data for child in stree.children[:6] + ] == ['C', 'A', 'A', '__stree_tuple_tmp_0', '__stree_tuple_tmp_1', '__stree_tuple_tmp_2'] + assert not any(isinstance(node, tn.StatementNode) for node in stree.preorder_traversal()) + + +def test_python_frontend_schedule_tree_starred_unpacking_uses_analyzable_structure(): + + @dace.program + def starred_prog(A: dace.float64[4], B: dace.float64[4], C: dace.float64[4], out: dace.float64[4]): + head, *rest = (A, B, C) + out[:] = rest[1] + + stree = starred_prog.to_schedule_tree() + + assert isinstance(stree.children[0], (tn.RefSetNode, tn.ViewNode)) + assert stree.children[0].target == 'head' + assert isinstance(stree.children[1], tn.CopyNode) + assert stree.children[1].target == 'rest_0' + assert stree.children[1].memlet.data == 'B' + assert isinstance(stree.children[2], tn.CopyNode) + assert stree.children[2].target == 'rest_1' + assert stree.children[2].memlet.data == 'C' + assert isinstance(stree.children[3], tn.CopyNode) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + + +def test_python_frontend_schedule_tree_star_call_expansion_is_resolved_statically(): + + def callee(a, b, c): + return a + b + c + + @dace.program + def star_call_prog(A: dace.float64[4], B: dace.float64[4], C: dace.float64[4]): + args = (A, B) + return callee(*args, c=C) + + stree = star_call_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'args_0' + assert stree.children[0].memlet.data == 'A' + assert isinstance(stree.children[1], tn.CopyNode) + assert stree.children[1].target == 'args_1' + assert stree.children[1].memlet.data == 'B' + assert isinstance(stree.children[2], tn.FunctionCallScope) + assert stree.children[2].call.callee_name == 'callee' + assert stree.children[2].call.arguments == {'a': 'A', 'b': 'B', 'c': 'C'} + assert isinstance(stree.children[3], tn.ReturnNode) + assert stree.children[3].values[0] == '__stree_retval' + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + + +def test_python_frontend_schedule_tree_double_star_call_expansion_is_resolved_statically(): + + def callee(a, b, c): + return a + b + c + + @dace.program + def dstar_call_prog(A: dace.float64[4], B: dace.float64[4], C: dace.float64[4]): + kwargs = {'c': C} + return callee(A, B, **kwargs) + + stree = dstar_call_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.StatementNode) + assert stree.children[0].code.as_string == "kwargs = {'c': C}" + assert isinstance(stree.children[1], tn.FunctionCallScope) + assert stree.children[1].call.callee_name == 'callee' + assert stree.children[1].call.arguments == {'a': 'A', 'b': 'B', 'c': 'C'} + assert isinstance(stree.children[2], tn.ReturnNode) + assert stree.children[2].values[0] == '__stree_retval' + assert not any(isinstance(node, tn.PythonCallbackNode) for node in stree.preorder_traversal()) + + +def test_python_frontend_schedule_tree_dynamic_star_call_expansion_uses_callback(): + + def callee(a, b, c): + return a + b + c + + @dace.program + def dynamic_star_call(flag: dace.bool_, A: dace.float64[4], B: dace.float64[4], C: dace.float64[4]): + return callee(*((A, B) if flag else (B, A)), c=C) + + stree = dynamic_star_call.to_schedule_tree() + + assert isinstance(stree.children[0], tn.PythonCallbackNode) + assert stree.children[0].reason == 'call expansion' + assert stree.children[0].code.as_string == '__stree_retval = callee(*((A, B) if flag else (B, A)), c=C)' + assert isinstance(stree.children[1], tn.ReturnNode) + assert stree.children[1].values[0] == '__stree_retval' + + +def test_python_frontend_schedule_tree_dynamic_expanded_sdfg_call_raises(): + + @dace.program + def inner(A: dace.float64[4], B: dace.float64[4]): + return A + B + + sdfg_obj = inner.to_sdfg(simplify=False) + + @dace.program + def dynamic_sdfg_call(flag: dace.bool_, A: dace.float64[4], B: dace.float64[4]): + return sdfg_obj(*((A, B) if flag else (B, A))) + + with pytest.raises(DaceSyntaxError, match='Dynamic argument expansion is unsupported for SDFG calls'): + dynamic_sdfg_call.to_schedule_tree() + + +# ------------------------------------------------------------------ # +# Phase 4 — Full Python Language Coverage Tests # +# ------------------------------------------------------------------ # + + +def test_try_except_produces_callback(): + + @dace.program + def try_prog(A: dace.float64[10]): + try: + A[0] = 1.0 + except Exception: + A[0] = 0.0 + return A + + stree = try_prog.to_schedule_tree() + + callbacks = [c for c in stree.children if isinstance(c, tn.PythonCallbackNode)] + assert len(callbacks) >= 1 + assert callbacks[0].reason == 'try/except' + + +def test_import_produces_callback(): + + @dace.program + def import_prog(A: dace.float64[10]): + import math + A[0] = math.pi + return A + + stree = import_prog.to_schedule_tree() + + callbacks = [c for c in stree.children if isinstance(c, tn.PythonCallbackNode)] + assert len(callbacks) >= 1 + assert any(c.reason == 'import' for c in callbacks) + + +def test_match_lowers_to_if_chain(): + + @dace.program + def match_prog(A: dace.int32[1]): + match A[0]: + case 0: + A[0] = 1 + case _: + A[0] = 2 + + stree = match_prog.to_schedule_tree() + + assert not any(isinstance(c, tn.PythonCallbackNode) for c in stree.children) + assert any(isinstance(c, tn.IfScope) for c in stree.children) + assert any(isinstance(c, tn.ElseScope) for c in stree.children) + + +def test_match_capture_guard_and_or_lower_natively(): + + @dace.program + def match_prog(A: dace.int32[1], B: dace.int32[1]): + match A[0]: + case 0 | 1: + B[0] = 7 + case x if x > 2: + B[0] = x + case _: + B[0] = -1 + + stree = match_prog.to_schedule_tree() + + assert not any(isinstance(c, tn.PythonCallbackNode) for c in stree.preorder_traversal()) + assert any(isinstance(c, tn.IfScope) for c in stree.children) + assert any(isinstance(c, tn.ElifScope) for c in stree.children) + assert any(isinstance(c, tn.ElseScope) for c in stree.children) + + +def test_match_fixed_length_sequence_lowers_natively(): + + @dace.program + def match_prog(A: dace.int32[2], B: dace.int32[1]): + match (A[0], A[1]): + case (0, x): + B[0] = x + case _: + B[0] = -1 + + stree = match_prog.to_schedule_tree() + + assert not any(isinstance(c, tn.PythonCallbackNode) for c in stree.preorder_traversal()) + assert any(isinstance(c, tn.IfScope) for c in stree.children) + assert any(isinstance(c, tn.ElseScope) for c in stree.children) + + +def test_match_sequence_guard_support_is_native(): + + @dace.program + def match_prog(A: dace.int32[2], B: dace.int32[1]): + match (A[0], A[1]): + case (x, y) if x < y and y > 0: + B[0] = x + y + case _: + B[0] = -1 + + stree = match_prog.to_schedule_tree() + + assert not any(isinstance(c, tn.PythonCallbackNode) for c in stree.preorder_traversal()) + if_scopes = [c for c in stree.children if isinstance(c, tn.IfScope)] + assert len(if_scopes) == 1 + assert 'len(' in if_scopes[0].condition.as_string + assert '__stree_tmp[0]' in if_scopes[0].condition.as_string + assert '__stree_tmp[1]' in if_scopes[0].condition.as_string + + +def test_match_mapping_case_forces_callback_for_whole_match(): + + @dace.program + def match_prog(A: dace.int32[2], B: dace.int32[1]): + match (A[0], A[1]): + case (0, x): + B[0] = x + case {'x': x}: + B[0] = x + case _: + B[0] = -1 + + stree = match_prog.to_schedule_tree() + + callbacks = [c for c in stree.children if isinstance(c, tn.PythonCallbackNode)] + assert len(callbacks) >= 1 + assert callbacks[0].reason == 'match/case' + + +def test_match_class_case_forces_callback_for_whole_match(): + + class Pair: + + def __init__(self, x: int, y: int): + self.x = x + self.y = y + + pair = Pair(1, 2) + + @dace.program + def match_prog(B: dace.int32[1]): + match pair: + case Pair(x, y): + B[0] = x + y + case _: + B[0] = -1 + + stree = match_prog.to_schedule_tree() + + callbacks = [c for c in stree.children if isinstance(c, tn.PythonCallbackNode)] + assert len(callbacks) >= 1 + assert callbacks[0].reason == 'match/case' + + +def test_class_def_is_rejected(): + + @dace.program + def classdef_prog(A: dace.float64[10]): + + class Foo: + x = 1 + + A[0] = Foo.x + return A + + with pytest.raises(DaceSyntaxError, match='Nested class definitions are unsupported'): + classdef_prog.to_schedule_tree() + + +def test_global_traces_container(): + """global x where x is a known global should bind, not callback.""" + some_global_array = np.zeros(10, dtype=np.float64) + + @dace.program + def global_prog(A: dace.float64[10]): + # global is typically used in nested scopes; test that it doesn't error + for i in range(10): + A[i] = some_global_array[i] + return A + + # Should not raise + stree = global_prog.to_schedule_tree() + assert isinstance(stree, tn.ScheduleTreeRoot) + + +def test_global_untraceable_callback(): + + @dace.program + def global_prog(A: dace.float64[10]): + global missing_name + A[0] = 1.0 + + stree = global_prog.to_schedule_tree() + + callbacks = [c for c in stree.children if isinstance(c, tn.PythonCallbackNode)] + assert len(callbacks) >= 1 + assert callbacks[0].reason == 'global scope' + + +def test_top_level_global_reassignment_emits_reassign_external(): + globals()['__schedule_tree_global_reassign'] = np.zeros(10, dtype=np.float64) + + try: + + @dace.program + def global_prog(A: dace.float64[10]): + global __schedule_tree_global_reassign + __schedule_tree_global_reassign = A + return A + + stree = global_prog.to_schedule_tree() + + finally: + del globals()['__schedule_tree_global_reassign'] + + reassigns = [node for node in stree.children if isinstance(node, tn.ReassignExternalNode)] + assert len(reassigns) == 1 + assert reassigns[0].scope == 'global' + assert reassigns[0].name == '__schedule_tree_global_reassign' + + +def test_top_level_nonlocal_reassignment_emits_reassign_external(): + + def make_prog(): + captured = np.zeros(10, dtype=np.float64) + + @dace.program + def nonlocal_prog(A: dace.float64[10]): + nonlocal captured + captured = A + return A + + return nonlocal_prog + + stree = make_prog().to_schedule_tree() + + reassigns = [node for node in stree.children if isinstance(node, tn.ReassignExternalNode)] + assert len(reassigns) == 1 + assert reassigns[0].scope == 'nonlocal' + assert reassigns[0].name == 'captured' + + +def test_decorated_nested_funcdef_produces_callback(): + + def passthrough(fn): + return fn + + @dace.program + def nested_prog(A: dace.float64[10]): + + @passthrough + def helper(x): + y = x + 1 + return y + + A[0] = helper(A[0]) + return A + + stree = nested_prog.to_schedule_tree() + + callbacks = [c for c in stree.children if isinstance(c, tn.PythonCallbackNode)] + assert len(callbacks) >= 1 + assert any(c.reason == 'nested function' for c in callbacks) + + +def test_async_function_produces_callback(): + + @dace.program + def nested_prog(A: dace.float64[10]): + + async def helper(x): + return x + + A[0] = 1.0 + return A + + stree = nested_prog.to_schedule_tree() + + callbacks = [c for c in stree.children if isinstance(c, tn.PythonCallbackNode)] + assert len(callbacks) >= 1 + assert any(c.reason == 'async function' for c in callbacks) + + +def test_async_dace_program_to_schedule_tree_is_rejected(): + + @dace.program + async def async_prog(A: dace.float64[10]): + return A + + with pytest.raises(SyntaxError, match='Async @dace.program functions are unsupported'): + async_prog.to_schedule_tree() + + +def test_delete_noop_for_arrays(): + """del of a known DaCe array should be a no-op (no node emitted).""" + + @dace.program + def del_prog(A: dace.float64[10]): + tmp = dace.define_local([10], dace.float64) + tmp[:] = A[:] + del tmp + return A + + stree = del_prog.to_schedule_tree() + + # No PythonCallbackNode for 'delete' should appear + callbacks = [c for c in stree.children if isinstance(c, tn.PythonCallbackNode) and c.reason == 'delete'] + assert len(callbacks) == 0 + + +def test_dynamic_context_manager_produces_callback(): + + @contextlib.contextmanager + def guard(value): + if value > 0: + yield + else: + yield + + @dace.program + def with_prog(A: dace.float64[10]): + with guard(A[0]): + A[1] = A[0] + return A + + stree = with_prog.to_schedule_tree() + + callbacks = [node for node in stree.preorder_traversal() if isinstance(node, tn.PythonCallbackNode)] + assert len(callbacks) == 1 + assert callbacks[0].reason == 'context manager' + assert callbacks[0].code.as_string.startswith('with guard(A[0]):') + + +def test_raise_produces_callback(): + + @dace.program + def raise_prog(A: dace.float64[10]): + raise ValueError("test") + + stree = raise_prog.to_schedule_tree() + + raise_nodes = [node for node in stree.children if isinstance(node, tn.RaiseNode)] + assert len(raise_nodes) == 1 + assert raise_nodes[0].exception_type is not None + assert raise_nodes[0].exception_type.as_string == 'ValueError' + assert [argument.as_string.strip("\"'") for argument in raise_nodes[0].args] == ['test'] + + +def test_dynamic_raise_produces_callback_when_supported(): + + @dace.program + def raise_prog(A: dace.float64[10]): + exc_type = ValueError + raise exc_type("test") + + with dace.config.set_temporary('frontend', 'raise_statements', value='support'): + stree = raise_prog.to_schedule_tree() + + callbacks = [node for node in stree.preorder_traversal() if isinstance(node, tn.PythonCallbackNode)] + assert len(callbacks) == 1 + assert callbacks[0].reason == 'raise' + + +def test_dynamic_raise_can_be_ignored(): + + @dace.program + def raise_prog(A: dace.float64[10]): + exc_type = ValueError + raise exc_type("test") + return A + + with dace.config.set_temporary('frontend', 'raise_statements', value='ignore_dynamic'): + stree = raise_prog.to_schedule_tree() + + assert not any(isinstance(node, tn.RaiseNode) for node in stree.preorder_traversal()) + assert not any( + isinstance(node, tn.PythonCallbackNode) and node.reason == 'raise' for node in stree.preorder_traversal()) + assert isinstance(stree.children[-1], tn.ReturnNode) + + +def test_raise_can_be_ignored_entirely(): + + @dace.program + def raise_prog(A: dace.float64[10]): + raise ValueError("test") + return A + + with dace.config.set_temporary('frontend', 'raise_statements', value='ignore_all'): + stree = raise_prog.to_schedule_tree() + + assert not any(isinstance(node, tn.RaiseNode) for node in stree.preorder_traversal()) + assert not any( + isinstance(node, tn.PythonCallbackNode) and node.reason == 'raise' for node in stree.preorder_traversal()) + assert isinstance(stree.children[-1], tn.ReturnNode) + + +def test_raise_from_is_rejected_before_policy_fallback(): + + @dace.program + def raise_prog(A: dace.float64[10]): + raise ValueError('outer') from ValueError('inner') + + with dace.config.set_temporary('frontend', 'raise_statements', value='ignore_all'): + with pytest.raises(DaceSyntaxError, match='raise from'): + raise_prog.to_schedule_tree() + + +def test_named_expr_desugared(): + """Walrus operator should be desugared before reaching schedule tree builder.""" + + @dace.program + def walrus_prog(A: dace.float64[10]): + if (x := A[0]) > 0: + A[1] = x + return A + + stree = walrus_prog.to_schedule_tree() + + # The schedule tree should have an assignment before the if, not a NamedExpr + # x = A[0] comes first, then if x > 0: ... + assert isinstance(stree, tn.ScheduleTreeRoot) + # Should not crash — that's the main verification + + +def test_singleton_subscript_assignment_scalarized(): + + @dace.program + def scalar_prog(x: dace.int32[20]): + idx = x[0] + + stree = scalar_prog.to_schedule_tree() + + assert [type(child) for child in stree.children] == [tn.TaskletNode] + assert stree.children[0].node.code.as_string == 'idx = x[0]' + assert str(stree.children[0].in_memlets['in0'].subset) == '0' + assert str(stree.children[0].out_memlets['out'].subset) == '0' + assert isinstance(stree.containers['idx'], dace.data.Scalar) + + +def test_singleton_negative_subscript_assignment_canonicalized(): + n = dace.symbol('n') + + @dace.program + def scalar_prog(x: dace.int32[n]): + idx = x[-1] + + stree = scalar_prog.to_schedule_tree() + + assert [type(child) for child in stree.children] == [tn.TaskletNode] + assert stree.children[0].node.code.as_string == 'idx = x[(n - 1)]' + assert str(stree.children[0].in_memlets['in0'].subset) == 'n - 1' + assert str(stree.children[0].out_memlets['out'].subset) == '0' + assert isinstance(stree.containers['idx'], dace.data.Scalar) + + +def test_singleton_symbolic_negative_subscript_assignment_canonicalized(): + n = dace.symbol('n') + i = dace.symbol('i', integer=True, positive=True) + + @dace.program + def scalar_prog(x: dace.int32[n]): + idx = x[-i] + + stree = scalar_prog.to_schedule_tree() + + assert [type(child) for child in stree.children] == [tn.TaskletNode] + assert stree.children[0].node.code.as_string == 'idx = x[(n - i)]' + assert (dace.symbolic.pystr_to_symbolic(str( + stree.children[0].in_memlets['in0'].subset)) == dace.symbolic.pystr_to_symbolic('n - i')) + assert str(stree.children[0].out_memlets['out'].subset) == '0' + + +def test_singleton_runtime_negative_subscript_assignment_uses_pyindex_when_enabled(): + n = dace.symbol('n') + i = dace.symbol('i', integer=True) + + @dace.program + def scalar_prog(x: dace.int32[n]): + idx = x[i] + + with dace.config.set_temporary('frontend', 'runtime_negative_indices', value=True): + stree = scalar_prog.to_schedule_tree() + + assert [type(child) for child in stree.children] == [tn.TaskletNode] + assert stree.children[0].node.code.as_string == 'idx = x[i]' + assert (dace.symbolic.pystr_to_symbolic(str( + stree.children[0].in_memlets['in0'].subset)) == dace.symbolic.pystr_to_symbolic('pyindex(i, n)')) + assert str(stree.children[0].out_memlets['out'].subset) == '0' + + +def test_short_circuit_condition_keeps_nested_index_in_guard(): + + @dace.program + def guard_prog(A: dace.int32[20], b: dace.int32[20], out: dace.int32[1], flag: dace.bool_, i: dace.int32): + if flag and A[b[i]] == 0: + out[0] = 1 + + stree = guard_prog.to_schedule_tree() + + assert [type(child) for child in stree.children] == [tn.IfScope] + assert 'A[b[i]]' in stree.children[0].condition.as_string + assert '__stree_idx' not in stree.children[0].condition.as_string + assert isinstance(stree.children[0].children[0], tn.TaskletNode) + + +def test_while_with_hoisted_index_rewritten_to_guarded_loop(): + + @dace.program + def loop_prog(A: dace.int32[20], b: dace.int32[20], i: dace.int32): + while A[b[i]] == 0: + i += 1 + + stree = loop_prog.to_schedule_tree() + + assert [type(child) for child in stree.children] == [tn.LoopScope] + assert stree.children[0].loop.loop_condition.as_string == 'True' + assert isinstance(stree.children[0].children[0], tn.TaskletNode) + assert stree.children[0].children[0].node.code.as_string == '__stree_idx = b[i]' + assert isinstance(stree.children[0].children[1], tn.IfScope) + assert 'not (A[__stree_idx] == 0)' in stree.children[0].children[1].condition.as_string + assert isinstance(stree.children[0].children[1].children[0], tn.BreakNode) + + +def test_while_else_with_hoisted_index_falls_back_to_callback(): + + @dace.program + def loop_prog(A: dace.int32[20], b: dace.int32[20], out: dace.int32[1], i: dace.int32): + while A[b[i]] == 0: + i += 1 + else: + out[0] = 1 + + stree = loop_prog.to_schedule_tree() + + callbacks = [node for node in stree.preorder_traversal() if isinstance(node, tn.PythonCallbackNode)] + assert len(callbacks) == 1 + assert callbacks[0].reason == 'while loop test outlining with else' + assert callbacks[0].code.as_string.startswith('while (A[b[i]] == 0):') + + +def test_while_else_without_hoisting_lowers_natively(): + + @dace.program + def loop_prog(A: dace.int32[20], out: dace.int32[1], i: dace.int32): + while i < 3: + i += 1 + else: + out[0] = 1 + + stree = loop_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.LoopScope) + assert stree.children[0].loop.loop_condition.as_string == '(i < 3)' + assert isinstance(stree.children[1], tn.ElseScope) + assert isinstance(stree.children[1].children[0], tn.TaskletNode) + + +def test_for_else_lowers_natively(): + + @dace.program + def loop_prog(A: dace.int32[20], out: dace.int32[1]): + for i in range(3): + A[i] = A[i] + else: + out[0] = 1 + + stree = loop_prog.to_schedule_tree() + + assert isinstance(stree.children[0], tn.LoopScope) + assert stree.children[0].loop.loop_condition.as_string == '(i < 3)' + assert isinstance(stree.children[1], tn.ElseScope) + assert isinstance(stree.children[1].children[0], tn.TaskletNode) + + +def test_schedule_tree_distinguishes_list_and_tuple_indices(): + + @dace.program + def list_prog(A: dace.float64[5, 6]): + tmp = A[[1, 2]] + return tmp + + @dace.program + def tuple_prog(A: dace.float64[5, 6]): + tmp = A[(1, 2)] + return tmp + + list_stree = list_prog.to_schedule_tree() + tuple_stree = tuple_prog.to_schedule_tree() + + assert isinstance(list_stree.containers['tmp'], dace.data.Array) + assert tuple(list_stree.containers['tmp'].shape) == (2, 6) + assert isinstance(tuple_stree.containers['tmp'], dace.data.Scalar) + assert tuple(tuple_stree.containers['tmp'].shape) == (1, ) + assert isinstance(list_stree.children[0], tn.TaskletNode) + assert str(list_stree.children[0].out_memlets['out'].subset) == '0:2, 0:6' + assert isinstance(tuple_stree.children[0], tn.TaskletNode) + assert str(tuple_stree.children[0].out_memlets['out'].subset) == '0' + assert not any(isinstance(node, tn.PythonCallbackNode) for node in list_stree.preorder_traversal()) + assert not any(isinstance(node, tn.PythonCallbackNode) for node in tuple_stree.preorder_traversal()) + + +def test_schedule_tree_distinguishes_list_and_tuple_indices_with_symbolic_shape(): + n = dace.symbol('n') + + @dace.program + def list_prog(A: dace.float64[5, n]): + tmp = A[[1, 2]] + return tmp + + @dace.program + def tuple_prog(A: dace.float64[5, n]): + tmp = A[(1, 2)] + return tmp + + list_stree = list_prog.to_schedule_tree() + tuple_stree = tuple_prog.to_schedule_tree() + + assert tuple(list_stree.containers['tmp'].shape) == (2, n) + assert str(list_stree.children[0].out_memlets['out'].subset) == '0:2, 0:n' + assert isinstance(tuple_stree.containers['tmp'], dace.data.Scalar) + + +def test_schedule_tree_symbolic_static_slice_shape(): + n = dace.symbol('n') + + @dace.program + def slice_prog(A: dace.float64[n]): + tmp = A[1:n:2] + return tmp + + stree = slice_prog.to_schedule_tree() + + assert isinstance(stree.containers['tmp'], dace.data.Array) + assert str(stree.containers['tmp'].shape[0]) == 'ceiling(n/2 - 1/2)' + assert isinstance(stree.children[0], tn.ViewNode) + + +def test_comprehension_desugaring(): + """Comprehensions should be desugared to explicit loops.""" + + @dace.program + def comp_prog(A: dace.float64[8]): + tmp = [A[i] for i in range(4)] + return tmp + + stree = comp_prog.to_schedule_tree() + + # After desugaring, we should see loop constructs instead of a single TaskletNode + # Check that it at least doesn't crash and produces a valid tree + assert isinstance(stree, tn.ScheduleTreeRoot) + + +def test_generator_immediate_consumption_desugaring(): + + @dace.program + def gen_prog(A: dace.float64[8]): + total = sum(x for x in A) + return total + + stree = gen_prog.to_schedule_tree() + + assert isinstance(stree, tn.ScheduleTreeRoot) + assert not any(isinstance(child, tn.PythonCallbackNode) for child in stree.children) + assert any(isinstance(child, tn.LoopScope) for child in stree.children) + + +def test_generic_visit_warns(): + """generic_visit should emit a warning for truly unknown node types.""" + import warnings + from dace.frontend.python.schedule_tree_frontend import PythonScheduleTreeBuilder + + # Verify that generic_visit is invoked by the builder when it encounters + # an AST statement it doesn't have a visitor for, and that it wraps the + # result as a PythonCallbackNode. + # We test this indirectly: the hardened generic_visit emits a warning, + # which we verify is called via monkeypatching. + called = [] + original_generic_visit = PythonScheduleTreeBuilder.generic_visit + + def patched_generic_visit(self, node): + called.append(type(node).__name__) + return original_generic_visit(self, node) + + PythonScheduleTreeBuilder.generic_visit = patched_generic_visit + try: + + @dace.program + def simple(A: dace.float64[10]): + A[0] = 1.0 + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + stree = simple.to_schedule_tree() + finally: + PythonScheduleTreeBuilder.generic_visit = original_generic_visit + + # If no unknown nodes were encountered, that's expected for simple code. + # The key test is: import the builder, confirm generic_visit exists and warns. + assert hasattr(PythonScheduleTreeBuilder, 'generic_visit') + assert isinstance(stree, tn.ScheduleTreeRoot) + + +def test_comprehensive_coverage(): + """Verify every ast.stmt subclass has a visitor or preprocessing handler.""" + import sys + + # All statement node types in the current Python version + all_stmt_types = set() + for name in dir(ast): + cls = getattr(ast, name) + if isinstance(cls, type) and issubclass(cls, ast.stmt) and cls is not ast.stmt: + all_stmt_types.add(name) + + # Statement types handled by the schedule tree builder + from dace.frontend.python.schedule_tree_frontend import PythonScheduleTreeBuilder + builder_visitors = set() + for name in dir(PythonScheduleTreeBuilder): + if name.startswith('visit_'): + builder_visitors.add(name[6:]) + + # Statement types handled by preprocessing (desugared or removed) + preprocessing_handled = { + 'With', + 'AsyncWith', # ContextManagerInliner + 'Assert', # Removed/evaluated + 'AsyncFor', # Disallowed + 'TypeAlias', # TypeAliasResolver + } + + # All explicitly handled types + handled = builder_visitors | preprocessing_handled + + # Find unhandled statement types + unhandled = all_stmt_types - handled + + # Some types might not exist in all Python versions + expected_unhandled = set() + if sys.version_info < (3, 10): + expected_unhandled.add('Match') + if sys.version_info < (3, 11): + expected_unhandled.add('TryStar') + actual_unhandled = unhandled - expected_unhandled + assert not actual_unhandled, f'Unhandled AST statement types: {actual_unhandled}' + + +def test_type_alias_is_compile_time_only_in_schedule_tree(temp_python_module): + if sys.version_info < (3, 12): + pytest.skip('Type alias statements require Python 3.12+') + + with temp_python_module(''' +import dace + +@dace.program +def prog(A: dace.float32[4]): + type dtype = dace.float32[4] + tmp: dtype = A + return tmp +''', + module_name_prefix='dace_schedule_tree_typealias') as module: + stree = module.prog.to_schedule_tree() + + assert 'tmp' in stree.containers + assert isinstance(stree.containers['tmp'], dace.data.Array) + assert stree.containers['tmp'].dtype == dace.float32 + assert tuple(stree.containers['tmp'].shape) == (4, ) + assert not any( + isinstance(node, tn.PythonCallbackNode) and node.reason == 'unhandled TypeAlias' + for node in stree.preorder_traversal()) + + +def test_generic_type_alias_is_rejected_in_schedule_tree(temp_python_module): + if sys.version_info < (3, 12): + pytest.skip('Type alias statements require Python 3.12+') + + with temp_python_module(''' +import dace + +@dace.program +def prog(A: dace.float32[4]): + type dtype[T] = T + return A +''', + module_name_prefix='dace_schedule_tree_typealias') as module: + with pytest.raises(DaceSyntaxError, match='Generic type aliases'): + module.prog.to_schedule_tree() + + +def test_type_var_tuple_alias_is_rejected_in_schedule_tree(temp_python_module): + if sys.version_info < (3, 12): + pytest.skip('Type alias statements require Python 3.12+') + + with temp_python_module(''' +import dace + +@dace.program +def prog(A: dace.float32[4]): + type dtype[*Ts] = tuple[*Ts] + return A +''', + module_name_prefix='dace_schedule_tree_typealias') as module: + with pytest.raises(DaceSyntaxError, match='Generic type aliases'): + module.prog.to_schedule_tree() + + +if __name__ == '__main__': + # pytest.main([__file__]) + test_python_frontend_schedule_tree_structure_scalar_field_assignment_errors_to_use_pythonclass() diff --git a/tests/python_frontend/schedule_tree/registry_parity_test.py b/tests/python_frontend/schedule_tree/registry_parity_test.py new file mode 100644 index 0000000000..e4459ff77e --- /dev/null +++ b/tests/python_frontend/schedule_tree/registry_parity_test.py @@ -0,0 +1,472 @@ +import importlib +import pkgutil +import pytest + +import dace +import numpy as np + +from dace import dtypes +from dace.frontend.common import op_repository as oprepo +import dace.frontend.python.replacements as replacements_pkg + +_KNOWN_OPTIONAL_DEPENDENCIES = {'torch', 'onnx'} +_REPLACEMENTS_IMPORTED = False + + +def _import_replacement_modules() -> None: + global _REPLACEMENTS_IMPORTED + if _REPLACEMENTS_IMPORTED: + return + + for module in pkgutil.iter_modules(replacements_pkg.__path__): + try: + importlib.import_module(f'{replacements_pkg.__name__}.{module.name}') + except ModuleNotFoundError as exc: + missing = (exc.name or '').split('.')[0] + if missing not in _KNOWN_OPTIONAL_DEPENDENCIES: + raise + _REPLACEMENTS_IMPORTED = True + + +def _is_numpy_ufunc_name(name: str) -> bool: + parts = name.split('.') + if len(parts) < 2 or parts[0] != 'numpy': + return False + + value = np + for part in parts[1:]: + value = getattr(value, part, None) + if value is None: + return False + return isinstance(value, np.ufunc) + + +def _function_has_inference_coverage(name: str) -> bool: + if name in oprepo.Replacements._dtype_rep: + return True + if oprepo.Replacements.get_ufunc_descriptor_inference('ufunc') is None: + return False + return _is_numpy_ufunc_name(name) + + +def _method_has_inference_coverage(key) -> bool: + return key in oprepo.Replacements._dtype_method_rep + + +def _attribute_has_inference_coverage(key) -> bool: + return key in oprepo.Replacements._dtype_attr_rep + + +def _ufunc_has_inference_coverage(name: str) -> bool: + return name in oprepo.Replacements._dtype_ufunc_rep + + +def _operator_has_inference_coverage(key) -> bool: + left_class, right_class, optype = key + return (left_class, right_class, + optype) in oprepo.Replacements._dtype_op_rep or (None, None, optype) in oprepo.Replacements._dtype_op_rep + + +def test_ufunc_descriptor_registry_parity(): + _import_replacement_modules() + assert set(oprepo.Replacements._ufunc_rep) == set(oprepo.Replacements._dtype_ufunc_rep) + + +def test_ufunc_descriptor_inference_shapes(): + _import_replacement_modules() + infer_ufunc = oprepo.Replacements.get_ufunc_descriptor_inference('ufunc') + infer_reduce = oprepo.Replacements.get_ufunc_descriptor_inference('reduce') + infer_accumulate = oprepo.Replacements.get_ufunc_descriptor_inference('accumulate') + infer_outer = oprepo.Replacements.get_ufunc_descriptor_inference('outer') + + left = dace.data.Array(dace.float32, [4, 1], transient=True) + right = dace.data.Array(dace.float32, [1, 5], transient=True) + vector = dace.data.Array(dace.float32, [4, 5], transient=True) + + add_result = infer_ufunc({'A': left, 'B': right}, 'add', 'A', 'B') + assert isinstance(add_result, dace.data.Array) + assert tuple(add_result.shape) == (4, 5) + assert add_result.dtype == dace.float32 + + divmod_result = infer_ufunc({'A': vector, 'B': vector}, 'divmod', 'A', 'B') + assert isinstance(divmod_result, tuple) + assert len(divmod_result) == 2 + assert all(isinstance(result, dace.data.Array) for result in divmod_result) + assert all(tuple(result.shape) == (4, 5) for result in divmod_result) + + reduce_result = infer_reduce({'A': vector}, 'add', 'A') + assert isinstance(reduce_result, dace.data.Array) + assert tuple(reduce_result.shape) == (5, ) + + accumulate_result = infer_accumulate({'A': vector}, 'add', 'A') + assert isinstance(accumulate_result, dace.data.Array) + assert tuple(accumulate_result.shape) == (4, 5) + + outer_result = infer_outer({'A': left, 'B': right}, 'add', 'A', 'B') + assert isinstance(outer_result, dace.data.Array) + assert tuple(outer_result.shape) == (4, 1, 1, 5) + + +def test_operator_descriptor_dispatch_uses_operand_categories(): + _import_replacement_modules() + generic_matmul = oprepo.Replacements.get_operator_descriptor_inference( + 'MatMult', dace.data.Array(dace.float32, [4, 3], transient=True), + dace.data.Array(dace.float32, [3, 2], transient=True)) + storage_cast = oprepo.Replacements.get_operator_descriptor_inference( + 'MatMult', dace.data.Array(dace.float32, [4], transient=True), dtypes.StorageType.GPU_Global) + + assert generic_matmul is not None + assert storage_cast is not None + assert generic_matmul is not storage_cast + + source = dace.data.Array(dace.float32, [4], transient=True, storage=dtypes.StorageType.Default) + result = storage_cast(source, dtypes.StorageType.GPU_Global) + assert isinstance(result, dace.data.Array) + assert result.storage == dtypes.StorageType.GPU_Global + assert tuple(result.shape) == (4, ) + + +def test_recent_alias_and_method_inference_regressions(): + _import_replacement_modules() + + infer_conj = oprepo.Replacements.get_descriptor_inference('numpy.conj') + infer_exp = oprepo.Replacements.get_descriptor_inference('exp') + infer_floor = oprepo.Replacements.get_descriptor_inference('math.floor') + infer_max = oprepo.Replacements.get_descriptor_inference('max') + infer_min = oprepo.Replacements.get_descriptor_inference('min') + infer_float32 = oprepo.Replacements.get_descriptor_inference('float32') + infer_numpy_int16 = oprepo.Replacements.get_descriptor_inference('numpy.int16') + infer_dace_bool = oprepo.Replacements.get_descriptor_inference('dace.bool') + infer_cart_create = oprepo.Replacements.get_descriptor_inference('dace.comm.Cart_create') + infer_cart_sub = oprepo.Replacements.get_descriptor_inference('dace.comm.Cart_sub') + infer_clip = oprepo.Replacements.get_descriptor_inference('numpy.clip') + infer_bcast = oprepo.Replacements.get_descriptor_inference('dace.comm.Bcast') + infer_isend = oprepo.Replacements.get_descriptor_inference('dace.comm.Isend') + infer_irecv = oprepo.Replacements.get_descriptor_inference('dace.comm.Irecv') + infer_subarray = oprepo.Replacements.get_descriptor_inference('dace.comm.Subarray') + infer_bcscatter = oprepo.Replacements.get_descriptor_inference('dace.comm.BCScatter') + infer_distr_matmult = oprepo.Replacements.get_descriptor_inference('dace.distr.MatMult') + infer_fft = oprepo.Replacements.get_descriptor_inference('numpy.fft.fft') + infer_ifft = oprepo.Replacements.get_descriptor_inference('numpy.fft.ifft') + infer_dot = oprepo.Replacements.get_descriptor_inference('numpy.dot') + infer_einsum = oprepo.Replacements.get_descriptor_inference('numpy.einsum') + infer_inv = oprepo.Replacements.get_descriptor_inference('numpy.linalg.inv') + infer_rot90 = oprepo.Replacements.get_descriptor_inference('numpy.rot90') + infer_solve = oprepo.Replacements.get_descriptor_inference('numpy.linalg.solve') + infer_tensordot = oprepo.Replacements.get_descriptor_inference('numpy.tensordot') + infer_cholesky = oprepo.Replacements.get_descriptor_inference('numpy.linalg.cholesky') + infer_real = oprepo.Replacements.get_descriptor_inference('numpy.real') + infer_full_like = oprepo.Replacements.get_descriptor_inference('numpy.full_like') + infer_identity = oprepo.Replacements.get_descriptor_inference('numpy.identity') + infer_select = oprepo.Replacements.get_descriptor_inference('numpy.select') + infer_transpose = oprepo.Replacements.get_descriptor_inference('transpose') + infer_where = oprepo.Replacements.get_descriptor_inference('numpy.where') + infer_sum = oprepo.Replacements.get_descriptor_inference('sum') + infer_intracomm_create_cart = oprepo.Replacements.get_method_descriptor_inference('Intracomm', 'Create_cart') + infer_intracomm_allreduce = oprepo.Replacements.get_method_descriptor_inference('Intracomm', 'Allreduce') + infer_processgrid_sub = oprepo.Replacements.get_method_descriptor_inference('ProcessGrid', 'Sub') + infer_processgrid_isend = oprepo.Replacements.get_method_descriptor_inference('ProcessGrid', 'Isend') + infer_slice = oprepo.Replacements.get_descriptor_inference('slice') + infer_define_stream = oprepo.Replacements.get_descriptor_inference('dace.define_stream') + infer_define_streamarray = oprepo.Replacements.get_descriptor_inference('dace.define_streamarray') + infer_elementwise = oprepo.Replacements.get_descriptor_inference('dace.elementwise') + infer_reduce = oprepo.Replacements.get_descriptor_inference('dace.reduce') + infer_cupy_full = oprepo.Replacements.get_descriptor_inference('cupy.full') + infer_cupy_empty_like = oprepo.Replacements.get_descriptor_inference('cupy.empty_like') + infer_fill = oprepo.Replacements.get_method_descriptor_inference('Array', 'fill') + infer_view = oprepo.Replacements.get_method_descriptor_inference('Array', 'view') + + complex_vector = dace.data.Array(dace.complex64, [4], transient=True) + cond = dace.data.Array(dace.bool_, [2, 1], transient=True) + matrix = dace.data.Array(dace.float32, [2, 3], transient=True) + square = dace.data.Array(dace.float64, [4, 4], transient=True) + rhs = dace.data.Array(dace.float64, [4], transient=True) + + conj_result = infer_conj({'A': complex_vector}, 'A') + assert isinstance(conj_result, dace.data.Array) + assert conj_result.dtype == dace.complex64 + assert tuple(conj_result.shape) == (4, ) + + exp_result = infer_exp({'A': matrix}, 'A') + assert isinstance(exp_result, dace.data.Array) + assert exp_result.dtype == dace.float32 + assert tuple(exp_result.shape) == (2, 3) + + floor_result = infer_floor({'A': matrix}, 'A') + assert isinstance(floor_result, dace.data.Array) + assert floor_result.dtype == dtypes.typeclass(int) + assert tuple(floor_result.shape) == (2, 3) + + float32_result = infer_float32({'A': matrix}, 'A') + assert isinstance(float32_result, dace.data.Array) + assert float32_result.dtype == dace.float32 + assert tuple(float32_result.shape) == (2, 3) + + numpy_int16_result = infer_numpy_int16({'A': matrix}, 'A') + assert isinstance(numpy_int16_result, dace.data.Array) + assert numpy_int16_result.dtype == dace.int16 + assert tuple(numpy_int16_result.shape) == (2, 3) + + cart_create_result = infer_cart_create({}, [2, 2]) + assert isinstance(cart_create_result, dace.data.Scalar) + assert isinstance(cart_create_result.dtype, dtypes.pyobject) + + cart_sub_result = infer_cart_sub({}, 'pgrid', [True, False]) + assert isinstance(cart_sub_result, dace.data.Scalar) + assert isinstance(cart_sub_result.dtype, dtypes.pyobject) + + dace_bool_result = infer_dace_bool({'A': matrix}, 'A') + assert isinstance(dace_bool_result, dace.data.Array) + assert dace_bool_result.dtype == dace.bool_ + assert tuple(dace_bool_result.shape) == (2, 3) + + assert infer_bcast({'A': matrix}, 'A') == () + + isend_result = infer_isend({'A': matrix}, 'A', 0, 0) + assert isinstance(isend_result, dace.data.Array) + assert tuple(isend_result.shape) == (1, ) + assert isinstance(isend_result.dtype, dtypes.opaque) + + assert infer_isend({'A': matrix, 'req': isend_result}, 'A', 0, 0, request='req') == () + + irecv_result = infer_irecv({'A': matrix}, 'A', 0, 0) + assert isinstance(irecv_result, dace.data.Array) + assert tuple(irecv_result.shape) == (1, ) + assert isinstance(irecv_result.dtype, dtypes.opaque) + + subarray_result = infer_subarray({'A': square}, 'A', [2, 2]) + assert isinstance(subarray_result, dace.data.Scalar) + assert isinstance(subarray_result.dtype, dtypes.pyobject) + + bcscatter_result = infer_bcscatter({'A': square, 'B': square}, 'A', 'B', [2, 2]) + assert isinstance(bcscatter_result, tuple) + assert len(bcscatter_result) == 2 + assert all(isinstance(result, dace.data.Array) for result in bcscatter_result) + assert all(result.dtype == dace.int32 for result in bcscatter_result) + assert all(tuple(result.shape) == (9, ) for result in bcscatter_result) + + distr_matmult_result = infer_distr_matmult({'A': square, 'B': square}, 'A', 'B', (4, 4, 4)) + assert isinstance(distr_matmult_result, dace.data.Array) + assert distr_matmult_result.dtype == dace.float64 + assert tuple(distr_matmult_result.shape) == (4, 4) + + clip_result = infer_clip({'A': matrix}, 'A', 1.0, 3.0) + assert isinstance(clip_result, dace.data.Array) + assert clip_result.dtype == dace.float32 + assert tuple(clip_result.shape) == (2, 3) + + clip_max_only_result = infer_clip({'A': matrix}, 'A', None, 3.0) + assert isinstance(clip_max_only_result, dace.data.Array) + assert clip_max_only_result.dtype == dace.float32 + assert tuple(clip_max_only_result.shape) == (2, 3) + + fft_result = infer_fft({'A': matrix}, 'A') + assert isinstance(fft_result, dace.data.Array) + assert fft_result.dtype == dace.complex64 + assert tuple(fft_result.shape) == (2, 3) + + dot_result = infer_dot({'A': square, 'B': square}, 'A', 'B') + assert isinstance(dot_result, dace.data.Array) + assert dot_result.dtype == dace.float64 + assert tuple(dot_result.shape) == (4, 4) + + einsum_result = infer_einsum({'A': square, 'B': square}, 'ik,kj->ij', 'A', 'B') + assert isinstance(einsum_result, dace.data.Array) + assert einsum_result.dtype == dace.float64 + assert tuple(einsum_result.shape) == (4, 4) + + dim_a, dim_b, dim_c, dim_d, dim_e = (dace.symbol(name) for name in ('dim_a', 'dim_b', 'dim_c', 'dim_d', 'dim_e')) + multi_contract_left = dace.data.Array(dace.float64, [dim_a, dim_b, dim_c, dim_d], transient=True) + multi_contract_right = dace.data.Array(dace.float64, [dim_b, dim_d, dim_c, dim_e], transient=True) + multi_contract_result = infer_einsum({ + 'A': multi_contract_left, + 'B': multi_contract_right + }, 'abcd,bdce->ae', 'A', 'B') + assert isinstance(multi_contract_result, dace.data.Array) + assert multi_contract_result.dtype == dace.float64 + assert tuple(multi_contract_result.shape) == (dim_a, dim_e) + + vec_extent = dace.symbol('vec_extent') + symbolic_vector = dace.data.Array(dace.float64, [vec_extent], transient=True) + repeated_index_result = infer_einsum({'A': symbolic_vector}, 'i->ii', 'A') + assert isinstance(repeated_index_result, dace.data.Array) + assert repeated_index_result.dtype == dace.float64 + assert tuple(repeated_index_result.shape) == (vec_extent, vec_extent) + + reduced_extent = dace.symbol('reduced_extent') + kept_extent = dace.symbol('kept_extent') + contracted_input_vector = dace.data.Array(dace.float64, [reduced_extent], transient=True) + retained_input_vector = dace.data.Array(dace.float64, [kept_extent], transient=True) + contracted_input_result = infer_einsum({ + 'A': contracted_input_vector, + 'B': retained_input_vector + }, 'j,k->k', 'A', 'B') + assert isinstance(contracted_input_result, dace.data.Array) + assert contracted_input_result.dtype == dace.float64 + assert tuple(contracted_input_result.shape) == (kept_extent, ) + + complex_matrix = dace.data.Array(dace.complex64, [2, 3], transient=True) + ifft_result = infer_ifft({'A': complex_matrix}, 'A') + assert isinstance(ifft_result, dace.data.Array) + assert ifft_result.dtype == dace.complex64 + assert tuple(ifft_result.shape) == (2, 3) + + rot90_result = infer_rot90({'A': matrix}, 'A') + assert isinstance(rot90_result, dace.data.Array) + assert rot90_result.dtype == dace.float32 + assert tuple(rot90_result.shape) == (3, 2) + + inv_result = infer_inv({'A': square}, 'A') + assert isinstance(inv_result, dace.data.Array) + assert inv_result.dtype == dace.float64 + assert tuple(inv_result.shape) == (4, 4) + + solve_result = infer_solve({'A': square, 'B': rhs}, 'A', 'B') + assert isinstance(solve_result, dace.data.Array) + assert solve_result.dtype == dace.float64 + assert tuple(solve_result.shape) == (4, ) + + tensordot_result = infer_tensordot({'A': square, 'B': square}, 'A', 'B', axes=1) + assert isinstance(tensordot_result, dace.data.Array) + assert tensordot_result.dtype == dace.float64 + assert tuple(tensordot_result.shape) == (4, 4) + + cholesky_result = infer_cholesky({'A': square}, 'A') + assert isinstance(cholesky_result, dace.data.Array) + assert cholesky_result.dtype == dace.float64 + assert tuple(cholesky_result.shape) == (4, 4) + + real_result = infer_real({'A': complex_vector}, 'A') + assert isinstance(real_result, dace.data.Array) + assert real_result.dtype == dace.float32 + assert tuple(real_result.shape) == (4, ) + + full_like_result = infer_full_like({'A': matrix}, 'A', 1.0) + assert isinstance(full_like_result, dace.data.Array) + assert full_like_result.dtype == dace.float32 + assert tuple(full_like_result.shape) == (2, 3) + + identity_result = infer_identity({}, 5, dtype=np.float64) + assert isinstance(identity_result, dace.data.Array) + assert identity_result.dtype == dace.float64 + assert tuple(identity_result.shape) == (5, 5) + + transpose_result = infer_transpose({'A': matrix}, 'A') + assert isinstance(transpose_result, dace.data.Array) + assert tuple(transpose_result.shape) == (3, 2) + + where_result = infer_where({'cond': cond, 'A': matrix}, 'cond', 'A', 1.0) + assert isinstance(where_result, dace.data.Array) + assert where_result.dtype == dace.float32 + assert tuple(where_result.shape) == (2, 3) + + assert infer_where({'cond': dace.data.Scalar(dace.bool_, transient=True)}, 'cond', 1, 2) is None + + select_result = infer_select({'cond': cond, 'A': matrix}, ['cond'], ['A'], default=1.0) + assert isinstance(select_result, dace.data.Array) + assert select_result.dtype == dace.float32 + assert tuple(select_result.shape) == (2, 3) + + pyobject_self = dace.data.Scalar(dtypes.pyobject(), transient=True) + intracomm_create_result = infer_intracomm_create_cart(pyobject_self, [2, 2]) + assert isinstance(intracomm_create_result, dace.data.Scalar) + assert isinstance(intracomm_create_result.dtype, dtypes.pyobject) + + assert infer_intracomm_allreduce(pyobject_self, None, 'A', 'MPI_SUM') == () + + processgrid_sub_result = infer_processgrid_sub(pyobject_self, [True, False]) + assert isinstance(processgrid_sub_result, dace.data.Scalar) + assert isinstance(processgrid_sub_result.dtype, dtypes.pyobject) + + processgrid_isend_result = infer_processgrid_isend(pyobject_self, 'A', 0, 0) + assert isinstance(processgrid_isend_result, dace.data.Array) + assert tuple(processgrid_isend_result.shape) == (1, ) + assert isinstance(processgrid_isend_result.dtype, dtypes.opaque) + + slice_result = infer_slice({}, 0, 5, 2) + assert isinstance(slice_result, tuple) + assert len(slice_result) == 1 + assert isinstance(slice_result[0], dace.data.Scalar) + assert isinstance(slice_result[0].dtype, dtypes.pyobject) + + define_stream_result = infer_define_stream({}, dace.float32, buffer_size=4) + assert isinstance(define_stream_result, dace.data.Stream) + assert define_stream_result.dtype == dace.float32 + assert tuple(define_stream_result.shape) == (1, ) + assert define_stream_result.buffer_size == 4 + + define_streamarray_result = infer_define_streamarray({}, [2, 3], dace.float64, buffer_size=8) + assert isinstance(define_streamarray_result, dace.data.Stream) + assert define_streamarray_result.dtype == dace.float64 + assert tuple(define_streamarray_result.shape) == (2, 3) + assert define_streamarray_result.buffer_size == 8 + + elementwise_result = infer_elementwise({'A': matrix}, 'lambda x: x + 1', 'A') + assert isinstance(elementwise_result, dace.data.Array) + assert elementwise_result.dtype == dace.float32 + assert tuple(elementwise_result.shape) == (2, 3) + + reduce_result = infer_reduce({'A': matrix}, 'lambda x, y: x + y', 'A', axis=1) + assert isinstance(reduce_result, dace.data.Array) + assert reduce_result.dtype == dace.float32 + assert tuple(reduce_result.shape) == (2, ) + + assert infer_reduce({ + 'A': matrix, + 'B': dace.data.Array(dace.float32, [3], transient=True) + }, + 'lambda x, y: x + y', + 'A', + out_array='B', + axis=0) == () + + cupy_full_result = infer_cupy_full({}, [4, 2], 3.0) + assert isinstance(cupy_full_result, dace.data.Array) + assert cupy_full_result.dtype == dace.float64 + assert tuple(cupy_full_result.shape) == (4, 2) + assert cupy_full_result.storage == dtypes.StorageType.GPU_Global + + cupy_empty_like_result = infer_cupy_empty_like({'A': matrix}, 'A', dtype=dace.float16, shape=[3, 4]) + assert isinstance(cupy_empty_like_result, dace.data.Array) + assert cupy_empty_like_result.dtype == dace.float16 + assert tuple(cupy_empty_like_result.shape) == (3, 4) + assert cupy_empty_like_result.storage == dtypes.StorageType.GPU_Global + + sum_result = infer_sum({'A': matrix}, 'A') + assert isinstance(sum_result, dace.data.Array) + assert tuple(sum_result.shape) == (3, ) + + max_result = infer_max({}, 1, 2.0, np.float32(3.0)) + assert isinstance(max_result, dace.data.Scalar) + assert max_result.dtype == dace.float32 + + min_result = infer_min({'x': dace.data.Scalar(dace.int32, transient=True)}, 'x', 5) + assert isinstance(min_result, dace.data.Scalar) + assert min_result.dtype == dace.int32 + + assert infer_fill(matrix, 7) == () + + view_result = infer_view(dace.data.Array(dace.float32, [4], transient=True), np.float16) + assert isinstance(view_result, dace.data.View) + assert view_result.dtype == dace.float16 + assert tuple(view_result.shape) == (8, ) + + +def test_runtime_registry_entries_have_inference_coverage_or_allowlisted_gap(): + _import_replacement_modules() + + missing_functions = [name for name in oprepo.Replacements._rep if not _function_has_inference_coverage(name)] + missing_methods = [key for key in oprepo.Replacements._method_rep if not _method_has_inference_coverage(key)] + missing_attributes = [key for key in oprepo.Replacements._attr_rep if not _attribute_has_inference_coverage(key)] + missing_ufuncs = [name for name in oprepo.Replacements._ufunc_rep if not _ufunc_has_inference_coverage(name)] + missing_operators = [key for key in oprepo.Replacements._oprep if not _operator_has_inference_coverage(key)] + + assert missing_functions == [], f'uncovered function inference registrations: {missing_functions}' + assert missing_methods == [], f'uncovered method inference registrations: {missing_methods}' + assert missing_attributes == [], f'uncovered attribute inference registrations: {missing_attributes}' + assert missing_ufuncs == [], f'uncovered ufunc inference registrations: {missing_ufuncs}' + assert missing_operators == [], f'uncovered operator inference registrations: {missing_operators}' + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/python_frontend/schedule_tree/structure_support_test.py b/tests/python_frontend/schedule_tree/structure_support_test.py new file mode 100644 index 0000000000..8006a4aae7 --- /dev/null +++ b/tests/python_frontend/schedule_tree/structure_support_test.py @@ -0,0 +1,225 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +from dataclasses import dataclass + +import dace +from dace import data +from dace.data.pydata import PythonClass, python_dataclass_descriptor +from dace.frontend.python.schedule_tree.structure_support import bind_target_structure, descriptor_from_structure, \ + resolve_member_access +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +def test_descriptor_from_structure_preserves_python_container_kind(): + tuple_descriptor = descriptor_from_structure( + (data.Scalar(dace.float64, transient=True), data.Scalar(dace.float64, transient=True))) + list_descriptor = descriptor_from_structure([data.Scalar(dace.float64, transient=True)]) + + assert isinstance(tuple_descriptor, dace.data.PythonTuple) + assert tuple_descriptor.dtype == dace.float64 + assert tuple(tuple_descriptor.shape) == (2, ) + assert isinstance(list_descriptor, dace.data.PythonList) + assert list_descriptor.dtype == dace.float64 + assert tuple(list_descriptor.shape) == (1, ) + + +def test_bind_target_structure_visits_starred_targets(): + target = ast.parse('head, *tail, last = value').body[0].targets[0] + seen = {} + + def _bind(name, structure): + seen[name] = structure + + matched = bind_target_structure(target, ('A', 'B', 'C', 'D'), _bind) + + assert matched is True + assert seen == {'head': 'A', 'tail': ['B', 'C'], 'last': 'D'} + + +def test_python_dataclass_descriptor_preserves_structure_vs_python_class_split(): + + @dataclass + class Inner: + x: dace.int32 + + @dataclass + class Outer: + inner: Inner + y: dace.float64 + + by_value = python_dataclass_descriptor(Outer, by_value=True) + python_object = python_dataclass_descriptor(Outer, by_value=False) + + assert isinstance(by_value, data.Structure) + assert by_value.name == 'Outer' + assert isinstance(by_value.members['inner'], data.Structure) + + assert isinstance(python_object, PythonClass) + assert python_object.name == 'Outer' + assert isinstance(python_object.members['inner'], data.Structure) + assert isinstance(python_object.members['y'], data.Scalar) + + +def test_plain_class_descriptor_preserves_structure_vs_python_class_split(): + + class Inner: + x: dace.int32 + + class Outer: + inner: Inner + y: dace.float64 + + by_value = data.Structure.from_class(Outer) + python_object = PythonClass.from_class(Outer) + + assert isinstance(by_value, data.Structure) + assert by_value.name == 'Outer' + assert isinstance(by_value.members['inner'], data.Structure) + + assert isinstance(python_object, PythonClass) + assert python_object.name == 'Outer' + assert isinstance(python_object.members['inner'], data.Structure) + assert isinstance(python_object.members['y'], data.Scalar) + + +def test_resolve_member_access_returns_named_member_path(): + Bundle = dace.data.Structure({'data': dace.float64[4]}, name='Bundle') + + access = resolve_member_access('bundle', Bundle, 'data') + + assert access is not None + assert access.data_name == 'bundle.data' + assert isinstance(access.descriptor, data.Array) + + +def test_schedule_tree_supports_structure_member_copy(): + Bundle = dace.data.Structure({'data': dace.float64[4]}, name='Bundle') + + @dace.program + def copy_member(bundle: Bundle, out: dace.float64[4]): + out[:] = bundle.data[:] + + stree = copy_member.to_schedule_tree() + + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'out' + assert stree.children[0].memlet.data == 'bundle.data' + + +def test_schedule_tree_supports_structure_member_index_read(): + Bundle = dace.data.Structure({'data': dace.float64[4]}, name='Bundle') + + @dace.program + def copy_member_index(bundle: Bundle, out: dace.float64[1]): + out[0] = bundle.data[1] + + stree = copy_member_index.to_schedule_tree() + + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'out' + assert stree.children[0].memlet.data == 'bundle.data' + assert str(stree.children[0].memlet.subset) == '1' + + +def test_schedule_tree_supports_nested_structure_member_copy(): + Outer = dace.data.Structure({'inner': dace.data.Structure({'data': dace.float64[4]}, name='Inner')}, name='Outer') + + @dace.program + def copy_member(bundle: Outer, out: dace.float64[4]): + out[:] = bundle.inner.data[:] + + stree = copy_member.to_schedule_tree() + + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'out' + assert stree.children[0].memlet.data == 'bundle.inner.data' + + +def test_schedule_tree_supports_nested_structure_member_index_read(): + Outer = dace.data.Structure({'inner': dace.data.Structure({'data': dace.float64[4]}, name='Inner')}, name='Outer') + + @dace.program + def copy_member_index(bundle: Outer, out: dace.float64[1]): + out[0] = bundle.inner.data[1] + + stree = copy_member_index.to_schedule_tree() + + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'out' + assert stree.children[0].memlet.data == 'bundle.inner.data' + assert str(stree.children[0].memlet.subset) == '1' + + +def test_schedule_tree_supports_structure_member_to_member_copy(): + Bundle = dace.data.Structure({'data': dace.float64[4]}, name='Bundle') + + @dace.program + def copy_member(dst: Bundle, src: Bundle): + dst.data[:] = src.data[:] + + stree = copy_member.to_schedule_tree() + + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'dst.data' + assert stree.children[0].memlet.data == 'src.data' + + +def test_schedule_tree_supports_structure_member_array_map_bounds(): + CSR = dace.data.Structure({ + 'indptr': dace.int32[5], + 'indices': dace.int32[8], + 'data': dace.float64[8], + }, + name='CSR') + + @dace.program + def spmv_shape(A: CSR, out: dace.float64[8]): + for row in dace.map[0:4]: + for idx in dace.map[A.indptr[row]:A.indptr[row + 1]]: + out[idx] = A.data[idx] + + stree = spmv_shape.to_schedule_tree() + + assert isinstance(stree.children[0], tn.MapScope) + assert stree.children[0].node.params == ['row'] + assert stree.children[0].node.ranges == [('0', '4', '1')] + assert isinstance(stree.children[0].children[0], tn.DynScopeCopyNode) + assert stree.children[0].children[0].target == '__stree_idx' + assert stree.children[0].children[0].memlet.data == 'A.indptr' + assert str(stree.children[0].children[0].memlet.subset) == 'row' + assert isinstance(stree.children[0].children[1], tn.DynScopeCopyNode) + assert stree.children[0].children[1].target == '__stree_idx1' + assert stree.children[0].children[1].memlet.data == 'A.indptr' + assert str(stree.children[0].children[1].memlet.subset) == 'row + 1' + inner_map = stree.children[0].children[2] + assert isinstance(inner_map, tn.MapScope) + assert inner_map.node.params == ['idx'] + assert inner_map.node.ranges == [('__stree_idx', '__stree_idx1', '1')] + assert isinstance(inner_map.children[0], tn.CopyNode) + assert inner_map.children[0].target == 'out' + assert inner_map.children[0].memlet.data == 'A.data' + assert str(inner_map.children[0].memlet.subset) == 'idx' + + +def test_schedule_tree_supports_python_class_member_copy(): + + @dataclass + class Inner: + data: dace.float64[4] + + @dataclass + class Outer: + inner: Inner + + PyOuter = python_dataclass_descriptor(Outer, by_value=False) + + @dace.program + def copy_member(bundle: PyOuter, out: dace.float64[4]): + out[:] = bundle.inner.data[:] + + stree = copy_member.to_schedule_tree() + + assert isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'out' + assert stree.children[0].memlet.data == 'bundle.inner.data' diff --git a/tests/python_frontend/schedule_tree/torch_autodiff_inference_test.py b/tests/python_frontend/schedule_tree/torch_autodiff_inference_test.py new file mode 100644 index 0000000000..9704409cab --- /dev/null +++ b/tests/python_frontend/schedule_tree/torch_autodiff_inference_test.py @@ -0,0 +1,78 @@ +import ast + +import pytest + +torch = pytest.importorskip('torch', reason='PyTorch not installed. Please install with: pip install dace[ml]') + +from dace.frontend.common import op_repository as oprepo +from dace.frontend.python.schedule_tree.type_inference import ScheduleTreeTypeInference +from dace.data.ml import ParameterArray + +pytestmark = [pytest.mark.torch, pytest.mark.autodiff] + + +def _function_def(source: str) -> ast.FunctionDef: + module = ast.parse(source) + node = module.body[0] + assert isinstance(node, ast.FunctionDef) + return node + + +def test_torch_autodiff_registry_entries_have_descriptor_inference(): + import dace.frontend.python.replacements.torch_autodiff # noqa: F401 + + assert oprepo.Replacements.get_descriptor_inference('torch.autograd.backward') is not None + assert oprepo.Replacements.get_method_descriptor_inference('Array', 'requires_grad_') is not None + assert oprepo.Replacements.get_method_self_descriptor_inference('Array', 'requires_grad_') is not None + assert oprepo.Replacements.get_method_descriptor_inference('Array', 'backward') is not None + assert oprepo.Replacements.get_attribute_descriptor_inference('ParameterArray', 'grad') is not None + + +def test_requires_grad_side_effect_enables_grad_attribute_inference(): + import dace.frontend.python.replacements.torch_autodiff # noqa: F401 + + program = _function_def('def prog(A):\n' + ' A.requires_grad_()\n' + ' grad = A.grad\n' + ' return grad\n') + + inferred = ScheduleTreeTypeInference({'torch': torch}, {'A': dace.data.Array(dace.float32, [4, 5])}).infer(program) + + grad_binding = inferred.get('grad') + assert grad_binding is not None + assert isinstance(grad_binding.descriptor, dace.data.Array) + assert not isinstance(grad_binding.descriptor, ParameterArray) + assert grad_binding.descriptor.dtype == dace.float32 + assert tuple(grad_binding.descriptor.shape) == (4, 5) + + +def test_torch_autodiff_descriptor_contracts(): + import dace.frontend.python.replacements.torch_autodiff # noqa: F401 + + infer_backward = oprepo.Replacements.get_descriptor_inference('torch.autograd.backward') + infer_requires_grad = oprepo.Replacements.get_method_descriptor_inference('Array', 'requires_grad_') + infer_requires_grad_self = oprepo.Replacements.get_method_self_descriptor_inference('Array', 'requires_grad_') + infer_backward_method = oprepo.Replacements.get_method_descriptor_inference('Array', 'backward') + infer_grad = oprepo.Replacements.get_attribute_descriptor_inference('ParameterArray', 'grad') + + array_desc = dace.data.Array(dace.float64, [3, 2], transient=True) + + assert infer_backward({'A': array_desc}, 'A') == () + assert infer_requires_grad(array_desc) == () + + parameter_desc = infer_requires_grad_self(array_desc) + assert isinstance(parameter_desc, ParameterArray) + assert parameter_desc.dtype == dace.float64 + assert tuple(parameter_desc.shape) == (3, 2) + + grad_desc = infer_grad(parameter_desc) + assert isinstance(grad_desc, dace.data.Array) + assert not isinstance(grad_desc, ParameterArray) + assert grad_desc.dtype == dace.float64 + assert tuple(grad_desc.shape) == (3, 2) + + assert infer_backward_method(array_desc) == () + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/python_frontend/structures/structure_python_test.py b/tests/python_frontend/structures/structure_python_test.py index af317be7d8..d93089e30a 100644 --- a/tests/python_frontend/structures/structure_python_test.py +++ b/tests/python_frontend/structures/structure_python_test.py @@ -449,6 +449,42 @@ def struct_recursive(A: Struct, B: Struct): assert np.allclose(B.y, A.y) +def test_struct_recursive_from_plain_class_annotation(): + + class Inner: + a: dace.float32[20] + b: dace.int32 + + def __init__(self, a, b): + self.a = a + self.b = b + + class Outer: + x: Inner + y: dace.float64[10, 10] + + def __init__(self, x, y): + self.x = x + self.y = y + + Struct = dace.data.Structure.from_class(Outer) + + @dace.program + def struct_recursive(A: Outer, B: Outer): + B.x.a[:] = A.x.a[:] + B.x.b = A.x.b + B.y[:] = A.y[:] + + A = Outer(x=Inner(a=np.random.rand(20).astype(np.float32), b=42), y=np.random.rand(10, 10).astype(np.float64)) + B = Outer(x=Inner(a=np.zeros(20, dtype=np.float32), b=0), y=np.zeros((10, 10), dtype=np.float64)) + + struct_recursive(Struct.make_argument_from_object(A), Struct.make_argument_from_object(B)) + + assert np.allclose(B.x.a, A.x.a) + assert not np.allclose(B.x.b, A.x.b) + assert np.allclose(B.y, A.y) + + if __name__ == '__main__': test_read_structure() test_write_structure() @@ -462,3 +498,4 @@ def struct_recursive(A: Struct, B: Struct): test_struct_interface() test_struct_recursive() test_struct_recursive_from_dataclass() + test_struct_recursive_from_plain_class_annotation()