diff --git a/docs/docstrings/dsl/dace/stree/optimizations/axis_merge.md b/docs/docstrings/dsl/dace/stree/optimizations/axis_merge.md
index ab6794f7..53cf06df 100644
--- a/docs/docstrings/dsl/dace/stree/optimizations/axis_merge.md
+++ b/docs/docstrings/dsl/dace/stree/optimizations/axis_merge.md
@@ -1,6 +1,6 @@
# axis_merge
-::: dsl.dace.stree.optimizations
+::: dsl.dace.stree.optimizations.axis_merge
diff --git a/docs/docstrings/dsl/dace/stree/optimizations/kernelize_maps.md b/docs/docstrings/dsl/dace/stree/optimizations/kernelize_maps.md
new file mode 100644
index 00000000..974eb637
--- /dev/null
+++ b/docs/docstrings/dsl/dace/stree/optimizations/kernelize_maps.md
@@ -0,0 +1,12 @@
+# kernelize_maps
+
+::: dsl.dace.stree.optimizations.kernelize_maps
+
+
diff --git a/docs/docstrings/dsl/dace/stree/optimizations/local_optimizations.md b/docs/docstrings/dsl/dace/stree/optimizations/local_optimizations.md
new file mode 100644
index 00000000..540b9eba
--- /dev/null
+++ b/docs/docstrings/dsl/dace/stree/optimizations/local_optimizations.md
@@ -0,0 +1,12 @@
+# local_optimizations
+
+::: dsl.dace.stree.optimizations.local_optimizations
+
+
diff --git a/docs/docstrings/dsl/dace/stree/optimizations/offgrid_conditionals.md b/docs/docstrings/dsl/dace/stree/optimizations/offgrid_conditionals.md
new file mode 100644
index 00000000..e5ab9ecf
--- /dev/null
+++ b/docs/docstrings/dsl/dace/stree/optimizations/offgrid_conditionals.md
@@ -0,0 +1,12 @@
+# offgrid_conditionals
+
+::: dsl.dace.stree.optimizations.offgrid_conditionals
+
+
diff --git a/docs/docstrings/dsl/dace/stree/optimizations/refine_transients.md b/docs/docstrings/dsl/dace/stree/optimizations/refine_transients.md
index e207bb71..4d11671e 100644
--- a/docs/docstrings/dsl/dace/stree/optimizations/refine_transients.md
+++ b/docs/docstrings/dsl/dace/stree/optimizations/refine_transients.md
@@ -1,6 +1,6 @@
# refine_transients
-::: dsl.dace.stree.optimizations
+::: dsl.dace.stree.optimizations.refine_transients
diff --git a/docs/docstrings/dsl/dace/stree/optimizations/replace_axis_symbol.md b/docs/docstrings/dsl/dace/stree/optimizations/replace_axis_symbol.md
new file mode 100644
index 00000000..ccff4aa4
--- /dev/null
+++ b/docs/docstrings/dsl/dace/stree/optimizations/replace_axis_symbol.md
@@ -0,0 +1,12 @@
+# replace_axis_symbol
+
+::: dsl.dace.stree.optimizations.replace_axis_symbol
+
+
diff --git a/docs/docstrings/dsl/dace/stree/optimizations/statistics.md b/docs/docstrings/dsl/dace/stree/optimizations/statistics.md
new file mode 100644
index 00000000..c218d368
--- /dev/null
+++ b/docs/docstrings/dsl/dace/stree/optimizations/statistics.md
@@ -0,0 +1,12 @@
+# statistics
+
+::: dsl.dace.stree.optimizations.statistics
+
+
diff --git a/external/dace b/external/dace
index d5fbadb6..7c526886 160000
--- a/external/dace
+++ b/external/dace
@@ -1 +1 @@
-Subproject commit d5fbadb626389e425fac5ed93d2a880811eca41f
+Subproject commit 7c526886bfadeb9808a06a66fcbca1dbfa6b8ad4
diff --git a/external/gt4py b/external/gt4py
index eef3c0ee..331c7bba 160000
--- a/external/gt4py
+++ b/external/gt4py
@@ -1 +1 @@
-Subproject commit eef3c0ee9de9c4eb8f57650b64abf7863c05fc83
+Subproject commit 331c7bba9161b96cf94f6d5d9bda06161703db28
diff --git a/ndsl/__init__.py b/ndsl/__init__.py
index d468d58f..72ea237f 100644
--- a/ndsl/__init__.py
+++ b/ndsl/__init__.py
@@ -10,6 +10,7 @@
from .constants import ConstantVersions
from .dsl.caches.codepath import FV3CodePath
from .quantity import Quantity
+from .dsl.optimization_config import OptimizationConfig
from .dsl.ndsl_runtime import NDSLRuntime
from .dsl.stencil import FrozenStencil, GridIndexing, StencilFactory, TimingCollector
from .dsl.stencil_config import CompilationConfig, RunMode, StencilConfig
@@ -90,6 +91,7 @@
"MetaEnumStr",
"State",
"LocalState",
+ "OptimizationConfig",
"NDSLRuntime",
"Local",
"DiagManagerMonitor",
diff --git a/ndsl/config/backend.py b/ndsl/config/backend.py
index fd6d7e9e..adae0b6c 100644
--- a/ndsl/config/backend.py
+++ b/ndsl/config/backend.py
@@ -52,6 +52,8 @@ class BackendLoopOrder(Enum):
"orch:dace:cpu:KJI": "dace:cpu_KJI",
"st:dace:gpu:KJI": "dace:gpu",
"orch:dace:gpu:KJI": "dace:gpu",
+ "st:dace:gpu:IJK": "dace:gpu_IJK",
+ "orch:dace:gpu:IJK": "dace:gpu_IJK",
}
"""Internal: match the NDSL backend names with the GT4Py names"""
diff --git a/ndsl/dsl/caches/cache_location.py b/ndsl/dsl/caches/cache_location.py
index 87d608dd..d4313815 100644
--- a/ndsl/dsl/caches/cache_location.py
+++ b/ndsl/dsl/caches/cache_location.py
@@ -7,46 +7,48 @@ def identify_code_path(
partitioner: Partitioner,
single_code_path: bool,
) -> FV3CodePath:
- """Determine which code path your rank will hit.
+ """
+ Determine which code path your rank will hit.
- If single_code_path is True, single_code_path is True,
- only one code path exists (case of doubly periodic grid).
+ If single_code_path is True, only one code path exists,
+ e.g. in case of a doubly periodic grid.
If single_code_path is False, we are in the case of the
- cube-sphere and we will look at our position on the tile."""
+ cube-sphere and we will look at our position on the tile.
+ """
# Doubly-periodic or single tile grid
- if single_code_path:
+ if single_code_path or partitioner.layout == (1, 1):
return FV3CodePath.All
# Cube-sphere
- if partitioner.layout == (1, 1):
- return FV3CodePath.All
- elif partitioner.layout[0] == 1 or partitioner.layout[1] == 1:
+ if partitioner.layout[0] <= 1 or partitioner.layout[1] <= 1:
raise NotImplementedError(
- f"Build for layout {partitioner.layout} is not handled"
+ f"Build for layout {partitioner.layout} is not handled."
)
- else:
- if partitioner.tile.on_tile_bottom(rank):
- if partitioner.tile.on_tile_left(rank):
- return FV3CodePath.BottomLeft
- if partitioner.tile.on_tile_right(rank):
- return FV3CodePath.BottomRight
- else:
- return FV3CodePath.Bottom
- if partitioner.tile.on_tile_top(rank):
- if partitioner.tile.on_tile_left(rank):
- return FV3CodePath.TopLeft
- if partitioner.tile.on_tile_right(rank):
- return FV3CodePath.TopRight
- else:
- return FV3CodePath.Top
- else:
- if partitioner.tile.on_tile_left(rank):
- return FV3CodePath.Left
- if partitioner.tile.on_tile_right(rank):
- return FV3CodePath.Right
- else:
- return FV3CodePath.Center
+
+ # Bottom row
+ if partitioner.tile.on_tile_bottom(rank):
+ if partitioner.tile.on_tile_left(rank):
+ return FV3CodePath.BottomLeft
+ if partitioner.tile.on_tile_right(rank):
+ return FV3CodePath.BottomRight
+ return FV3CodePath.Bottom
+
+ # Top row
+ if partitioner.tile.on_tile_top(rank):
+ if partitioner.tile.on_tile_left(rank):
+ return FV3CodePath.TopLeft
+ if partitioner.tile.on_tile_right(rank):
+ return FV3CodePath.TopRight
+ return FV3CodePath.Top
+
+ # Left & right column with corners already handled
+ if partitioner.tile.on_tile_left(rank):
+ return FV3CodePath.Left
+ if partitioner.tile.on_tile_right(rank):
+ return FV3CodePath.Right
+
+ return FV3CodePath.Center
def get_cache_fullpath(code_path: FV3CodePath) -> str:
diff --git a/ndsl/dsl/dace/dace_config.py b/ndsl/dsl/dace/dace_config.py
index 1aedb028..67250738 100644
--- a/ndsl/dsl/dace/dace_config.py
+++ b/ndsl/dsl/dace/dace_config.py
@@ -10,14 +10,20 @@
from gt4py.cartesian.utils.compiler import cxx_compiler_defaults, gpu_configuration
from ndsl import LocalComm
+from ndsl.comm import Comm
from ndsl.comm.communicator import Communicator
from ndsl.comm.partitioner import Partitioner
from ndsl.config import Backend
from ndsl.dsl import NDSL_COMPILER_SILENCE, NDSL_GLOBAL_PRECISION
from ndsl.dsl.caches.cache_location import identify_code_path
from ndsl.dsl.caches.codepath import FV3CodePath
+from ndsl.dsl.dace.hardware_config import get_gpu_hardware_defaults
from ndsl.optional_imports import cupy as cp
-from ndsl.performance.collector import NullPerformanceCollector, PerformanceCollector
+from ndsl.performance.collector import (
+ AbstractPerformanceCollector,
+ NullPerformanceCollector,
+ PerformanceCollector,
+)
if TYPE_CHECKING:
@@ -166,8 +172,8 @@ def __init__(
Args:
communicator: used for setting the distributed caches
backend: string for the backend
- tile_nx: x/y domain size for a single time
- tile_nz: z domain size for a single time
+ tile_nx: x/y domain size for a single tile
+ tile_nz: z domain size for a single tile
orchestration: orchestration mode from DaCeOrchestration
time: trigger performance collection, available to user with
`performance_collector`
@@ -181,16 +187,12 @@ def __init__(
# ToDo: DaceConfig becomes a bit more than a read-only config
# with this. Should be refactored into a DaceExecutor carrying a config
self.loaded_dace_executables: DaceExecutables = {}
- self.performance_collector = (
- PerformanceCollector(
- "InternalOrchestrationTimer",
- comm=(
- LocalComm(0, 6, {}) if communicator is None else communicator.comm
- ),
+ if not time:
+ self.performance_collector: AbstractPerformanceCollector = (
+ NullPerformanceCollector()
)
- if time
- else NullPerformanceCollector()
- )
+ else:
+ self.set_timer(communicator.comm if communicator else None)
# Temporary. This is a bit too out of the ordinary for the common user.
# We should refactor the architecture to allow for a `gtc:orchestrated:dace:X`
@@ -265,21 +267,29 @@ def __init__(
march_option = "-mcpu=native" if is_arm_neoverse else "-march=native"
# Removed --fast-math
gpu_config = gpu_configuration(GT4PY_COMPILE_OPT_LEVEL)
+ gpu_cflags = " ".join(gpu_config.gpu_compile_flags).strip()
dace.config.Config.set(
"compiler",
"cuda",
"args",
- value=f"-std=c++14 {warnings_policy} -Xcompiler -fPIC -O{optimization_level} -Xcompiler {march_option} {gpu_config.gpu_compile_flags}",
+ value=f"-std=c++14 {warnings_policy} -Xcompiler -fPIC -O{optimization_level} -Xcompiler {march_option} {gpu_cflags}",
)
- cuda_sm = cp.cuda.Device(0).compute_capability if cp else 60
- dace.config.Config.set("compiler", "cuda", "cuda_arch", value=f"{cuda_sm}")
- # Block size/thread count is defaulted to an average value for recent
- # hardware (Pascal and upward). The problem of setting an optimized
- # block/thread is both hardware and problem dependant. Fine tuners
- # available in DaCe should be relied on for further tuning of this value.
+ # Target compilation for hardware micro-code capacities
+ gpu_defaults = get_gpu_hardware_defaults()
dace.config.Config.set(
- "compiler", "cuda", "default_block_size", value="64,8,1"
+ "compiler",
+ "cuda",
+ "cuda_arch",
+ value=f"{gpu_defaults.compute_capability}",
+ )
+
+ # Default block size for kernels launch
+ dace.config.Config.set(
+ "compiler",
+ "cuda",
+ "default_block_size",
+ value=str(gpu_defaults.block_size)[1:-1],
)
# Potentially buggy - deactivate
dace.config.Config.set(
@@ -346,6 +356,9 @@ def __init__(
value="c",
)
+ # Debug lineinfo is incorrect anyway for the stencils
+ dace.config.Config.set("compiler", "lineinfo", value="none")
+
# Attempt to kill the dace.conf to avoid confusion
dace_conf_to_kill = dace.config.Config.cfg_filename()
if dace_conf_to_kill is not None:
@@ -413,4 +426,20 @@ def from_dict(cls, data: dict) -> Self:
config.rank_size = data["rank_size"]
config.layout = data["layout"]
config.tile_resolution = data["tile_resolution"]
- return config
+ # TODO
+ # Computed properties like `self.code_path` and `self.do_compile`
+ # aren't updated.
+ # We also don't `set_distributed_caches()` based on that updated
+ # information.
+ raise NotImplementedError(
+ "Implementation of `DaceConfig.from_dict()` is incomplete."
+ )
+
+ def set_timer(self, comm: Comm | None) -> None:
+ """Set timer on configuration externally"""
+ # TODO: this absolutely should not be a on a Configuration object
+ # and even less setup outside. Madness, we have lost our ways...
+ self.performance_collector = PerformanceCollector(
+ "InternalOrchestrationTimer",
+ comm=(LocalComm(0, 6, {}) if comm is None else comm),
+ )
diff --git a/ndsl/dsl/dace/hardware_config.py b/ndsl/dsl/dace/hardware_config.py
new file mode 100644
index 00000000..bbd367dc
--- /dev/null
+++ b/ndsl/dsl/dace/hardware_config.py
@@ -0,0 +1,126 @@
+import dataclasses
+import sys
+from pathlib import Path
+from typing import Literal
+
+from ndsl import ndsl_log
+from ndsl.optional_imports import cupy as cp
+
+
+GPUVendor = Literal["Nvidia"] | Literal["AMD"] | Literal["Intel"] | Literal["Unknown"]
+
+# Taken straight out of https://pcisig.com/membership/member-companies
+_VENDOR_PCI_SIGNATURES: dict[int, GPUVendor] = {
+ 0x10DE: "Nvidia",
+ 0x1002: "AMD",
+ 0x8086: "Intel",
+ 0x0: "Unknown",
+}
+
+# Cached copy of the hardware default
+_GPU_HARDWARE_DEFAULTS = None
+
+
+def _get_vendor() -> GPUVendor:
+ """Retrieve vendor using the current device PCI id to query the PCI vendor
+ from the kernel logs.
+
+ ⚠️ Only works on Linux - kicks back to "Unknown" in other cases.
+ """
+ if not sys.platform.startswith("linux"):
+ ndsl_log.info("GPU hardware detection only possible on Linux system.")
+ return "Unknown"
+
+ pci_device_id = cp.cuda.runtime.deviceGetPCIBusId(0)
+ dev_path = Path("/sys", "bus", "pci", "devices", f"{pci_device_id}")
+ if not dev_path.exists():
+ ndsl_log.info(f"GPU detection: PCI device not found at {dev_path}.")
+ return "Unknown"
+
+ with open(dev_path / "vendor", "r") as f:
+ vendor_str = f.read().strip().replace("0x", "")
+ vendor_id = int(vendor_str, 16)
+
+ if vendor_id not in _VENDOR_PCI_SIGNATURES:
+ ndsl_log.error(f"Unknown GPU vendor with PCI-SIG ID of {vendor_id:#X}.")
+ return "Unknown"
+
+ return _VENDOR_PCI_SIGNATURES[vendor_id]
+
+
+@dataclasses.dataclass
+class GPUHardwareDefaults:
+ """Compute defaults for common GPUs"""
+
+ vendor: GPUVendor
+ block_size: list[int] = dataclasses.field(default_factory=list)
+ compute_capability: int = -1 # Nvidia specific
+
+
+def get_gpu_hardware_defaults() -> GPUHardwareDefaults:
+ """Retrieve default values for GPU computation configuration."""
+ global _GPU_HARDWARE_DEFAULTS
+ if _GPU_HARDWARE_DEFAULTS is not None:
+ return _GPU_HARDWARE_DEFAULTS # type: ignore[unreachable]
+
+ if cp is None or not cp.cuda.is_available():
+ ndsl_log.warning("No cupy - defaulting for GPU hardware")
+ _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults(
+ vendor="Unknown",
+ block_size=[
+ 8,
+ 1,
+ 1,
+ ], # Smaller common denominator of massively parallel hardware
+ )
+ return _GPU_HARDWARE_DEFAULTS
+
+ # Who goes there
+ vendor = _get_vendor()
+ if vendor == "Nvidia":
+ compute_capability = int(cp.cuda.Device(0).compute_capability)
+ # Default block size based on compute capability
+ if compute_capability > 80:
+ # Covers:
+ # - Blackwell (100+)
+ # - Hopper (90-100)
+ # - Ampere (80-90)
+ block_sizes = [128, 1, 1]
+ elif compute_capability > 60:
+ # Covers:
+ # - Volta (70-80)
+ # - Pascal (60-70)
+ block_sizes = [64, 8, 1]
+ else:
+ # For older hardware - we default to the safe warp-size since
+ # the dawn of GPGPU on Nvidia hardware
+ block_sizes = [32, 1, 1]
+
+ _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults(
+ vendor=vendor,
+ block_size=block_sizes,
+ compute_capability=compute_capability,
+ )
+ elif vendor == "AMD":
+ _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults(
+ vendor=vendor,
+ block_size=[64, 1, 1], # Default RDNA architecture is Wave64
+ )
+ elif vendor == "Intel":
+ _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults(
+ vendor=vendor,
+ block_size=[32, 1, 1], # Intel can run 8, 16 or 32 - but SIMD betters in 32
+ )
+ else:
+ _GPU_HARDWARE_DEFAULTS = GPUHardwareDefaults(
+ vendor=vendor,
+ block_size=[
+ 8,
+ 1,
+ 1,
+ ], # Smaller common denominator of massively parallel hardware
+ )
+
+ ndsl_log.info(f"GPU vendor detected: {_GPU_HARDWARE_DEFAULTS.vendor}")
+
+ return _GPU_HARDWARE_DEFAULTS
diff --git a/ndsl/dsl/dace/labeler.py b/ndsl/dsl/dace/labeler.py
index 08398ca0..b2bb102f 100644
--- a/ndsl/dsl/dace/labeler.py
+++ b/ndsl/dsl/dace/labeler.py
@@ -1,11 +1,11 @@
-from __future__ import annotations
-
from typing import Any
-import dace.properties
+import dace
from dace import library, nodes
from dace.transformation import transformation as xf
+from ndsl import OptimizationConfig
+
@library.node
class _Labeler(nodes.LibraryNode):
@@ -13,9 +13,28 @@ class _Labeler(nodes.LibraryNode):
default_implementation = "pure"
unique_name = dace.properties.Property(dtype=str, desc="Unique name")
- def __init__(self, unique_name: str, **kwargs: dict[str, Any]) -> None:
+ def __init__(
+ self,
+ unique_name: str,
+ local_optimization: OptimizationConfig | None,
+ **kwargs: dict[str, Any],
+ ) -> None:
super().__init__(name="NDSLRuntime_Label", **kwargs)
+ # HACK to avoid state fusion of labeler states
+ # MPI WaitAll block state fusion, so we just pretend to be one 🐉.
+ # Keeping the labeler states non-fused is important to keep code flow consistent until we
+ # get to the schedule tree.
+ self.label = "_Waitall_"
+
self._unique_name = unique_name
+ self._local_optimizations = local_optimization
+
+ def has_side_effects(self) -> bool:
+ # HACK
+ # LibraryNodes with side effects aren't touched by simplify. This
+ # keeps the library nodes alive until we get to the schedule tree
+ # where we can use the information.
+ return True
@library.register_expansion(_Labeler, "pure")
@@ -32,7 +51,10 @@ def expansion(
def set_label(
- sdfg: dace.SDFG | dace.CompiledSDFG, qualname: str, is_top_sdfg: bool
+ sdfg: dace.SDFG | dace.CompiledSDFG,
+ qualname: str,
+ is_top_sdfg: bool,
+ local_optimizations: OptimizationConfig | None,
) -> None:
"""Surround the SDFG with two state/library node combo labelling
the code for future reference in further optimization.
@@ -50,19 +72,29 @@ def set_label(
# With the topmost SDFG we have to skip over the
# "init" state
if is_top_sdfg:
- state = sdfg.add_state_after(
+ label_state = sdfg.add_state_after(
state,
label=f"__Label_Enter__{qualname}",
)
else:
- state = sdfg.add_state_before(
+ label_state = sdfg.add_state_before(
state,
label=f"__Label_Enter__{qualname}",
)
- state.add_node(_Labeler(unique_name=f"Enter__{qualname}"))
+ label_state.add_node(
+ _Labeler(
+ unique_name=f"Enter__{qualname}",
+ local_optimization=local_optimizations,
+ )
+ )
if sdfg.out_edges(state) == []:
- state = sdfg.add_state_after(
+ label_state = sdfg.add_state_after(
state,
label=f"__Label_Exit__{qualname}",
)
- state.add_node(_Labeler(unique_name=f"Exit__{qualname}"))
+ label_state.add_node(
+ _Labeler(
+ unique_name=f"Exit__{qualname}",
+ local_optimization=local_optimizations,
+ )
+ )
diff --git a/ndsl/dsl/dace/orchestration.py b/ndsl/dsl/dace/orchestration.py
index 4da02544..0179a734 100644
--- a/ndsl/dsl/dace/orchestration.py
+++ b/ndsl/dsl/dace/orchestration.py
@@ -4,28 +4,30 @@
import os
from collections.abc import Callable, Sequence
from pathlib import Path
+from pprint import pformat
from typing import Any
-from dace import SDFG, CompiledSDFG
+from dace import SDFG, CompiledSDFG, DeviceType
from dace import compiletime as DaceCompiletime
from dace import dtypes
from dace import method as dace_method
from dace import nodes
from dace import program as dace_program
from dace.dtypes import DeviceType as DaceDeviceType
+from dace.dtypes import ScheduleType
from dace.dtypes import StorageType as DaceStorageType
from dace.frontend.python.common import SDFGConvertible
from dace.frontend.python.parser import DaceProgram
from dace.sdfg.analysis.schedule_tree import treenodes as tn
from dace.transformation.auto.auto_optimize import make_transients_persistent
-from dace.transformation.dataflow import MapExpansion
+from dace.transformation.dataflow import MapCollapse, MapExpansion
+from dace.transformation.dataflow.add_threadblock_map import AddThreadBlockMap
from dace.transformation.helpers import get_parent_map
from gt4py import storage as gt_storage
import ndsl.dsl.dace.replacements # noqa # We load in the DaCe replacements
-from ndsl import ndsl_log
+from ndsl import Backend, OptimizationConfig, ndsl_log
from ndsl.comm.mpi import MPI
-from ndsl.config import BackendLoopOrder
from ndsl.dsl.dace.build import get_sdfg_path, write_build_info
from ndsl.dsl.dace.dace_config import (
DEACTIVATE_DISTRIBUTED_DACE_COMPILE,
@@ -33,19 +35,15 @@
DaCeOrchestration,
)
from ndsl.dsl.dace.dace_executable import DaceExecutable
+from ndsl.dsl.dace.hardware_config import get_gpu_hardware_defaults
from ndsl.dsl.dace.labeler import set_label
from ndsl.dsl.dace.sdfg_debug_passes import (
negative_delp_checker,
negative_qtracers_checker,
sdfg_nan_checker,
)
-from ndsl.dsl.dace.stree import CPUPipeline
-from ndsl.dsl.dace.stree.optimizations import (
- AxisIterator,
- CartesianAxisMerge,
- CartesianRefineTransients,
- CleanUpScheduleTree,
-)
+from ndsl.dsl.dace.stree import CPUPipeline, GPUPipeline
+from ndsl.dsl.dace.stree.pipeline import StreePipeline
from ndsl.dsl.dace.utils import (
DaCeProgress,
memory_static_analysis,
@@ -55,10 +53,7 @@
from ndsl.quantity import Quantity, State
-_INTERNAL__SCHEDULE_TREE_OPTIMIZATION: bool = (
- os.environ.get("NDSL_STREE_OPT", "False") == "True"
-)
-"""INTERNAL: Developer flag to turn the untested schedule tree roundtrip optimizer."""
+_INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES: list[tn.ScheduleNodeVisitor] | None = None
def dace_inhibitor(func: Callable) -> Callable:
@@ -149,8 +144,36 @@ def _tree_as_sdfg(stree: tn.ScheduleTreeRoot) -> SDFG:
return stree.as_sdfg(skip={"ScalarToSymbolPromotion", "ControlFlowRaising"})
+def _optimization_pipeline(
+ config: OptimizationConfig,
+ device_type: DeviceType,
+ backend: Backend,
+ *,
+ passes: list[tn.ScheduleNodeVisitor] | None = None,
+ cache_directory: Path | None = None,
+) -> StreePipeline:
+ if device_type == DeviceType.CPU:
+ return CPUPipeline(
+ config, backend, passes=passes, cache_directory=cache_directory
+ )
+
+ if device_type == DeviceType.GPU:
+ return GPUPipeline(
+ config, backend, passes=passes, cache_directory=cache_directory
+ )
+
+ raise ValueError(
+ f"Unknown device type `{device_type}`, expected {DeviceType.CPU} or {DeviceType.GPU}."
+ )
+
+
def _build_sdfg(
- dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, args: Any, kwargs: Any
+ dace_program: DaceProgram,
+ sdfg: SDFG,
+ config: DaceConfig,
+ optimization_config: OptimizationConfig | None,
+ args: Any,
+ kwargs: Any,
) -> None:
"""Build the .so out of the SDFG on the top tile ranks only."""
is_compiling = True if DEACTIVATE_DISTRIBUTED_DACE_COMPILE else config.do_compile
@@ -158,6 +181,11 @@ def _build_sdfg(
backend_name = config.get_backend()
if is_compiling:
+ if optimization_config is None:
+ ndsl_log.debug(f"Using default optimization config for {sdfg.label}.")
+ optimization_config = OptimizationConfig()
+
+ ndsl_log.debug(f"Compiling config:\n{pformat(optimization_config, indent=2)}")
# Fully specialize all known symbols and then propagate these changes in the simplify
# pass that follows. This is not only a smart idea in general, but also simplifies (haha)
# the schedule tree (optimization) roundtrip.
@@ -170,27 +198,59 @@ def _build_sdfg(
repl_dict[sym] = val
my_sdfg.replace_dict(repl_dict)
- if config.verbose_orchestration:
+ if config.verbose_orchestration:
+ ndsl_log.debug("saving 00-combined_from_stencils.sdfgz")
+ sdfg.save(
+ os.path.abspath(
+ f"{sdfg.build_folder}/00-combined_from_stencils.sdfgz"
+ ),
+ compress=True,
+ )
+
+ if config.is_gpu_backend():
+ with DaCeProgress(config, "Configure maps to run on GPU"):
+ for this_sdfg in sdfg.all_sdfgs_recursive():
+ for state in this_sdfg.states():
+ for node in state.nodes():
+ if (
+ isinstance(node, nodes.EntryNode)
+ and node.schedule != ScheduleType.Sequential
+ ):
+ node.schedule = ScheduleType.GPU_Device
+
+ ndsl_log.debug("saving 00-gpu-maps.sdfgz")
sdfg.save(
- os.path.abspath(f"{sdfg.build_folder}/00-combined_from_stencils.sdfgz"),
+ os.path.abspath(f"{sdfg.build_folder}/00-gpu-maps.sdfgz"),
compress=True,
)
with DaCeProgress(config, "Simplify (1)"):
_simplify(sdfg)
if config.verbose_orchestration:
+ ndsl_log.debug("saving 01-simplify.sdfgz")
sdfg.save(
os.path.abspath(f"{sdfg.build_folder}/01-simplify_1.sdfgz"),
compress=True,
)
- if _INTERNAL__SCHEDULE_TREE_OPTIMIZATION:
+ if optimization_config.stree.enabled:
# Here be 🐉 - but tests exists in test_optimization.py
with DaCeProgress(config, "Schedule Tree: generate from SDFG"):
# Break all loops into uni-dimensional loops to simplify optimizations
- sdfg.apply_transformations_repeated(MapExpansion, validate=True)
+ sdfg.apply_transformations_repeated(
+ MapExpansion,
+ options={
+ "inner_schedule": (
+ ScheduleType.GPU_Device
+ if device_type is DeviceType.GPU
+ else ScheduleType.Default
+ )
+ },
+ validate=True,
+ )
stree = sdfg.as_schedule_tree()
if config.verbose_orchestration:
+ ndsl_log.debug("saving 02-pre_opt.stree.txt")
with open(
os.path.abspath(f"{sdfg.build_folder}/02-pre_opt.stree.txt"),
"w+",
@@ -198,45 +258,16 @@ def _build_sdfg(
f.write(stree.as_string())
with DaCeProgress(config, "Schedule Tree: optimization"):
- passes = []
- if backend_name.loop_order == BackendLoopOrder.IJK:
- passes.extend(
- [
- CleanUpScheduleTree(),
- CartesianAxisMerge(AxisIterator._I),
- CartesianAxisMerge(AxisIterator._J),
- CartesianAxisMerge(AxisIterator._K),
- CartesianRefineTransients(backend_name),
- ]
- )
- elif backend_name.loop_order == BackendLoopOrder.KJI:
- passes.extend(
- [
- CleanUpScheduleTree(),
- CartesianAxisMerge(AxisIterator._K),
- CartesianAxisMerge(AxisIterator._J),
- CartesianAxisMerge(AxisIterator._I),
- CartesianRefineTransients(backend_name),
- ]
- )
- elif backend_name.loop_order == BackendLoopOrder.KIJ:
- passes.extend(
- [
- CleanUpScheduleTree(),
- CartesianAxisMerge(AxisIterator._K),
- CartesianAxisMerge(AxisIterator._I),
- CartesianAxisMerge(AxisIterator._J),
- CartesianRefineTransients(backend_name),
- ]
- )
- else:
- raise NotImplementedError(
- f"Loop order {backend_name.loop_order} has no schedule tree pipeline"
- )
- CPUPipeline(passes=passes, cache_directory=Path(sdfg.build_folder)).run(
- stree, verbose=config.verbose_schedule_tree_optimizations
+ pipeline = _optimization_pipeline(
+ optimization_config,
+ device_type,
+ backend_name,
+ cache_directory=Path(sdfg.build_folder),
+ passes=_INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES,
)
+ pipeline.run(stree, verbose=config.verbose_schedule_tree_optimizations)
if config.verbose_orchestration:
+ ndsl_log.debug("saving 03-post_opt.stree.txt")
with open(
os.path.abspath(f"{sdfg.build_folder}/03-post_opt.stree.txt"),
"w+",
@@ -246,48 +277,92 @@ def _build_sdfg(
with DaCeProgress(config, "Schedule Tree: go back to SDFG"):
sdfg = _tree_as_sdfg(stree)
if config.verbose_orchestration:
+ ndsl_log.debug("saving 04-from_stree.sdfgz")
sdfg.save(
os.path.abspath(f"{sdfg.build_folder}/04-from_stree.sdfgz"),
compress=True,
)
- # Make the transients array persistents
- if config.is_gpu_backend():
- # TODO
- # The following should happen on the stree level
- _to_gpu(sdfg)
+ # We want all maps properly collapse to make sure the codegen will see nD parallel
+ # axis as a single kernelizable map
+ with DaCeProgress(config, "Collapse maps"):
+ # allow `MapCollapse` to collapse maps with different schedules
+ sdfg.apply_transformations_repeated(MapCollapse, permissive=True)
- sdfg.apply_gpu_transformations()
+ with DaCeProgress(config, "Make transient persistents"):
+ # Make the transients array persistents
+ if config.is_gpu_backend():
+ # TODO
+ # The following should happen on the stree level
+ _to_gpu(sdfg)
+ make_transients_persistent(sdfg=sdfg, device=device_type)
- make_transients_persistent(sdfg=sdfg, device=device_type)
+ # Upload args to device
+ _upload_to_device(list(args) + list(kwargs.values()))
+ else:
+ # TODO
+ # The following should happen on the stree level
+ for _sd, _aname, arr in sdfg.arrays_recursive():
+ if arr.shape == (1,):
+ arr.storage = DaceStorageType.Register
+ make_transients_persistent(sdfg=sdfg, device=device_type)
- # Upload args to device
- _upload_to_device(list(args) + list(kwargs.values()))
- else:
- # TODO
- # The following should happen on the stree level
- for _sd, _aname, arr in sdfg.arrays_recursive():
- if arr.shape == (1,):
- arr.storage = DaceStorageType.Register
- make_transients_persistent(sdfg=sdfg, device=device_type)
-
- # Build non-constants & non-transients from the sdfg_kwargs
- sdfg_kwargs = dace_program._create_sdfg_args(sdfg, args, kwargs)
- for k in dace_program.constant_args:
- if k in sdfg_kwargs:
- del sdfg_kwargs[k]
- sdfg_kwargs = {k: v for k, v in sdfg_kwargs.items() if v is not None}
- for k, tup in dace_program.resolver.closure_arrays.items():
- if k in sdfg_kwargs and tup[1].transient:
- del sdfg_kwargs[k]
-
- with DaCeProgress(config, "Simplify (2)"):
- _simplify(sdfg)
- if config.verbose_orchestration:
- sdfg.save(
- os.path.abspath(f"{sdfg.build_folder}/05-simplify_2.sdfgz"),
- compress=True,
+ if config.is_gpu_backend():
+ with DaCeProgress(config, "Apply GPU transformations"):
+ # Set block size on GPU maps and collect callback
+ # tasklets to exclude next
+ gpu_defaults = get_gpu_hardware_defaults()
+ exclude_taskslets_list = []
+
+ for me, _state in sdfg.all_nodes_recursive():
+ if (
+ isinstance(me, nodes.MapEntry)
+ and me.map.schedule == ScheduleType.GPU_Device
+ ):
+ if me.map.gpu_block_size is None:
+ me.map.gpu_block_size = gpu_defaults.block_size
+
+ if isinstance(me, nodes.Tasklet) and "callback_" in me.label:
+ exclude_taskslets_list.append(me.label)
+
+ sdfg.apply_transformations_repeated(
+ AddThreadBlockMap, print_report=False
)
+
+ if optimization_config.gpu.common_gpu_xforms:
+ with DaCeProgress(config, "Apply common GPU xforms"):
+ # Apply common GPU transforms (includes a simplify)
+ # while making sure tasklet remain on the host
+ from dace.transformation.interstate import GPUTransformSDFG
+
+ sdfg.apply_transformations(
+ GPUTransformSDFG,
+ options={
+ "exclude_tasklets": ",".join(exclude_taskslets_list),
+ "host_data": ["__pystate"],
+ },
+ )
+ else:
+ with DaCeProgress(config, "GPU simplify"):
+ _simplify(sdfg)
+
+ if config.verbose_orchestration:
+ ndsl_log.debug("saving 05-apply_gpu_xforms.sdfgz")
+ sdfg.save(
+ os.path.abspath(
+ f"{sdfg.build_folder}/05-apply_gpu_xforms.sdfgz"
+ ),
+ compress=True,
+ )
+ else:
+ with DaCeProgress(config, "Simplify (2)"):
+ _simplify(sdfg)
+ if config.verbose_orchestration:
+ ndsl_log.debug("saving 05-simplify_2.sdfgz")
+ sdfg.save(
+ os.path.abspath(f"{sdfg.build_folder}/05-simplify_2.sdfgz"),
+ compress=True,
+ )
# Move all memory that can be into a pool to lower memory pressure for GPU
# We skip this memory optimization for CPU because we don't have a memory
# pool available yet (DaCe v1)
@@ -316,7 +391,12 @@ def _build_sdfg(
# Compile
with DaCeProgress(config, "Codegen & compile"):
- sdfg.compile()
+ compiled_sdfg = sdfg.compile()
+ config.loaded_dace_executables[dace_program] = DaceExecutable(
+ compiled_sdfg=compiled_sdfg,
+ arguments={},
+ arguments_hash=0,
+ )
# Printing analysis of the compiled SDFG
with DaCeProgress(config, "Build finished. Running memory static analysis"):
@@ -355,22 +435,30 @@ def _build_sdfg(
)
MPI.COMM_WORLD.Barrier()
- with DaCeProgress(config, "Loading"):
- sdfg_path = get_sdfg_path(dace_program.name, config, override_run_only=True)
- if sdfg_path is None:
- raise ValueError("Couldn't load SDFG post build")
- compiledSDFG, _ = dace_program.load_precompiled_sdfg(
- sdfg_path, *args, **kwargs
- )
- config.loaded_dace_executables[dace_program] = DaceExecutable(
- compiled_sdfg=compiledSDFG,
- arguments={},
- arguments_hash=0,
- )
+ if not is_compiling:
+ with DaCeProgress(config, "Loading"):
+ sdfg_path = get_sdfg_path(
+ dace_program.name, config, override_run_only=True
+ )
+ if sdfg_path is None:
+ raise ValueError("Couldn't load SDFG post build")
+ compiledSDFG, _ = dace_program.load_precompiled_sdfg(
+ sdfg_path, *args, **kwargs
+ )
+ config.loaded_dace_executables[dace_program] = DaceExecutable(
+ compiled_sdfg=compiledSDFG,
+ arguments={},
+ arguments_hash=0,
+ )
def _call_sdfg(
- dace_program: DaceProgram, sdfg: SDFG, config: DaceConfig, args: Any, kwargs: Any
+ dace_program: DaceProgram,
+ sdfg: SDFG,
+ config: DaceConfig,
+ optimization_config: OptimizationConfig | None,
+ args: Any,
+ kwargs: Any,
) -> list | None:
"""Dispatch to either SDFG execution and/or build."""
@@ -381,8 +469,7 @@ def _call_sdfg(
mode in [DaCeOrchestration.Build, DaCeOrchestration.BuildAndRun]
and dace_program not in config.loaded_dace_executables # already cached
):
- ndsl_log.info("Building DaCe orchestration")
- _build_sdfg(dace_program, sdfg, config, args, kwargs)
+ _build_sdfg(dace_program, sdfg, config, optimization_config, args, kwargs)
if mode not in [DaCeOrchestration.BuildAndRun, DaCeOrchestration.Run]:
raise ValueError(f"Unexpected DaceOrchestration mode `{mode}`.")
@@ -427,6 +514,7 @@ def _call_sdfg(
def _parse_sdfg(
dace_program: DaceProgram,
config: DaceConfig,
+ optimization: OptimizationConfig | None,
*args: Any,
**kwargs: Any,
) -> SDFG | CompiledSDFG | None:
@@ -441,6 +529,8 @@ def _parse_sdfg(
if dace_program in config.loaded_dace_executables:
return config.loaded_dace_executables[dace_program].compiled_sdfg
+ ndsl_log.info(f"Building DaCe orchestration for {dace_program.f.__qualname__}")
+
# Build expected path
sdfg_path = get_sdfg_path(dace_program.name, config)
if sdfg_path is None:
@@ -462,6 +552,16 @@ def _parse_sdfg(
simplify=False,
validate=False, # TODO: should we have a "debug flag" to turn this on?
)
+
+ # Label the code (this is the topmost code)
+ if sdfg is not None and optimization is not None and optimization.stree.enabled:
+ set_label(
+ sdfg,
+ dace_program.f.__qualname__,
+ is_top_sdfg=True,
+ local_optimizations=optimization,
+ )
+
return sdfg
if os.path.isfile(sdfg_path):
@@ -489,9 +589,15 @@ class _LazyComputepathFunction(SDFGConvertible):
that will be compiled but not regenerated.
"""
- def __init__(self, func: Callable, config: DaceConfig) -> None:
+ def __init__(
+ self,
+ func: Callable,
+ config: DaceConfig,
+ optimization_config: OptimizationConfig | None,
+ ) -> None:
self.func = func
self.config = config
+ self.optimization_config = optimization_config
self.daceprog: DaceProgram = dace_program(self.func)
self._sdfg = None
@@ -500,6 +606,7 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def]
sdfg = _parse_sdfg(
self.daceprog,
self.config,
+ self.optimization_config,
*args,
**kwargs,
)
@@ -507,6 +614,7 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def]
self.daceprog,
sdfg,
self.config,
+ self.optimization_config,
args,
kwargs,
)
@@ -520,7 +628,9 @@ def global_vars(self, value): # type: ignore[no-untyped-def]
self.daceprog.global_vars = value
def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def]
- return _parse_sdfg(self.daceprog, self.config, *args, **kwargs)
+ return _parse_sdfg(
+ self.daceprog, self.config, self.optimization_config, *args, **kwargs
+ )
def __sdfg_closure__(self, *args, **kwargs): # type: ignore[no-untyped-def]
return self.daceprog.__sdfg_closure__(*args, **kwargs)
@@ -567,25 +677,39 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def]
sdfg = _parse_sdfg(
self.daceprog,
self.lazy_method.config,
+ self.lazy_method.optimization_config,
*args,
**kwargs,
)
- # Label the code (this is the topmost code)
- if sdfg is not None and _INTERNAL__SCHEDULE_TREE_OPTIMIZATION:
- set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=True)
return _call_sdfg(
self.daceprog,
sdfg,
self.lazy_method.config,
+ self.lazy_method.optimization_config,
args,
kwargs,
)
def __sdfg__(self, *args, **kwargs): # type: ignore[no-untyped-def]
- sdfg = _parse_sdfg(self.daceprog, self.lazy_method.config, *args, **kwargs)
+ sdfg = _parse_sdfg(
+ self.daceprog,
+ self.lazy_method.config,
+ self.lazy_method.optimization_config,
+ *args,
+ **kwargs,
+ )
# Label the code
- if sdfg is not None and _INTERNAL__SCHEDULE_TREE_OPTIMIZATION:
- set_label(sdfg, type(self.obj_to_bind).__qualname__, is_top_sdfg=False)
+ if (
+ sdfg is not None
+ and self.lazy_method.optimization_config is not None
+ and self.lazy_method.optimization_config.stree.enabled
+ ):
+ set_label(
+ sdfg,
+ type(self.obj_to_bind).__qualname__,
+ is_top_sdfg=False,
+ local_optimizations=self.lazy_method.optimization_config,
+ )
return sdfg
def __sdfg_closure__(self, reevaluate=None): # type: ignore[no-untyped-def]
@@ -599,9 +723,15 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None): # t
constant_args, given_args, parent_closure
)
- def __init__(self, func: Callable, config: DaceConfig):
+ def __init__(
+ self,
+ func: Callable,
+ config: DaceConfig,
+ optimization_config: OptimizationConfig | None,
+ ) -> None:
self.func = func
self.config = config
+ self.optimization_config = optimization_config
def __get__(self, obj: object, objtype: Any = None) -> SDFGEnabledCallable:
"""Return SDFGEnabledCallable wrapping original obj.method from cache.
@@ -620,6 +750,7 @@ def orchestrate(
config: DaceConfig,
method_to_orchestrate: str = "__call__",
dace_compiletime_args: Sequence[str] | None = None,
+ optimization_config: OptimizationConfig | None = None,
) -> None:
"""
Orchestrate a method of an object with DaCe.
@@ -672,7 +803,7 @@ def orchestrate(
# Build DaCe orchestrated wrapper
# This is a JIT object, e.g. DaCe compilation will happen on call
- wrapped = _LazyComputepathMethod(func, config).__get__(obj)
+ wrapped = _LazyComputepathMethod(func, config, optimization_config).__get__(obj)
if method_to_orchestrate == "__call__":
# Grab the function from the type of the child class
@@ -724,6 +855,7 @@ def closure_resolver(self, constant_args, given_args, parent_closure=None): # t
def orchestrate_function(
config: DaceConfig,
dace_compiletime_args: Sequence[str] | None = None,
+ optimization_config: OptimizationConfig | None = None,
) -> Callable[..., Any] | _LazyComputepathFunction:
"""
Decorator orchestrating a method of an object with DaCe.
@@ -742,7 +874,7 @@ def _decorator(func: Callable[..., Any]): # type: ignore[no-untyped-def]
def _wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
for argument in dace_compiletime_args:
func.__annotations__[argument] = DaceCompiletime
- return _LazyComputepathFunction(func, config)
+ return _LazyComputepathFunction(func, config, optimization_config)
return _wrapper(func) if config.is_dace_orchestrated() else func
diff --git a/ndsl/dsl/dace/stree/optimizations/__init__.py b/ndsl/dsl/dace/stree/optimizations/__init__.py
index 73497f93..210f7978 100644
--- a/ndsl/dsl/dace/stree/optimizations/__init__.py
+++ b/ndsl/dsl/dace/stree/optimizations/__init__.py
@@ -1,11 +1,28 @@
-from .axis_merge import AxisIterator, CartesianAxisMerge
+from .axis_merge import CartesianAxisMerge
+from .cartesian_merge import CartesianMerge
from .clean_tree import CleanUpScheduleTree
+from .kernelize_maps import KernelizeMaps
+from .local_optimizations import LocalOptimizations
+from .offgrid_conditionals import (
+ ExtractOffgridConditionals,
+ InlineOffgridConditionals,
+ MergeConditionals,
+)
from .refine_transients import CartesianRefineTransients
+from .remove_loops import InlineVertical2DWrite
+from .statistics import TreeOptimizationStatistics
__all__ = [
- "AxisIterator",
"CartesianAxisMerge",
- "CartesianRefineTransients",
+ "CartesianMerge",
"CleanUpScheduleTree",
+ "KernelizeMaps",
+ "LocalOptimizations",
+ "ExtractOffgridConditionals",
+ "InlineOffgridConditionals",
+ "MergeConditionals",
+ "CartesianRefineTransients",
+ "InlineVertical2DWrite",
+ "TreeOptimizationStatistics",
]
diff --git a/ndsl/dsl/dace/stree/optimizations/axis_merge.py b/ndsl/dsl/dace/stree/optimizations/axis_merge.py
index c042badf..196d3d0e 100644
--- a/ndsl/dsl/dace/stree/optimizations/axis_merge.py
+++ b/ndsl/dsl/dace/stree/optimizations/axis_merge.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import copy
import dace
@@ -7,29 +5,18 @@
from dace.sdfg.analysis.schedule_tree import treenodes as tn
from ndsl import ndsl_log
-from ndsl.dsl.dace.stree.optimizations.memlet_helpers import (
+from ndsl.dsl.dace.stree.optimizations.common import (
AxisIterator,
- no_data_dependencies_on_cartesian_axis,
-)
-from ndsl.dsl.dace.stree.optimizations.tree_common_op import (
detect_cycle,
+ get_next_node,
+ is_axis_for,
+ is_axis_map,
+ last_node,
list_index,
+ no_data_dependencies_on_cartesian_axis,
swap_node_position_in_tree,
)
-
-
-# Buggy passes that should work
-PUSH_IFSCOPE_DOWNWARD = False # Crashing the overall stree - bad algorithmics
-
-
-def _is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool:
- """Returns true if node is a map over the given axis."""
- map_parameter = node.node.map.params
- return len(map_parameter) == 1 and map_parameter[0].startswith(axis.as_str())
-
-
-def _is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool:
- return node.loop.loop_variable.startswith(axis.as_str())
+from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ReplaceAxisSymbol
def _both_same_single_axis_maps(
@@ -39,18 +26,17 @@ def _both_same_single_axis_maps(
(
len(first.node.map.params) == 1 and len(second.node.map.params) == 1
) # Single axis
- and _is_axis_map(first, axis) # Correct axis in first map
- and _is_axis_map(second, axis) # Correct axis in second map
+ and is_axis_map(first, axis) # Correct axis in first map
+ and is_axis_map(second, axis) # Correct axis in second map
)
def _can_merge_axis_maps(
first: tn.MapScope, second: tn.MapScope, axis: AxisIterator
) -> bool:
- if _both_same_single_axis_maps(first, second, axis):
- if no_data_dependencies_on_cartesian_axis(first, second, axis):
- return True
- return False
+ return _both_same_single_axis_maps(
+ first, second, axis
+ ) and no_data_dependencies_on_cartesian_axis(first, second, axis)
class InsertOvercomputationGuard(tn.ScheduleNodeTransformer):
@@ -82,89 +68,45 @@ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope:
all_children_are_maps = all(
[isinstance(child, tn.MapScope) for child in node.children]
)
- if not all_children_are_maps:
- if self._merged_range != self._original_range:
- if_scope = tn.IfScope(
- condition=self._execution_condition(),
- children=node.children,
- parent=node,
- )
- # Re-parent to IF
- for child in node.children:
- child.parent = if_scope
- node.children = [if_scope]
+ if all_children_are_maps:
+ node.children = self.visit(node.children)
return node
- node.children = self.visit(node.children)
+ if self._merged_range != self._original_range:
+ if_scope = tn.IfScope(
+ condition=self._execution_condition(),
+ children=node.children,
+ parent=node,
+ )
+ # Re-parent to IF
+ for child in node.children:
+ child.parent = if_scope
+ node.children = [if_scope]
return node
-def _get_next_node(
- nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode
-) -> tn.ScheduleTreeNode:
- return nodes[list_index(nodes, node) + 1]
-
-
-def _last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> bool:
- return list_index(nodes, node) >= len(nodes) - 1
-
-
-class ReplaceAxisSymbol(tn.ScheduleNodeVisitor):
- def __init__(self, axis: AxisIterator) -> None:
- self._axis = axis
-
- def visit_MapScope(
- self,
- map_scope: tn.MapScope,
- axis_replacements: dict[str, str] | None = None,
- ) -> None:
- if axis_replacements is None:
- axis_replacements = {}
-
- for index, param in enumerate(map_scope.node.params):
- if param in axis_replacements:
- map_scope.node.params[index] = axis_replacements[param]
-
- # visit children
- for child in map_scope.children:
- self.visit(child, axis_replacements=axis_replacements)
-
- def visit_TaskletNode(
- self,
- node: tn.TaskletNode,
- axis_replacements: dict[str, str] | None = None,
- ) -> None:
- if not axis_replacements:
- # Noop if there are no replacements to do.
- return
-
- for memlets in node.in_memlets.values():
- memlets.replace(axis_replacements)
- for memlets in node.out_memlets.values():
- memlets.replace(axis_replacements)
-
-
class CartesianAxisMerge(tn.ScheduleNodeTransformer):
"""Merge a cartesian axis if they are contiguous in code-flow.
Can do:
- merge a given axis with the next maps at the same recursion level
- - can overcompute (eager) to allow for more merging at the cost of an if
+ - can overcompute to allow for more merging at the cost of an if
It expects:
- All Maps and ForLoop are on a single axis - but doesn't check for it.
Args:
axis: AxisIterator to be merged
- eager: overcompute with a conditional guard
+ overcompute: merge at the cost of an if statement.
"""
- def __init__(self, axis: AxisIterator, *, eager: bool = True) -> None:
+ def __init__(self, axis: AxisIterator, *, overcompute: bool) -> None:
self.axis = axis
- self.eager = eager
+ self.overcompute = overcompute
def __str__(self) -> str:
- return f"CartesianAxisMerge_{self.axis.name}_{'eager' if self.eager else ''}"
+ suffix = "_overcompute" if self.overcompute else ""
+ return f"CartesianAxisMerge_{self.axis.name}{suffix}"
def _merge_node(
self, node: tn.ScheduleTreeNode, nodes: list[tn.ScheduleTreeNode]
@@ -179,9 +121,6 @@ def _merge_node(
if isinstance(node, tn.MapScope):
return self._map_overcompute_merge(node, nodes)
- if PUSH_IFSCOPE_DOWNWARD and isinstance(node, tn.IfScope):
- return self._push_ifelse_down(node, nodes)
-
if isinstance(node, tn.ForScope):
return self._for_merge(node)
@@ -197,7 +136,7 @@ def _merge_node(
def _for_merge(self, the_for_scope: tn.ForScope) -> int:
merged = 0
- if _is_axis_for(the_for_scope, self.axis):
+ if is_axis_for(the_for_scope, AxisIterator._K):
# TODO: if the for scope is on a cartesian axis it can be
# merged with other for scope going in the same direction
pass
@@ -206,7 +145,7 @@ def _for_merge(self, the_for_scope: tn.ForScope) -> int:
if (
len(the_for_scope.children) == 1
and isinstance(the_for_scope.children[0], tn.MapScope)
- and _is_axis_map(the_for_scope.children[0], self.axis)
+ and is_axis_map(the_for_scope.children[0], self.axis)
):
swap_node_position_in_tree(the_for_scope, the_for_scope.children[0])
merged += 1
@@ -248,92 +187,19 @@ def _push_tasklet_down(
return merged
- def _push_ifelse_down(
- self, the_if: tn.IfScope, nodes: list[tn.ScheduleTreeNode]
- ) -> int:
- merged = 0
-
- # Recurse down if/else/elif
- if_index = list_index(nodes, the_if)
- if len(the_if.children) != 0:
- merged += self._merge_node(the_if.children[0], the_if.children)
- for else_index in range(if_index + 1, len(nodes)):
- else_node = nodes[else_index]
- if else_index < len(nodes) and (
- isinstance(else_node, tn.ElseScope)
- or isinstance(else_node, tn.ElifScope)
- ):
- merged += self._merge_node(else_node, else_node.children)
- else:
- break
-
- # Look at swapping if/else/elif first map w/ control flow
-
- # Gather all first maps - if they do not exists, get out
- all_maps = []
- if isinstance(the_if.children[0], tn.MapScope):
- all_maps.append(the_if.children[0])
- else:
- return merged
- for else_index in range(if_index + 1, len(nodes)):
- else_node = nodes[else_index]
- if else_index < len(nodes) and (
- isinstance(else_node, tn.ElseScope)
- or isinstance(else_node, tn.ElifScope)
- ):
- if isinstance(else_node.children[0], tn.MapScope):
- all_maps.append(else_node.children[0])
- else:
- return merged
-
- else:
- break
-
- # Check for mergeability
- if len(all_maps) > 1:
- the_map = all_maps[0]
- for _map in all_maps[1:]:
- if not _can_merge_axis_maps(the_map, _map, self.axis):
- return merged
-
- # We are good to go - swap it all
- inner_if_map = the_if.children[0]
-
- # Swap IF & maps
- if_index = list_index(nodes, the_if)
- swap_node_position_in_tree(the_if, inner_if_map)
-
- # Swap ELIF/ELSE & maps
- for else_index in range(if_index + 1, len(nodes)):
- if else_index < len(nodes) and (
- isinstance(nodes[else_index], tn.ElseScope)
- or isinstance(nodes[else_index], tn.ElifScope)
- ):
- swap_node_position_in_tree(
- nodes[else_index], nodes[else_index].children[0]
- )
- else:
- break
-
- # Merge the Maps
- assert isinstance(nodes[if_index], tn.MapScope)
- merged += self._map_overcompute_merge(nodes[if_index], nodes)
-
- return merged
-
def _map_overcompute_merge(
self, the_map: tn.MapScope, nodes: list[tn.ScheduleTreeNode]
) -> int:
# End of nodes OR
# Not the right axis
# --> recurse
- if _last_node(nodes, the_map) or not _is_axis_map(the_map, self.axis):
+ if last_node(nodes, the_map) or not is_axis_map(the_map, self.axis):
merged = 0
for child in the_map.children:
merged += self._merge_node(child, the_map.children)
return merged
- next_node = _get_next_node(nodes, the_map)
+ next_node = get_next_node(nodes, the_map)
# Next node is not a MapScope - no merge
if not isinstance(next_node, tn.MapScope):
@@ -345,7 +211,6 @@ def _map_overcompute_merge(
# Over compute to merge:
# - force-merge by expanding the ranges
- # - then, guard children to only run in their respective range
first_range = the_map.node.map.range
second_range = next_node.node.map.range
merged_range = dace.subsets.Range(
@@ -358,8 +223,15 @@ def _map_overcompute_merge(
]
)
- # push IfScope down if children are just maps
- axis_as_str = the_map.node.params[0]
+ # only overcompute if configured - otherwise no merge
+ if not self.overcompute and (
+ first_range != merged_range or second_range != merged_range
+ ):
+ return 0
+
+ # - then, guard children to only run in their respective range
+ axis_as_str = the_map.node.map.params[0]
+ assert isinstance(axis_as_str, str)
first_map = InsertOvercomputationGuard(
axis_as_str, merged_range=merged_range, original_range=first_range
).visit(the_map)
@@ -368,7 +240,9 @@ def _map_overcompute_merge(
merged_range=merged_range,
original_range=second_range,
).visit(next_node)
- merged_children: list[tn.MapScope] = [
+ assert isinstance(first_map, tn.MapScope)
+ assert isinstance(second_map, tn.MapScope)
+ merged_children: list[tn.ScheduleTreeNode] = [
*first_map.children,
*second_map.children,
]
@@ -384,11 +258,13 @@ def _map_overcompute_merge(
# K-maps use unique iterators (i.e. every k-map iterates over `k__[0-9]*`).
# After merge, we need to replace the axis symbols of the second map's children
# with the axis symbol of the first map.
- if next_node.node.map.params[0] != the_map.node.map.params[0]:
- replacements = {next_node.node.map.params[0]: the_map.node.map.params[0]}
- ReplaceAxisSymbol(self.axis).visit(
- first_map, axis_replacements=replacements
- )
+ if second_map.node.map.params[0] != first_map.node.map.params[0]:
+ replacements = {
+ dace.symbol(second_map.node.map.params[0]): dace.symbol(
+ first_map.node.map.params[0]
+ )
+ }
+ ReplaceAxisSymbol(replacements).visit(first_map)
# delete now-merged second_map
del nodes[list_index(nodes, next_node)]
diff --git a/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py
new file mode 100644
index 00000000..067eb5f2
--- /dev/null
+++ b/ndsl/dsl/dace/stree/optimizations/cartesian_merge.py
@@ -0,0 +1,107 @@
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
+
+from ndsl.config import Backend, BackendLoopOrder
+from ndsl.dsl.dace.stree.optimizations.axis_merge import CartesianAxisMerge
+from ndsl.dsl.dace.stree.optimizations.common import AxisIterator
+from ndsl.dsl.dace.stree.optimizations.offgrid_conditionals import (
+ ExtractOffgridConditionals,
+ InlineOffgridConditionals,
+ MergeConditionals,
+)
+
+
+class CartesianMerge(tn.ScheduleNodeTransformer):
+ """Merge Cartesian computation blocks.
+
+ Args:
+ backend: The loop order influences the merge order.
+ overcompute: Whether to merge at the cost of an if statement. Defaults to True.
+ """
+
+ def __init__(
+ self,
+ backend: Backend,
+ *,
+ overcompute: bool = True,
+ merge_order: str = "default",
+ ) -> None:
+ super().__init__()
+ self._backend = backend
+ self._overcompute = overcompute
+ self._merge_order = merge_order
+
+ if self._merge_order not in (
+ "default",
+ "IJK",
+ "IKJ",
+ "JIK",
+ "JKI",
+ "KIJ",
+ "KJI",
+ ):
+ raise ValueError(f"Unexpected merge order {self._merge_order}.")
+
+ def __str__(self) -> str:
+ return "CartesianMerge"
+
+ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None:
+ axis_merge_order = self._axis_merge_order()
+ for axis in axis_merge_order:
+ InlineOffgridConditionals(axis).visit(node)
+ MergeConditionals().visit(node)
+
+ for axis in axis_merge_order:
+ CartesianAxisMerge(
+ axis, overcompute=self._overcompute
+ ).visit_ScheduleTreeRoot(node)
+
+ ExtractOffgridConditionals().visit(node)
+ MergeConditionals().visit(node)
+
+ def _axis_merge_order(self) -> tuple[AxisIterator, AxisIterator, AxisIterator]:
+ if self._merge_order == "default":
+ return self._axis_merge_order_default()
+
+ return self._axis_merge_order_custom()
+
+ def _axis_merge_order_default(
+ self,
+ ) -> tuple[AxisIterator, AxisIterator, AxisIterator]:
+ if self._backend.loop_order == BackendLoopOrder.IJK:
+ return (AxisIterator._I, AxisIterator._J, AxisIterator._K)
+
+ if self._backend.loop_order == BackendLoopOrder.IKJ:
+ return (AxisIterator._I, AxisIterator._K, AxisIterator._J)
+
+ if self._backend.loop_order == BackendLoopOrder.JIK:
+ return (AxisIterator._J, AxisIterator._I, AxisIterator._K)
+
+ if self._backend.loop_order == BackendLoopOrder.JKI:
+ return (AxisIterator._J, AxisIterator._K, AxisIterator._I)
+
+ if self._backend.loop_order == BackendLoopOrder.KIJ:
+ return (AxisIterator._K, AxisIterator._I, AxisIterator._J)
+
+ assert self._backend.loop_order == BackendLoopOrder.KJI
+ return (AxisIterator._K, AxisIterator._J, AxisIterator._I)
+
+ def _axis_merge_order_custom(
+ self,
+ ) -> tuple[AxisIterator, AxisIterator, AxisIterator]:
+ if self._merge_order == "IJK":
+ return (AxisIterator._I, AxisIterator._J, AxisIterator._K)
+
+ if self._merge_order == "IKJ":
+ return (AxisIterator._I, AxisIterator._K, AxisIterator._J)
+
+ if self._merge_order == "JIK":
+ return (AxisIterator._J, AxisIterator._I, AxisIterator._K)
+
+ if self._merge_order == "JKI":
+ return (AxisIterator._J, AxisIterator._K, AxisIterator._I)
+
+ if self._merge_order == "KIJ":
+ return (AxisIterator._K, AxisIterator._I, AxisIterator._J)
+
+ assert self._merge_order == "KJI"
+ return (AxisIterator._K, AxisIterator._J, AxisIterator._I)
diff --git a/ndsl/dsl/dace/stree/optimizations/clean_tree.py b/ndsl/dsl/dace/stree/optimizations/clean_tree.py
index 5e9ab522..acd7bd79 100644
--- a/ndsl/dsl/dace/stree/optimizations/clean_tree.py
+++ b/ndsl/dsl/dace/stree/optimizations/clean_tree.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
from dace.sdfg.analysis.schedule_tree import treenodes as tn
from ndsl import ndsl_log
@@ -9,6 +7,7 @@ class CleanUpScheduleTree(tn.ScheduleNodeTransformer):
"""Remove `StateBoundary` nodes from children of ScheduleTreeScopes."""
def __init__(self) -> None:
+ super().__init__()
self._removed_state_boundaries = 0
def __str__(self) -> str:
@@ -24,43 +23,48 @@ def _remove_state_boundaries_from_children(
self._removed_state_boundaries += 1
node.children.remove(boundary)
+ def visit_LibraryCall(self, node: tn.LibraryCall) -> tn.LibraryCall | None:
+ # Filter duplicate labeled regions
+ # TODO: this shouldn't be necessary and needs to be cleaned up.
+ if node.node.unique_name.endswith("_patched"):
+ return None
+
+ return node
+
def visit_WhileScope(self, node: tn.WhileScope) -> tn.WhileScope:
self._remove_state_boundaries_from_children(node)
- for child in node.children:
- self.visit(child)
+ self.generic_visit(node)
return node
def visit_ForScope(self, node: tn.ForScope) -> tn.ForScope:
self._remove_state_boundaries_from_children(node)
- for child in node.children:
- self.visit(child)
+ self.generic_visit(node)
return node
def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope:
self._remove_state_boundaries_from_children(node)
- for child in node.children:
- self.visit(child)
+ self.generic_visit(node)
return node
def visit_IfScope(self, node: tn.IfScope) -> tn.IfScope:
self._remove_state_boundaries_from_children(node)
- for child in node.children:
- self.visit(child)
+
+ self.generic_visit(node)
return node
- def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None:
+ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot:
self._removed_state_boundaries = 0
self._remove_state_boundaries_from_children(node)
- for child in node.children:
- self.visit(child)
+ self.generic_visit(node)
ndsl_log.debug(f"{self}: removed {self._removed_state_boundaries} nodes")
+ return node
diff --git a/ndsl/dsl/dace/stree/optimizations/common/__init__.py b/ndsl/dsl/dace/stree/optimizations/common/__init__.py
new file mode 100644
index 00000000..2e342912
--- /dev/null
+++ b/ndsl/dsl/dace/stree/optimizations/common/__init__.py
@@ -0,0 +1,25 @@
+from .memlet import AxisIterator, no_data_dependencies_on_cartesian_axis # isort: skip
+from .loops import is_axis_for, is_axis_map, is_cartesian_axis
+from .topology import (
+ detect_cycle,
+ get_next_node,
+ last_node,
+ list_index,
+ reparent_scope_node,
+ swap_node_position_in_tree,
+)
+
+
+__all__ = [
+ "AxisIterator",
+ "no_data_dependencies_on_cartesian_axis",
+ "is_axis_map",
+ "is_cartesian_axis",
+ "is_axis_for",
+ "get_next_node",
+ "last_node",
+ "swap_node_position_in_tree",
+ "detect_cycle",
+ "list_index",
+ "reparent_scope_node",
+]
diff --git a/ndsl/dsl/dace/stree/optimizations/common/loops.py b/ndsl/dsl/dace/stree/optimizations/common/loops.py
new file mode 100644
index 00000000..1f057954
--- /dev/null
+++ b/ndsl/dsl/dace/stree/optimizations/common/loops.py
@@ -0,0 +1,29 @@
+import dace.sdfg.analysis.schedule_tree.treenodes as tn
+
+from ndsl.dsl.dace.stree.optimizations.common import AxisIterator
+
+
+def is_axis_map(node: tn.MapScope, axis: AxisIterator) -> bool:
+ """Returns true if node is a Map over the given axis."""
+ if len(node.node.map.params) != 1:
+ return False
+
+ param = node.node.map.params[0]
+ assert isinstance(param, str)
+ return axis.is_equal(param)
+
+
+def is_cartesian_axis(node: tn.MapScope | tn.ForScope) -> bool:
+ """Returns true if the given node is a map over any cartesian axis."""
+ for axis in AxisIterator:
+ if (isinstance(node, tn.MapScope) and is_axis_map(node, axis)) or (
+ isinstance(node, tn.ForScope) and is_axis_for(node, axis)
+ ):
+ return True
+
+ return False
+
+
+def is_axis_for(node: tn.ForScope, axis: AxisIterator) -> bool:
+ """Returns true if node is a For over the given axis."""
+ return axis.is_equal(node.loop.loop_variable)
diff --git a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py b/ndsl/dsl/dace/stree/optimizations/common/memlet.py
similarity index 53%
rename from ndsl/dsl/dace/stree/optimizations/memlet_helpers.py
rename to ndsl/dsl/dace/stree/optimizations/common/memlet.py
index 0626133e..d52ca9f5 100644
--- a/ndsl/dsl/dace/stree/optimizations/memlet_helpers.py
+++ b/ndsl/dsl/dace/stree/optimizations/common/memlet.py
@@ -1,7 +1,9 @@
from enum import Enum
+from numbers import Number
-import dace.sdfg.analysis.schedule_tree.treenodes as stree
from dace.memlet import Memlet
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
+from dace.symbolic import symbol
from ndsl import ndsl_log
@@ -17,81 +19,89 @@ def as_str(self) -> str:
def as_cartesian_index(self) -> int:
return self.value[1]
+ def is_equal(self, other: str) -> bool:
+ if self == AxisIterator._K:
+ return other.startswith(self.as_str())
+
+ return other == self.as_str()
+
+
+def normalize_cartesian_indexation(
+ index: Number | symbol, axis: AxisIterator
+) -> symbol:
+ """Return a normalize indexation symbol for cartesian indexation."""
+ if isinstance(index, Number):
+ # Special case for refined cartesian indices, i.e. when `index` is 0.
+ return index
+
+ rename_maps = {}
+ for symb in index.free_symbols:
+ if symb.name.startswith(axis.as_str()):
+ rename_maps[symb] = symbol(axis.as_str())
+ return index.subs(rename_maps)
+
def no_data_dependencies_on_cartesian_axis(
- first: stree.MapScope,
- second: stree.MapScope,
+ first: tn.MapScope,
+ second: tn.MapScope,
axis: AxisIterator,
) -> bool:
- """Check for read after write. Allow when indexation on the axis
- is not offset."""
+ """Check for read after write and write after write with different offsets."""
write_collector = MemletCollector(collect_reads=False)
write_collector.visit(first)
+ other_writes = MemletCollector(collect_reads=False)
+ other_writes.visit(second)
read_collector = MemletCollector(collect_writes=False)
read_collector.visit(second)
+
for write in write_collector.out_memlets:
# TODO: this can be optimized to allow non-overlapping intervals and such in the future
- if write.subset.dims() <= axis.as_cartesian_index():
+ axis_index = axis.as_cartesian_index()
+
+ if write.subset.dims() <= axis_index:
# Dimension does not exist
continue
- previous_axis_index = write.subset[axis.as_cartesian_index()][0]
+ previous_axis_index = normalize_cartesian_indexation(
+ write.subset[axis_index][0], axis
+ )
+
+ # Write-after-write with an offset case
+ for other_write in other_writes.out_memlets:
+ if write.data == other_write.data:
+ if previous_axis_index != normalize_cartesian_indexation(
+ other_write.subset[axis_index][0], axis
+ ):
+ ndsl_log.debug(
+ f"[{axis.name} Merge] Found write after write conflict "
+ f"for {write.data} "
+ f"w/ different offset to {axis.name} ("
+ f"first write at {previous_axis_index}, "
+ f"second write at {other_write.subset[axis_index][0]})"
+ )
+ return False
+
+ # Read-after-write with an offset case
for read in read_collector.in_memlets:
if write.data == read.data:
- if previous_axis_index != read.subset[axis.as_cartesian_index()][0]:
+ if previous_axis_index != normalize_cartesian_indexation(
+ read.subset[axis_index][0], axis
+ ):
ndsl_log.debug(
f"[{axis.name} Merge] Found read after write conflict "
f"for {write.data} "
f"w/ different offset to {axis.name} ("
- f"write at {write.subset[axis.as_cartesian_index()][0]}, "
- f"read at {read.subset[axis.as_cartesian_index()][0]})"
+ f"write at {write.subset[axis_index][0]}, "
+ f"read at {read.subset[axis_index][0]})"
)
return False
- return True
-
-def no_data_dependencies(
- first: stree.MapScope,
- second: stree.MapScope,
- restrict_check_to_k: bool = False,
-) -> bool:
- write_collector = MemletCollector(collect_reads=False)
- write_collector.visit(first)
- read_collector = MemletCollector(collect_writes=False)
- read_collector.visit(second)
- for write in write_collector.out_memlets:
- # Make sure we don't have read after write conditions.
- # TODO: this can be optimized to allow non-overlapping intervals and such in the future
- if restrict_check_to_k:
- if write.subset.dims() < 3:
- # Case of 2D write - no K dependency
- continue
-
- previous_k_index = write.subset[2][0]
- for read in read_collector.in_memlets:
- if write.data == read.data:
- if previous_k_index != read.subset[2][0]:
- print(
- "[K Merge] Found read after write conflict "
- f"for {write.data} "
- "w/ different offset to K ("
- f"write at {write.subset[2][0]}, "
- f"read at {read.subset[2][0]})"
- )
- return False
-
- else:
- if write.data in [read.data for read in read_collector.in_memlets]:
- print(
- f"[All dims merge] Found potential read after write conflict for {write.data}"
- )
- return False
return True
-class MemletCollector(stree.ScheduleNodeVisitor):
+class MemletCollector(tn.ScheduleNodeVisitor):
"""Gathers in_memlets and out_memlets of TaskNodes and LibraryCalls."""
in_memlets: list[Memlet]
@@ -106,13 +116,13 @@ def __init__(
self.in_memlets = []
self.out_memlets = []
- def visit_TaskletNode(self, node: stree.TaskletNode) -> None:
+ def visit_TaskletNode(self, node: tn.TaskletNode) -> None:
if self._collect_reads:
self.in_memlets.extend([memlet for memlet in node.in_memlets.values()])
if self._collect_writes:
self.out_memlets.extend([memlet for memlet in node.out_memlets.values()])
- def visit_LibraryCall(self, node: stree.LibraryCall) -> None:
+ def visit_LibraryCall(self, node: tn.LibraryCall) -> None:
if self._collect_reads:
if isinstance(node.in_memlets, set):
self.in_memlets.extend(node.in_memlets)
@@ -130,7 +140,7 @@ def visit_LibraryCall(self, node: stree.LibraryCall) -> None:
)
-def has_dynamic_memlets(first: stree.MapScope, second: stree.MapScope) -> bool:
+def has_dynamic_memlets(first: tn.MapScope, second: tn.MapScope) -> bool:
first_collector = MemletCollector()
second_collector = MemletCollector()
first_collector.visit(first)
diff --git a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py b/ndsl/dsl/dace/stree/optimizations/common/topology.py
similarity index 65%
rename from ndsl/dsl/dace/stree/optimizations/tree_common_op.py
rename to ndsl/dsl/dace/stree/optimizations/common/topology.py
index 1253ba81..fa06f3db 100644
--- a/ndsl/dsl/dace/stree/optimizations/tree_common_op.py
+++ b/ndsl/dsl/dace/stree/optimizations/common/topology.py
@@ -3,12 +3,30 @@
import dace.sdfg.analysis.schedule_tree.treenodes as tn
+def reparent_scope_node(
+ original_parent: tn.ScheduleTreeScope,
+ new_parent: tn.ScheduleTreeScope,
+ *,
+ prepend: bool = True,
+) -> None:
+ """Re-parent children between two scope nodes"""
+
+ for child in original_parent.children:
+ child.parent = new_parent
+
+ if prepend:
+ new_parent.children = [*original_parent.children, *new_parent.children]
+ else:
+ new_parent.children = [*new_parent.children, *original_parent.children]
+
+
def swap_node_position_in_tree(
top_node: tn.ScheduleTreeScope, child_node: tn.ScheduleTreeScope
) -> None:
"""Top node becomes child, child becomes top node."""
# Ensue parent/children relationship is valid
tn.validate_children_and_parents_align(top_node)
+ assert top_node.parent is not None
# Take refs before swap
top_children = top_node.parent.children
@@ -51,3 +69,15 @@ def list_index(
"""Check if node is in list with "is" operator."""
# compare with "is" to get memory comparison. ".index()" uses value comparison
return next(index for index, element in enumerate(collection) if element is node)
+
+
+def get_next_node(
+ nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode
+) -> tn.ScheduleTreeNode:
+ """Get next node in the children from given node"""
+ return nodes[list_index(nodes, node) + 1]
+
+
+def last_node(nodes: list[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> bool:
+ """Test for last node of list"""
+ return list_index(nodes, node) >= len(nodes) - 1
diff --git a/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py b/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py
new file mode 100644
index 00000000..03edb878
--- /dev/null
+++ b/ndsl/dsl/dace/stree/optimizations/kernelize_maps.py
@@ -0,0 +1,101 @@
+from copy import deepcopy
+
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
+
+from ndsl import Backend
+from ndsl.config import BackendLoopOrder
+from ndsl.dsl.dace.stree.optimizations.common import (
+ AxisIterator,
+ is_axis_map,
+ is_cartesian_axis,
+)
+
+
+class _KernelizeMap(tn.ScheduleNodeTransformer):
+ def __init__(self, axis: AxisIterator) -> None:
+ super().__init__()
+ self._axis = axis
+
+ def __str__(self) -> str:
+ return f"KernelizeMap_{self._axis}"
+
+ def _count_cartesian_children(self, node: tn.ScheduleTreeScope) -> int:
+ cartesian_children = 0
+ for child in node.children:
+ if isinstance(child, (tn.MapScope, tn.ForScope)) and is_cartesian_axis(
+ child
+ ):
+ cartesian_children += 1
+ return cartesian_children
+
+ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope | list[tn.MapScope]:
+ # if this is a map on a cartesian axis
+ # and the children contain two or more cartesian axes
+ if is_axis_map(node, self._axis) and self._count_cartesian_children(node) > 1:
+ kernelized_maps: list[tn.MapScope] = []
+ current_children: list[tn.ScheduleTreeNode] = []
+
+ for child in node.children:
+ current_children.append(child)
+ if isinstance(child, (tn.MapScope, tn.ForScope)) and is_cartesian_axis(
+ child
+ ):
+ kernelized_maps.append(
+ tn.MapScope(
+ node=deepcopy(node.node),
+ children=[child for child in current_children],
+ parent=node.parent,
+ state=node.state,
+ )
+ )
+ current_children = []
+ return kernelized_maps
+
+ return self.generic_visit(node)
+
+
+class KernelizeMaps(tn.ScheduleNodeVisitor):
+ def __init__(self, backend: Backend, *, apply_order: str = "default") -> None:
+ super().__init__()
+ self._backend = backend
+ self._apply_order = apply_order
+
+ if not self._backend.is_gpu_backend():
+ raise ValueError(
+ "The transformation `KernelizeMaps` is only intended to run on GPUs."
+ )
+
+ def __str__(self) -> str:
+ return "KernelizeMaps"
+
+ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None:
+ for axis in self._axis_order():
+ _KernelizeMap(axis).visit(node)
+
+ def _axis_order(self) -> list[AxisIterator]:
+ if self._apply_order == "default":
+ # By default, follow the backend's axis order.
+ return self._axis_order_backend()
+
+ # Allow custom order (e.g. for local optimizations).
+ return self._axis_order_custom()
+
+ def _axis_order_backend(self) -> list[AxisIterator]:
+ if self._backend.loop_order == BackendLoopOrder.IJK:
+ return [AxisIterator._J, AxisIterator._I]
+ if self._backend.loop_order == BackendLoopOrder.KJI:
+ return [AxisIterator._J, AxisIterator._K]
+
+ raise NotImplementedError(
+ f"KernelizeMaps is not configured for loop order {self._backend.loop_order}."
+ )
+
+ def _axis_order_custom(self) -> list[AxisIterator]:
+ if self._apply_order == "JI":
+ return [AxisIterator._J, AxisIterator._I]
+ if self._apply_order == "JK":
+ return [AxisIterator._J, AxisIterator._K]
+
+ raise NotImplementedError(
+ f"KernelizeMaps is not configured for custom apply order {self._apply_order}."
+ )
diff --git a/ndsl/dsl/dace/stree/optimizations/local_optimizations.py b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py
new file mode 100644
index 00000000..b9796613
--- /dev/null
+++ b/ndsl/dsl/dace/stree/optimizations/local_optimizations.py
@@ -0,0 +1,398 @@
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
+
+from ndsl import Backend, OptimizationConfig, ndsl_log
+from ndsl.dsl.dace.stree.optimizations.cartesian_merge import CartesianMerge
+from ndsl.dsl.dace.stree.optimizations.kernelize_maps import KernelizeMaps
+from ndsl.dsl.dace.stree.optimizations.remove_loops import InlineVertical2DWrite
+
+
+class ScheduleTreeScopeTransformer(tn.ScheduleNodeTransformer):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def _breadth_first_callback(self, node: tn.ScheduleTreeScope) -> None:
+ pass
+
+ def _depth_first_callback(self, node: tn.ScheduleTreeScope) -> None:
+ pass
+
+ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_GBlock(self, node: tn.GBlock) -> tn.GBlock:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_LoopScope(self, node: tn.LoopScope) -> tn.LoopScope:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_ForScope(self, node: tn.ForScope) -> tn.ForScope:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_WhileScope(self, node: tn.WhileScope) -> tn.WhileScope:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_DoWhileScope(self, node: tn.DoWhileScope) -> tn.DoWhileScope:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_IfScope(self, node: tn.IfScope) -> tn.IfScope:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_StateIfScope(self, node: tn.StateIfScope) -> tn.StateIfScope:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_ElifScope(self, node: tn.ElifScope) -> tn.ElifScope:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_ElseScope(self, node: tn.ElseScope) -> tn.ElseScope:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_MapScope(self, node: tn.MapScope) -> tn.MapScope:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def visit_ConsumeScope(self, node: tn.ConsumeScope) -> tn.ConsumeScope:
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+
+class _LabeledSection(tn.ScheduleTreeScope):
+ def __init__(
+ self,
+ *,
+ children: list[tn.ScheduleTreeNode],
+ parent: tn.ScheduleTreeScope,
+ label: str,
+ optimizations: OptimizationConfig,
+ ) -> None:
+ super().__init__(children=children, parent=parent)
+ self.label = label
+ self.optimizations = optimizations
+
+ def as_string(self, indent: int = 0) -> str:
+ result = indent * tn.INDENTATION + f"section '{self.label}':\n"
+ return result + super().as_string(indent)
+
+
+class _LabelSections(ScheduleTreeScopeTransformer):
+ """
+ Transform entry/exit labeler nodes into a `LabeledSection` (see above)
+ for easier later handling in case of local optimizations. Handles nested
+ labeled sections.
+
+ Before
+
+ ```none
+ # program before
+
+ library_node("entry my_stencil")
+ map i in [...]
+ map j in [...]
+ map k in [...]
+ # contents of "my_stencil"
+ library node("exit my_stencil")
+
+ # program continues
+ ```
+
+ After
+
+ ```none
+ # program before
+
+ labeled_section "my_stencil":
+ map i in [...]
+ map j in [...]
+ map k in [...]
+ # contents of "my_stencil
+
+ # program continues
+ ```
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def __str__(self) -> str:
+ return "_LabelSections"
+
+ def _depth_first_callback(self, scope: tn.ScheduleTreeScope) -> None:
+ """
+ This is the function that actually does all the work by going over the children of a given schedule tree
+ scope and re-grouping them into labeled sections based on `NDSLRuntime_Label` entry/exit nodes.
+ """
+ # The stack of entry nodes. They pop when the matching exit node is reached. Using a stack adds
+ # support for nested labeled sections.
+ entry_nodes_stack: list[tn.LibraryCall] = []
+
+ # The stack of children. Every new entry node pushes its children into a new stack entry. This allows
+ # one pass to gather nested children.
+ children_stack: list[list[tn.ScheduleTreeNode]] = []
+
+ # Top-level stack is for the current scope.
+ children_stack.append([])
+
+ for child in scope.children:
+ # Unless we are dealing with `tn.LibraryCall` nodes, we push all nodes to the stack of new children.
+ if not isinstance(child, tn.LibraryCall):
+ children_stack[-1].append(child)
+ continue
+
+ if not child.node.name == "NDSLRuntime_Label":
+ # Leave other library call nodes alone.
+ children_stack[-1].append(child)
+ continue
+
+ if child.node.unique_name.startswith("Enter__"):
+ # Keep taps on where we start and open a new list of children.
+ entry_nodes_stack.append(child)
+ children_stack.append([])
+ continue
+
+ # Expect to find an exit node now (matching the entry node that current on top of the stack).
+ if not child.node.unique_name.startswith("Exit__"):
+ raise RuntimeError(
+ f"Unexpected `NDSLRuntime_Label` '{child.node.unique_name}'."
+ )
+
+ # For exit nodes, find the matching entry node and the new children.
+ section_start = entry_nodes_stack.pop()
+ new_children = children_stack.pop()
+
+ # sanity checks
+ # - ensure we have the right section (if not, something is screwed up)
+ name = section_start.node.unique_name.removeprefix("Enter__")
+ assert name == child.node.unique_name.removeprefix("Exit__")
+ # - ensure we have the same parent (if not something is screwed up)
+ parent = section_start.parent
+ assert parent == child.parent
+ # - ensure the stack of children is not empty (it will at least contain the top-level scope)
+ assert len(children_stack) > 0
+
+ # Put all the new children in a `LabeledSection` and push that into the
+ # new children of the above stack of children.
+ new_node = _LabeledSection(
+ children=new_children,
+ parent=parent,
+ label=name,
+ optimizations=section_start.node._local_optimizations,
+ )
+ # re-parent new children to new node
+ for c in new_node.children:
+ c.parent = new_node
+ # push new node into enclosing stack of children
+ children_stack[-1].append(new_node)
+
+ # and - of course - the final book keeping
+ self._labeled_sections += 1
+
+ # set the new children on the current scope
+ scope.children = children_stack.pop()
+
+ # some sanity checks
+ assert len(children_stack) == 0 # expect empty stack
+ for child in scope.children:
+ assert child.parent == scope # expect correct parent
+
+ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot:
+ self._labeled_sections = 0
+
+ # recurse down first to label sections "leaf first"
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ ndsl_log.debug(f"{self}: labeled {self._labeled_sections} sections.")
+ return node
+
+
+class _ApplyLocalOptimizations(ScheduleTreeScopeTransformer):
+ """
+ Applies local optimization in `LabeledSection`s in a "leaf first" approach.
+
+ This work inline and replaces the `LabeledSection` with the results of the local
+ optimization as configured in the `OptimizationConfig` of the `LabeledSection`.
+ """
+
+ def __init__(self, backend: Backend) -> None:
+ super().__init__()
+ self._backend = backend
+
+ def __str__(self) -> str:
+ return "_LabelSections"
+
+ def visit__LabeledSection(self, node: _LabeledSection) -> _LabeledSection:
+ # Recurse into labeled sections to support nested labeled sections.
+ self._breadth_first_callback(node)
+ self.generic_visit(node)
+ self._depth_first_callback(node)
+
+ return node
+
+ def _depth_first_callback(self, scope: tn.ScheduleTreeScope) -> None:
+ new_children: list[tn.ScheduleTreeNode] = []
+
+ for child in scope.children:
+ # Any child that isn't a _LabeledSection gets directly added to the list of new children.
+ if not isinstance(child, _LabeledSection):
+ new_children.append(child)
+ continue
+
+ # For labeled sections, apply the local optimizations to the sections' children, then
+ # append the possibly transformed children to the list of new children (without the
+ # labeled section).
+
+ # TODO
+ # The code below is basically an `StreePipeline`. I've duplicated that
+ # pipeline because we need some clever engineering to not get into a
+ # hell of dependency circles (where the local optimizations are pipeline pass
+ # and in itself depend on the pipeline).
+
+ config = child.optimizations
+ assert config.stree.enabled
+
+ # HACK
+ # Below, we are calling `visit_ScheduleTreeRoot` with a `LabeledSection`. This works
+ # because python uses duck-typing.
+ # TODO
+ # Clean up pipeline passes and the pipeline itself such that they can work
+ # on any subtree (i.e. any `ScheduleTreeScope`).
+
+ if self._backend.is_gpu_backend():
+ if config.stree.inline_K_loops_size_one:
+ gpu_inliner = InlineVertical2DWrite()
+ gpu_inliner.visit_ScheduleTreeRoot(child)
+
+ if config.stree.merger.enabled:
+ gpu_merger = CartesianMerge(
+ self._backend,
+ overcompute=config.stree.merger.overcompute,
+ merge_order=config.stree.merger.order,
+ )
+ gpu_merger.visit_ScheduleTreeRoot(child)
+
+ if config.stree.kernelize:
+ if config.stree.merger.order not in ("IJK", "KJI"):
+ ndsl_log.warning(
+ "Can't locally kernelize maps. Unknown apply oder. Skipping this pass."
+ )
+ else:
+ # Follow the merge-order for kernelization
+ gpu_kernelizer = KernelizeMaps(
+ self._backend,
+ apply_order=(
+ "JI" if config.stree.merger.order == "IJK" else "JK"
+ ),
+ )
+ gpu_kernelizer.visit_ScheduleTreeRoot(child)
+
+ if config.stree.refine_transients:
+ # We can't know if transients are local to the scope that we are working in.
+ # In they are not, transient refinement can generate wrong results and refine
+ # too eagerly. Global transient refinement will also work in this section.
+ ndsl_log.warning(
+ "[Local-Opt]: Transient refinement can't e applied on a local scale "
+ "because it needs the global information on where/how transient data "
+ "is used. Please enable transient refinement on your global optimization "
+ "config and disable it here. No transients will be refined on the local "
+ "scale even if this option is turned on."
+ )
+ else:
+ if config.stree.inline_K_loops_size_one:
+ cpu_inliner = InlineVertical2DWrite()
+ cpu_inliner.visit_ScheduleTreeRoot(child)
+
+ if config.stree.merger.enabled:
+ cpu_merger = CartesianMerge(
+ self._backend,
+ overcompute=config.stree.merger.overcompute,
+ merge_order=config.stree.merger.order,
+ )
+ cpu_merger.visit_ScheduleTreeRoot(child)
+
+ if config.stree.refine_transients:
+ # We can't know if transients are local to the scope that we are working in.
+ # In they are not, transient refinement can generate wrong results and refine
+ # too eagerly. Global transient refinement will also work in this section.
+ ndsl_log.warning(
+ "[Local-Opt]: Transient refinement can't e applied on a local scale "
+ "because it needs the global information on where/how transient data "
+ "is used. Please enable transient refinement on your global optimization "
+ "config and disable it here. No transients will be refined on the local "
+ "scale even if this option is turned on."
+ )
+
+ # Replace this `LabeledSection` with just the (now transformed) children.
+ for c in child.children:
+ # be sure to re-parent the children of this node to the new parent
+ c.parent = child.parent
+ new_children.append(c)
+
+ scope.children = new_children
+
+ # sanity checks
+ for child in scope.children:
+ assert child.parent == scope # expect correct parent
+ assert not isinstance(
+ child, _LabeledSection
+ ) # no labeled sections should be left at this point
+
+
+class LocalOptimizations(tn.ScheduleNodeVisitor):
+ def __init__(self, backend: Backend) -> None:
+ super().__init__()
+ self._backend = backend
+
+ def __str__(self) -> str:
+ return "LocalOptimizations"
+
+ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None:
+ # First, parse enter/exit labels into `LabeledSection`s...
+ _LabelSections().visit(node)
+
+ # .. then, apply local optimizations on children of `LabeledSection`s.
+ _ApplyLocalOptimizations(self._backend).visit(node)
diff --git a/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py
new file mode 100644
index 00000000..de4c21a2
--- /dev/null
+++ b/ndsl/dsl/dace/stree/optimizations/offgrid_conditionals.py
@@ -0,0 +1,129 @@
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
+
+from ndsl import ndsl_log
+from ndsl.dsl.dace.stree.optimizations.common import (
+ AxisIterator,
+ get_next_node,
+ is_axis_map,
+ last_node,
+ list_index,
+)
+
+
+class InlineOffgridConditionals(tn.ScheduleNodeVisitor):
+ """
+ Push offgrid conditional inside their cartesian block, duplicating the
+ conditional if needed.
+
+ Turning:
+ ```
+ if a_flag == 0:
+ map i, j, k:
+ ...
+ map i, j, k:
+ ...
+ ```
+ into
+ ```
+ map i, j, k:
+ if a_flag == 0:
+ ...
+ map i, j, k:
+ if a_flag == 0:
+ ...
+ ```
+ """
+
+ _axis: AxisIterator
+
+ def __init__(self, axis: AxisIterator) -> None:
+ super().__init__()
+ self._axis = axis
+
+ def __str__(self) -> str:
+ return f"InlineOffgridConditionals_{self._axis}"
+
+ def visit_IfScope(self, node: tn.IfScope) -> None:
+ assert node.parent is not None # just to keep pyright happy
+
+ # For now, skip in case there's an `elif` or `else` following.
+ if not last_node(node.parent.children, node):
+ next_node = get_next_node(node.parent.children, node)
+ if isinstance(next_node, (tn.ElifScope, tn.ElseScope)):
+ ndsl_log.debug(
+ "Can't handle conditionals with `elif` and `else` blocks yet :("
+ )
+ return
+
+ if not all(
+ [
+ isinstance(child, tn.MapScope) and is_axis_map(child, self._axis)
+ for child in node.children
+ ]
+ ):
+ return
+
+ # If all children are maps over the correct axis, move the if inside.
+ new_nodes: list[tn.MapScope] = []
+
+ for child in node.children:
+ assert isinstance(
+ child, tn.MapScope
+ ) # otherwise the condition above is wrong
+
+ if_scope = tn.IfScope(
+ condition=node.condition, children=child.children, parent=child
+ )
+
+ for map_child in child.children:
+ map_child.parent = if_scope # re-parent to new if_scope
+
+ child.children = [if_scope]
+ child.parent = node.parent # re-parent to parent of old if_scope
+ new_nodes.append(child)
+
+ insert_at = list_index(node.parent.children, node)
+ node.parent.children[insert_at:insert_at] = new_nodes
+ node.parent.children.remove(node)
+
+
+class ExtractOffgridConditionals(tn.ScheduleNodeTransformer):
+ """
+ Push offgrid conditional outside of their cartesian block.
+
+ This is the inverse transform of InlineOffgridConditionals.
+ """
+
+ def __str__(self) -> str:
+ return "ExtractOffgridConditionals"
+
+
+class MergeConditionals(tn.ScheduleNodeTransformer):
+ """
+ Merge consecutive and equal conditionals.
+
+ Turning:
+ ```
+ if a_flag == 0:
+ map i, j, k:
+ ...
+ if a_flag == 0:
+ map i, j, k:
+ ...
+ ```
+ into
+ ```
+ if a_flag == 0:
+ map i, j, k:
+ ...
+ map i, j, k:
+ ...
+ ```
+
+ Outside of user code, combination of ExtractOffgridConditionals,
+ InlineOffgridConditionals and CartesianMapMerge can lead to this
+ pattern.
+ """
+
+ def __str__(self) -> str:
+ return "MergeConditionals"
diff --git a/ndsl/dsl/dace/stree/optimizations/refine_transients.py b/ndsl/dsl/dace/stree/optimizations/refine_transients.py
index 19a425bd..ca0178f8 100644
--- a/ndsl/dsl/dace/stree/optimizations/refine_transients.py
+++ b/ndsl/dsl/dace/stree/optimizations/refine_transients.py
@@ -1,11 +1,10 @@
import warnings
import dace.data
-import dace.sdfg.analysis.schedule_tree.treenodes as stree
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
from ndsl import ndsl_log
-from ndsl.config import Backend, BackendFramework
-from ndsl.dsl.dace.stree.optimizations.memlet_helpers import AxisIterator
+from ndsl.dsl.dace.stree.optimizations.common import AxisIterator
def _change_index_of_tuple(
@@ -28,21 +27,16 @@ def _reduce_cartesian_axis_size_to_1(
transient_map_reads: dace.subsets.Range | None,
transient_map_writes: dace.subsets.Range | None,
transient_data: dace.data.Data,
- layout_map: tuple[int, ...],
) -> bool:
"""Reduce dimension size of transient to 1 if all access (reads and writes)
are atomic"""
# Dev Note: Better dataflow analysis would look at exactly
- # what's goin on here!
+ # what's going on here!
# Assume 3D cartesian!
if len(transient_data.shape) < 3:
- warnings.warn(
- f"Potential non-3D array: {transient_data}, skipping.",
- UserWarning,
- stacklevel=2,
- )
+ ndsl_log.debug(f"Potential non-3D array: {transient_data}, skipping.")
return False
read_write_range: dace.subsets.Range = dace.subsets.union(
@@ -59,26 +53,17 @@ def _reduce_cartesian_axis_size_to_1(
# therefore this dimension can be removed. BUT we are not truly
# removing it, we are reducing it to 1 to not have to deal
# with different slicing.
- transient_data.shape = _change_index_of_tuple(
+ new_shape = _change_index_of_tuple(
transient_data.shape,
axis.as_cartesian_index(),
value=1,
)
-
- if len(transient_data.shape) == 3:
- layout = [*layout_map]
- else:
- data_dim_count = len(transient_data.shape) - 3
- layout = [dim + data_dim_count for dim in layout_map] + [
- i - 1 for i in range(data_dim_count, 0, -1)
- ]
-
- transient_data.set_strides_from_layout(*layout)
+ transient_data.set_shape(new_shape)
transient_data.lifetime = dace.dtypes.AllocationLifetime.State
return True
-class CollectTransientRangeAccess(stree.ScheduleNodeVisitor):
+class CollectTransientRangeAccess(tn.ScheduleNodeVisitor):
"""Unionize all transient arrays access into a single Range."""
def __init__(self) -> None:
@@ -95,18 +80,16 @@ def __init__(self) -> None:
self.transients_range_writes: dict[str, dace.subsets.Range | None] = {}
self.transients_range_reads: dict[str, dace.subsets.Range | None] = {}
- def __str__(self) -> str:
- return "CartesianCollectMaps"
-
def _find_first_map_or_loop(
self,
- node: stree.TaskletNode,
+ node: tn.TaskletNode,
axis: AxisIterator,
) -> dace.nodes.MapEntry | None:
parent = node.parent
while parent is not None:
- if isinstance(parent, stree.MapScope):
- for p in parent.node.params:
+ if isinstance(parent, tn.MapScope):
+ for p in parent.node.map.params:
+ assert isinstance(p, str)
if p.startswith(axis.as_str()):
return parent.node
@@ -115,8 +98,8 @@ def _find_first_map_or_loop(
def _record_access(
self,
- node: stree.TaskletNode,
- memlets: stree.MemletSet,
+ node: tn.TaskletNode,
+ memlets: tn.MemletSet,
recording_set: dict[str, dace.subsets.Range | None],
) -> None:
for memlet in memlets:
@@ -149,11 +132,11 @@ def _record_access(
AxisIterator._K.as_cartesian_index()
].add(map_entry)
- def visit_TaskletNode(self, node: stree.TaskletNode) -> None:
- self._record_access(node, node.input_memlets(), self.transients_range_writes)
- self._record_access(node, node.output_memlets(), self.transients_range_reads)
+ def visit_TaskletNode(self, node: tn.TaskletNode) -> None:
+ self._record_access(node, node.input_memlets(), self.transients_range_reads)
+ self._record_access(node, node.output_memlets(), self.transients_range_writes)
- def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None:
+ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None:
self.containers = node.containers
for name, data in self.containers.items():
if data.transient and isinstance(data, dace.data.Array):
@@ -161,23 +144,20 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None:
self.transients_range_writes[name] = None
self.transients_range_reads[name] = None
- for child in node.children:
- self.visit(child)
+ self.generic_visit(node)
-class RebuildMemletsFromContainers(stree.ScheduleNodeVisitor):
- """Rebuild memlets from containers to ensure they are scope to the right size."""
+class RebuildMemletsFromContainers(tn.ScheduleNodeVisitor):
+ """Rebuild memlets from containers to ensure they are scoped to the right size."""
def __init__(self, refined_arrays: set[str]) -> None:
self._refined_arrays = refined_arrays
- def __str__(self) -> str:
- return "RefineTransientAxis"
-
- def visit_TaskletNode(self, node: stree.TaskletNode) -> None:
+ def visit_TaskletNode(self, node: tn.TaskletNode) -> None:
for memlet in [*node.output_memlets(), *node.input_memlets()]:
if memlet.data not in self._refined_arrays:
continue
+
array = self.containers[memlet.data]
if array.transient:
if not isinstance(memlet.subset, dace.subsets.Range):
@@ -191,13 +171,12 @@ def visit_TaskletNode(self, node: stree.TaskletNode) -> None:
if array.shape[index] == 1:
memlet.subset.ranges[index] = (0, 0, 1)
- def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None:
+ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None:
self.containers = node.containers
- for child in node.children:
- self.visit(child)
+ self.generic_visit(node)
-class CartesianRefineTransients(stree.ScheduleNodeTransformer):
+class CartesianRefineTransients(tn.ScheduleNodeTransformer):
"""Refine (reduce dimensionality) of transients based on their true use in
the cartesian dimensions.
@@ -210,7 +189,7 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer):
cartesian axis) it will reduce that axis to 1 if all access are atomic
(exactly _one_ element of the array is ever worked on in a single loop)
- It will refuse to merge if the transient is used in multiple loops of for
- a given axis - irrigardless of it's access pattern (e.g. even if it could be
+ a given axis - regardless of it's access pattern (e.g. even if it could be
refine because it's always written first.)
It should but cannot do or will produce bugs if:
@@ -240,7 +219,7 @@ class CartesianRefineTransients(stree.ScheduleNodeTransformer):
memory (e.g. halo) for the `RebuildMemletsFromContainers`!
"""
- def __init__(self, backend: Backend) -> None:
+ def __init__(self) -> None:
warnings.warn(
"CartesianRefineTransients is a WIP. It's usage is *severely* limited "
"and will most likely lead to bad numerics. Check the docs, check utest.",
@@ -248,25 +227,19 @@ def __init__(self, backend: Backend) -> None:
stacklevel=2,
)
- if not backend.is_orchestrated() or backend.framework != BackendFramework.DACE:
- raise NotImplementedError(
- f"[Schedule Tree Opt] CartesianRefineTransient not implemented for backend {backend}"
- )
- self.layout_map = backend.as_layout_map()
- self.refined_array: set[str] = set()
-
def __str__(self) -> str:
return "CartesianRefineTransients"
- def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None:
+ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None:
collect_map = CollectTransientRangeAccess()
collect_map.visit(node)
# Remove Axis
- refined_transient = 0
+ refined_arrays: set[str] = set()
for name, data in node.containers.items():
if not (data.transient and isinstance(data, dace.data.Array)):
continue
+
refined = False
for axis in AxisIterator:
# We do not refine multi-map transients
@@ -279,18 +252,18 @@ def visit_ScheduleTreeRoot(self, node: stree.ScheduleTreeRoot) -> None:
> 1
):
continue
+
# Refine axis down to 1
refined |= _reduce_cartesian_axis_size_to_1(
axis,
collect_map.transients_range_reads[name],
collect_map.transients_range_writes[name],
data,
- self.layout_map,
)
- refined_transient += 1 if refined else 0
- self.refined_array.add(name)
+ if refined:
+ refined_arrays.add(name)
- RebuildMemletsFromContainers(self.refined_array).visit(node)
+ RebuildMemletsFromContainers(refined_arrays).visit(node)
- ndsl_log.debug(f"🚀 {refined_transient} Transient refined")
+ ndsl_log.debug(f"🚀 {len(refined_arrays)} Transient refined")
diff --git a/ndsl/dsl/dace/stree/optimizations/remove_loops.py b/ndsl/dsl/dace/stree/optimizations/remove_loops.py
new file mode 100644
index 00000000..89716404
--- /dev/null
+++ b/ndsl/dsl/dace/stree/optimizations/remove_loops.py
@@ -0,0 +1,87 @@
+import ast
+from typing import Any
+
+import dace
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
+
+from ndsl import ndsl_log
+from ndsl.dsl.dace.stree.optimizations.common import (
+ AxisIterator,
+ is_axis_for,
+ list_index,
+)
+from ndsl.dsl.dace.stree.optimizations.replace_axis_symbol import ReplaceAxisSymbol
+
+
+class InlineVertical2DWrite(tn.ScheduleNodeVisitor):
+ """Inline K index value for 2D write vertical while removing for loop.
+
+ Transforming:
+ ```
+ for __k = 0; __k < 1; __k = __k + 1:
+ map __j, __i:
+ field[__i, __j] = tasklet(field_in[__i, __j, __k])
+ ```
+
+ Into
+ ```
+ map __j, __i:
+ field[__i, __j] = tasklet(field_in[__i, __j, 0])
+ ```
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self._for_scopes_removed = 0
+
+ def __str__(self) -> str:
+ return "InlineVertical2DWrite"
+
+ def visit_ForScope(self, the_for: tn.ForScope) -> None:
+ if not is_axis_for(the_for, AxisIterator._K):
+ return
+
+ assert the_for.parent is not None # just to keep pyright happy
+
+ # Retrieve init/bound value by executing the code and replace usage of it
+ # If the code cannot be executed (no-literal variable part of the op, etc.)
+ # we will _not_ inline
+ try:
+ exec_locals: dict[str, Any] = {}
+ exec_globals: dict[str, Any] = {}
+ exec(
+ ast.unparse(the_for.loop.init_statement.code[0]),
+ exec_globals,
+ exec_locals,
+ )
+ init_value = exec_locals[the_for.loop.loop_variable]
+ bound_value = eval(
+ ast.unparse(the_for.loop.loop_condition.code[0].value.comparators)
+ )
+ except Exception as _:
+ return
+ if abs(bound_value - init_value) != 1:
+ return
+
+ ReplaceAxisSymbol(
+ {dace.symbol(the_for.loop.loop_variable): str(init_value)}
+ ).visit(the_for)
+
+ # Insert children of the ForScope to parent
+ insert_at = list_index(the_for.parent.children, the_for)
+ for child in the_for.children:
+ child.parent = the_for.parent
+ the_for.parent.children[insert_at:insert_at] = the_for.children
+
+ # Remove ForScope
+ the_for.parent.children.remove(the_for)
+ self._for_scopes_removed += 1
+ assert len(the_for.children) > 0
+
+ def visit_ScheduleTreeRoot(self, the_root: tn.ScheduleTreeRoot) -> None:
+ self._for_scopes_removed = 0
+
+ for child in the_root.children:
+ self.visit(child)
+
+ ndsl_log.debug(f"🚀 {self}: {self._for_scopes_removed} inlined")
diff --git a/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py
new file mode 100644
index 00000000..e4b83b03
--- /dev/null
+++ b/ndsl/dsl/dace/stree/optimizations/replace_axis_symbol.py
@@ -0,0 +1,36 @@
+import itertools
+import re
+
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
+from dace.symbolic import symbol
+
+
+class ReplaceAxisSymbol(tn.ScheduleNodeVisitor):
+ def __init__(self, axis_replacements: dict[str | symbol, str | symbol]) -> None:
+ self._axis_replacements = axis_replacements
+
+ def visit_TaskletNode(self, node: tn.TaskletNode) -> None:
+ for memlet in itertools.chain(
+ node.in_memlets.values(), node.out_memlets.values()
+ ):
+ memlet.replace(self._axis_replacements)
+
+ if node.node.label.startswith("masklet"):
+ for old, new in self._axis_replacements.items():
+ # use regex to match word boundaries (with `\b`)
+ node.node.code.as_string = re.sub(
+ rf"\b{str(old)}\b", str(new), node.node.code.as_string
+ )
+
+ def visit_IfScope(self, node: tn.IfScope) -> None:
+ for old, new in self._axis_replacements.items():
+ # use regex to match word boundaries (with `\b`)
+ node.condition.as_string = re.sub(
+ rf"\b{str(old)}\b", str(new), node.condition.as_string
+ )
+
+ for child in node.children:
+ self.visit(child)
+
+ def __str__(self) -> str:
+ return "ReplaceAxisSymbol"
diff --git a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py
index 2583ec2d..9f7e4be4 100644
--- a/ndsl/dsl/dace/stree/optimizations/specialize_maps.py
+++ b/ndsl/dsl/dace/stree/optimizations/specialize_maps.py
@@ -1,15 +1,16 @@
-import dace.sdfg.analysis.schedule_tree.treenodes as stree
import dace.subsets as sbs
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
-class SpecializeCartesianMaps(stree.ScheduleNodeVisitor):
+class SpecializeCartesianMaps(tn.ScheduleNodeVisitor):
def __init__(self, mappings: dict[str, int]) -> None:
super().__init__()
self._mappings = mappings
- def visit_MapScope(self, node: stree.MapScope) -> None:
+ def visit_MapScope(self, node: tn.MapScope) -> None:
dims = []
for p in node.node.map.params:
+ assert isinstance(p, str)
if p == "__i":
dims.append((0, self._mappings["__I"], 1))
if p == "__j":
@@ -19,3 +20,6 @@ def visit_MapScope(self, node: stree.MapScope) -> None:
node.node.map.range = sbs.Range(dims)
self.visit(node.children)
+
+ def __str__(self) -> str:
+ return "SpecializeCartesianMaps"
diff --git a/ndsl/dsl/dace/stree/optimizations/statistics.py b/ndsl/dsl/dace/stree/optimizations/statistics.py
new file mode 100644
index 00000000..54c0b09d
--- /dev/null
+++ b/ndsl/dsl/dace/stree/optimizations/statistics.py
@@ -0,0 +1,110 @@
+import dataclasses
+
+import dace
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
+
+from ndsl.dsl.dace.stree.optimizations.common import (
+ AxisIterator,
+ is_axis_for,
+ is_axis_map,
+)
+
+
+class CountCartesianLoops(tn.ScheduleNodeVisitor):
+ def __init__(self) -> None:
+ super().__init__()
+ self._maps = [0, 0, 0]
+ self._fors = [0, 0, 0]
+ self._3D_kernels = 0
+
+ def visit_MapScope(self, node: tn.MapScope) -> None:
+ for axis in AxisIterator:
+ if is_axis_map(node, axis):
+ self._maps[axis.as_cartesian_index()] += 1
+
+ if isinstance(node.children[0], tn.MapScope) and isinstance(
+ node.children[0].children[0], tn.MapScope
+ ):
+ self._3D_kernels += 1
+
+ self.visit(node.children)
+
+ def visit_ForScope(self, node: tn.ForScope) -> None:
+ for axis in AxisIterator:
+ if is_axis_for(node, axis):
+ self._fors[axis.as_cartesian_index()] += 1
+
+ self.visit(node.children)
+
+
+class CountTransient(tn.ScheduleNodeVisitor):
+ def __init__(self) -> None:
+ super().__init__()
+ self._counts = [0, 0, 0, 0, 0]
+
+ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot) -> None:
+ for data in node.containers.values():
+ non_atomic_dims_count = sum(1 for x in data.shape if x != 1)
+ if isinstance(data, dace.data.Array) and data.transient:
+ if non_atomic_dims_count == 0:
+ self._counts[0] += 1
+ elif non_atomic_dims_count == 1:
+ self._counts[1] += 1
+ elif non_atomic_dims_count == 2:
+ self._counts[2] += 1
+ elif non_atomic_dims_count == 3:
+ self._counts[3] += 1
+ else:
+ self._counts[4] += 1
+
+
+class TreeOptimizationStatistics:
+ """Capture basic statistics on the schedule tree optimization actions"""
+
+ @dataclasses.dataclass
+ class Record:
+ """Private record of a state of a tree"""
+
+ cartesian_maps: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0])
+ cartesian_fors: list[int] = dataclasses.field(default_factory=lambda: [0, 0, 0])
+ threeD_kernels: int = 0
+ transients: list[int] = dataclasses.field(
+ default_factory=lambda: [0, 0, 0, 0, 0]
+ )
+
+ def __init__(self) -> None:
+ self._original_record = TreeOptimizationStatistics.Record()
+ self._optimized_record = TreeOptimizationStatistics.Record()
+
+ def _record(
+ self,
+ record: Record,
+ tree_root: tn.ScheduleTreeRoot,
+ ) -> None:
+ """Record the state of a tree"""
+ c = CountCartesianLoops()
+ c.visit(tree_root)
+ record.cartesian_fors = c._fors
+ record.cartesian_maps = c._maps
+ record.threeD_kernels = c._3D_kernels
+
+ c = CountTransient()
+ c.visit(tree_root)
+ record.transients = c._counts
+
+ def original(self, tree_root: tn.ScheduleTreeRoot) -> None:
+ """Record the original state of the tree, before optimization"""
+ self._record(self._original_record, tree_root)
+
+ def optimized(self, tree_root: tn.ScheduleTreeRoot) -> None:
+ """Record the state of the tree after optimization"""
+ self._record(self._optimized_record, tree_root)
+
+ def report(self) -> str:
+ """Craft a concize string reporting on the statistics"""
+ msg = "Tree optimization:\n"
+ msg += f" Cartesian maps [I, J, K]: {self._original_record.cartesian_maps} -> {self._optimized_record.cartesian_maps}\n"
+ msg += f" Cartesian fors [I, J, K]: {self._original_record.cartesian_fors} -> {self._optimized_record.cartesian_fors}\n"
+ msg += f" Transients [Scalarized Array, 1D, 2D, 3D, 4D+]: {self._original_record.transients} -> {self._optimized_record.transients}\n"
+ msg += f" Full 3D kernels: {self._original_record.threeD_kernels} -> {self._optimized_record.threeD_kernels}\n"
+ return msg
diff --git a/ndsl/dsl/dace/stree/pipeline.py b/ndsl/dsl/dace/stree/pipeline.py
index f9bc452f..50f2d32e 100644
--- a/ndsl/dsl/dace/stree/pipeline.py
+++ b/ndsl/dsl/dace/stree/pipeline.py
@@ -1,16 +1,24 @@
from pathlib import Path
-import dace.sdfg.analysis.schedule_tree.treenodes as stree
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
-from ndsl import ndsl_log_on_rank_0
-from ndsl.dsl.dace.stree.optimizations import AxisIterator, CartesianAxisMerge
+from ndsl import Backend, OptimizationConfig, ndsl_log_on_rank_0
+from ndsl.dsl.dace.stree.optimizations import (
+ CartesianMerge,
+ CartesianRefineTransients,
+ CleanUpScheduleTree,
+ InlineVertical2DWrite,
+ KernelizeMaps,
+ LocalOptimizations,
+ TreeOptimizationStatistics,
+)
class StreePipeline:
def __init__(
self,
*,
- passes: list[stree.ScheduleNodeTransformer],
+ passes: list[tn.ScheduleNodeVisitor],
cache_directory: Path | None = None,
) -> None:
if cache_directory is None:
@@ -27,10 +35,14 @@ def __repr__(self) -> str:
def run(
self,
- stree: stree.ScheduleTreeRoot,
+ stree: tn.ScheduleTreeScope,
verbose: bool = False,
- ) -> stree.ScheduleTreeRoot:
+ ) -> tn.ScheduleTreeScope:
+ tree_stats = TreeOptimizationStatistics()
+ tree_stats.original(stree)
+
for i, p in enumerate(self.passes):
+ path: Path | None = None
if verbose:
path = self.cache_directory / f"pass{i}_{p}.txt"
ndsl_log_on_rank_0.info(f"[Stree OPT] {p} (saving {path} after)")
@@ -38,23 +50,42 @@ def run(
p.visit(stree)
if verbose:
+ assert path is not None
with open(path, "w+") as f:
f.write(stree.as_string())
+ tree_stats.optimized(stree)
+ ndsl_log_on_rank_0.info(tree_stats.report())
return stree
class CPUPipeline(StreePipeline):
def __init__(
self,
+ config: OptimizationConfig,
+ backend: Backend,
*,
- passes: list[stree.ScheduleNodeTransformer] | None = None,
+ passes: list[tn.ScheduleNodeVisitor] | None = None,
cache_directory: Path | None = None,
) -> None:
+ if passes is None:
+ ppl_passes = [CleanUpScheduleTree(), LocalOptimizations(backend)]
+ if config.stree.inline_K_loops_size_one:
+ ppl_passes.append(InlineVertical2DWrite())
+ if config.stree.merger.enabled:
+ ppl_passes.append(
+ CartesianMerge(
+ backend,
+ overcompute=config.stree.merger.overcompute,
+ merge_order=config.stree.merger.order,
+ )
+ )
+ if config.stree.refine_transients:
+ ppl_passes.append(CartesianRefineTransients())
+ else:
+ ppl_passes = passes
super().__init__(
- passes=(
- passes if passes is not None else [CartesianAxisMerge(AxisIterator._K)]
- ),
+ passes=ppl_passes,
cache_directory=cache_directory,
)
@@ -62,10 +93,33 @@ def __init__(
class GPUPipeline(StreePipeline):
def __init__(
self,
- passes: list[stree.ScheduleNodeTransformer] | None = None,
+ config: OptimizationConfig,
+ backend: Backend,
+ *,
+ passes: list[tn.ScheduleNodeVisitor] | None = None,
cache_directory: Path | None = None,
) -> None:
+ if passes is None:
+ ppl_passes = [CleanUpScheduleTree(), LocalOptimizations(backend)]
+ if config.stree.inline_K_loops_size_one:
+ ppl_passes.append(InlineVertical2DWrite())
+ if config.stree.merger.enabled:
+ ppl_passes.append(
+ CartesianMerge(backend, overcompute=config.stree.merger.overcompute)
+ )
+ if config.stree.kernelize:
+ ppl_passes.append(KernelizeMaps(backend))
+ if config.stree.refine_transients:
+ # TODO
+ # 🐞 Transient refine can't be used
+ # because of bugs transients showing in code generation
+ # ppl_passes.append(CartesianRefineTransients(backend))
+ raise ValueError(
+ "Transient refinement is currently unavailable in the GPU pipeline."
+ )
+ else:
+ ppl_passes = passes
super().__init__(
- passes=passes if passes is not None else [],
+ passes=ppl_passes,
cache_directory=cache_directory,
)
diff --git a/ndsl/dsl/ndsl_runtime.py b/ndsl/dsl/ndsl_runtime.py
index a994c61a..492e62dc 100644
--- a/ndsl/dsl/ndsl_runtime.py
+++ b/ndsl/dsl/ndsl_runtime.py
@@ -5,6 +5,7 @@
from collections.abc import Callable
from typing import Any, Sequence
+from ndsl import OptimizationConfig
from ndsl.dsl.dace.orchestration import orchestrate
from ndsl.dsl.stencil import StencilFactory
from ndsl.dsl.typing import Float
@@ -21,10 +22,15 @@ class NDSLRuntime:
The __call__ function will automatically be orchestrated."""
- def __init__(self, stencil_factory: StencilFactory) -> None:
+ def __init__(
+ self,
+ stencil_factory: StencilFactory,
+ optimization_config: OptimizationConfig | None = None,
+ ) -> None:
self._stencil_factory = stencil_factory
# Use this flag to detect that the init wasn't done properly
self._base_class_was_properly_super_init = True
+ self._optimization_config = optimization_config
def __init_subclass__(cls: type[NDSLRuntime], **kwargs: dict[str, Any]) -> None:
# WARNING: no code outside the `init_decorator` this is cls
@@ -75,6 +81,7 @@ def check_for_quantity(object_: object) -> None:
orchestrate(
obj=self,
config=self._stencil_factory.config.dace_config,
+ optimization_config=self._optimization_config,
)
def __getattribute__(self, name: str) -> Any:
diff --git a/ndsl/dsl/optimization_config.py b/ndsl/dsl/optimization_config.py
new file mode 100644
index 00000000..fdcd7166
--- /dev/null
+++ b/ndsl/dsl/optimization_config.py
@@ -0,0 +1,53 @@
+import os
+from dataclasses import dataclass, field
+
+
+@dataclass
+class OptimizationConfig:
+ @dataclass
+ class Tree:
+ """Optimization using the Schedule Tree IR"""
+
+ @dataclass
+ class Merger:
+ enabled: bool = True
+ """Enable cartesian axis merging."""
+
+ overcompute: bool = (
+ os.getenv("NDSL_STREE_OVERCOMPUTE_MERGE", "True").lower() == "true"
+ )
+ """When merging allow maps of different sizes to merge by inserting an `if` guard."""
+
+ order: str = "default"
+ """
+ Allows to manually override the merging order (e.g. `KJI` will merge `K`, then `J`, then `I`).
+ The default follows loop order of the backend given to `CartesianMerge`.
+ """
+
+ enabled: bool = os.getenv("NDSL_STREE_OPT", "False").lower() == "true"
+ """Enable Schedule Tree transformations."""
+
+ # TODO: Is it safe? Deactivate by default for now
+ inline_K_loops_size_one: bool = False
+ """"Remove serial for loops of size one in the K-axis."""
+
+ kernelize: bool = True
+ """Enable maximizing 3-axis kernelization by duplicating maps (GPU only)."""
+
+ merger: Merger = field(default_factory=Merger)
+ """Configuration object for cartesian axis merging."""
+
+ refine_transients: bool = True
+ """Reduce dimensionality of transient arrays based on their usage."""
+
+ @dataclass
+ class GPU:
+ """Optimization dedicated for GPU"""
+
+ common_gpu_xforms: bool = False
+ """DaCe common xforms bundled in `apply_gpu_transformations`"""
+
+ stree: Tree = field(default_factory=Tree)
+ gpu: GPU = field(default_factory=GPU)
+
+ name: str = "unset"
diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py
index d5d37264..96d00946 100644
--- a/ndsl/dsl/stencil.py
+++ b/ndsl/dsl/stencil.py
@@ -7,8 +7,8 @@
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, cast
-import dace
import numpy as np
+from dace.config import Config as DaceConfig
from gt4py.cartesian import config as gt_config
from gt4py.cartesian import definitions as gt_definitions
from gt4py.cartesian import gtscript
@@ -321,7 +321,7 @@ def __init__(
BackendFramework.DACE
== self.stencil_config.compilation_config.backend.framework
):
- dace.Config.set(
+ DaceConfig.set(
"default_build_folder",
value="{gt_root}/{gt_cache}/dacecache".format(
gt_root=gt_config.cache_settings["root_path"],
@@ -881,6 +881,8 @@ def _origin_from_dims(self, dims: Iterable[str]) -> list[int]:
return_origin.append(self.origin[1])
elif dim in K_DIMS:
return_origin.append(self.origin[2])
+ else:
+ raise ValueError(f"Unknown dimension '{dim}'.")
return return_origin
def _domain_from_dims(self, dimensions: Iterable[str]) -> list[int]:
@@ -888,16 +890,18 @@ def _domain_from_dims(self, dimensions: Iterable[str]) -> list[int]:
for dimension in dimensions:
if dimension == I_DIM:
result.append(self.domain[0])
- if dimension == I_INTERFACE_DIM:
+ elif dimension == I_INTERFACE_DIM:
result.append(self.domain[0] + 1)
- if dimension == J_DIM:
+ elif dimension == J_DIM:
result.append(self.domain[1])
- if dimension == J_INTERFACE_DIM:
+ elif dimension == J_INTERFACE_DIM:
result.append(self.domain[1] + 1)
- if dimension == K_DIM:
+ elif dimension == K_DIM:
result.append(self.domain[2])
- if dimension == K_INTERFACE_DIM:
+ elif dimension == K_INTERFACE_DIM:
result.append(self.domain[2] + 1)
+ else:
+ raise ValueError(f"Unknown dimension '{dimension}'.")
return result
def get_shape(
diff --git a/ndsl/initialization/subtile_grid_sizer.py b/ndsl/initialization/subtile_grid_sizer.py
index 4a257080..c923afee 100644
--- a/ndsl/initialization/subtile_grid_sizer.py
+++ b/ndsl/initialization/subtile_grid_sizer.py
@@ -17,11 +17,18 @@ def __init__(
n_halo: int,
data_dimensions: dict[str, int],
backend: Backend,
+ *,
+ pad_non_interface_dimensions: bool = False,
) -> None:
super().__init__(nx, ny, nz, n_halo, data_dimensions)
fortran_style_memory = backend.is_fortran_aligned()
- self._pad_non_interface_dimensions = not fortran_style_memory
+
+ # TODO: pad_non_interface_dimensions should not be kept. In general
+ # this should _always_ be False and non-interface dimensions never padded by default
+ self._pad_non_interface_dimensions = (
+ not fortran_style_memory or pad_non_interface_dimensions
+ )
@classmethod
def from_tile_params(
@@ -36,6 +43,7 @@ def from_tile_params(
data_dimensions: dict[str, int] | None = None,
tile_partitioner: TilePartitioner | None = None,
tile_rank: int = 0,
+ pad_non_interface_dimensions: bool = False,
) -> Self:
"""Create a SubtileGridSizer from parameters about the full tile.
@@ -76,7 +84,15 @@ def from_tile_params(
"SubtileGridSizer::from_tile_params: Compute domain extent must be greater than halo size"
)
- return cls(nx, ny, nz, n_halo, data_dimensions, backend)
+ return cls(
+ nx,
+ ny,
+ nz,
+ n_halo,
+ data_dimensions,
+ backend,
+ pad_non_interface_dimensions=pad_non_interface_dimensions,
+ )
@classmethod
def from_namelist(
diff --git a/ndsl/quantity/local.py b/ndsl/quantity/local.py
index f69480a7..37aee3eb 100644
--- a/ndsl/quantity/local.py
+++ b/ndsl/quantity/local.py
@@ -31,6 +31,7 @@ def __init__(
# Initialize memory to obviously wrong value - Local should _not_ be expected
# to be zero'ed.
data[:] = 123456789
+ self._on_gpu = backend.is_gpu_backend()
super().__init__(
data,
@@ -45,5 +46,5 @@ def __init__(
def __descriptor__(self) -> Any:
"""Locals uses `Quantity.__descriptor__` and flag itself as transient."""
data = dace.data.create_datadescriptor(self._data)
- data.transient = True
+ data.transient = True if not self._on_gpu else False
return data
diff --git a/ndsl/quantity/quantity.py b/ndsl/quantity/quantity.py
index 0624a8c0..d493fe5f 100644
--- a/ndsl/quantity/quantity.py
+++ b/ndsl/quantity/quantity.py
@@ -6,7 +6,6 @@
from typing import Any, cast
import dace
-import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from gt4py import storage as gt_storage
@@ -287,7 +286,7 @@ def field(self) -> np.ndarray | cupy.ndarray:
def data(self) -> np.ndarray | cupy.ndarray:
"""The underlying array of data"""
warnings.warn(
- "Quantity.data accessor is now deprecated. Use a slicing operation directly on"
+ "Quantity.data accessor is now deprecated. Use a slicing operation directly on "
"the quantity, e.g. `my_quantity[:]` instead of `my_quantity.data[:]`",
category=UserWarning,
stacklevel=2,
@@ -459,6 +458,8 @@ def transpose(
return transposed
def plot_k_level(self, k_index: int = 0) -> None:
+ import matplotlib.pyplot as plt
+
field = self._data
plt.xlabel("I")
plt.ylabel("J")
diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py
index 6e5b17af..652e132a 100644
--- a/ndsl/stencils/testing/conftest.py
+++ b/ndsl/stencils/testing/conftest.py
@@ -105,6 +105,12 @@ def pytest_addoption(parser: pytest.Parser) -> None:
default=False,
help="Do not generate logging report or NetCDF in .translate-errors",
)
+ parser.addoption(
+ "--pad_non_interface_dimensions",
+ action="store_true",
+ default=False,
+ help="Pad the non interface dimensions in all backends. Default to False.",
+ )
def pytest_configure(config: pytest.Config) -> None:
@@ -255,6 +261,9 @@ def _sequential_savepoint_cases(
topology_mode = metafunc.config.getoption("topology")
sort_report = metafunc.config.getoption("sort_report")
no_report = metafunc.config.getoption("no_report")
+ pad_non_interface_dimensions = metafunc.config.getoption(
+ "pad_non_interface_dimensions"
+ )
return _savepoint_cases(
savepoint_names,
@@ -268,6 +277,7 @@ def _sequential_savepoint_cases(
topology_mode,
sort_report=sort_report,
no_report=no_report,
+ pad_non_interface_dimensions=pad_non_interface_dimensions,
)
@@ -283,6 +293,7 @@ def _savepoint_cases(
topology_mode: str,
sort_report: str,
no_report: bool,
+ pad_non_interface_dimensions: bool,
) -> list[SavepointCase]:
grid_params = grid_params_from_f90nml(namelist)
return_list = []
@@ -305,6 +316,7 @@ def _savepoint_cases(
rank=rank,
layout=grid_params["layout"],
backend=backend,
+ pad_non_interface_dimensions=pad_non_interface_dimensions,
).python_grid()
if grid_mode == "compute":
_compute_grid_data(
@@ -377,6 +389,9 @@ def _parallel_savepoint_cases(
savepoint_names = _parallel_savepoint_names(metafunc, data_path)
grid_mode = metafunc.config.getoption("grid")
savepoint_to_replay = _get_savepoint_restriction(metafunc)
+ pad_non_interface_dimensions = metafunc.config.getoption(
+ "pad_non_interface_dimensions"
+ )
return _savepoint_cases(
savepoint_names,
@@ -390,6 +405,7 @@ def _parallel_savepoint_cases(
topology_mode,
sort_report=sort_report,
no_report=no_report,
+ pad_non_interface_dimensions=pad_non_interface_dimensions,
)
diff --git a/ndsl/stencils/testing/grid.py b/ndsl/stencils/testing/grid.py
index 3af290e4..db24fd13 100644
--- a/ndsl/stencils/testing/grid.py
+++ b/ndsl/stencils/testing/grid.py
@@ -60,6 +60,7 @@ def _make(
layout: tuple[int, int],
rank: int,
backend: Backend,
+ pad_non_interface_dimensions: bool = False,
) -> "Grid":
shape_params = {
"npx": npx,
@@ -81,7 +82,15 @@ def _make(
"js": N_HALO_DEFAULT,
"je": ny + N_HALO_DEFAULT - 1,
}
- return cls(indices, shape_params, rank, layout, backend, local_indices=True)
+ return cls(
+ indices,
+ shape_params,
+ rank,
+ layout,
+ backend,
+ local_indices=True,
+ pad_non_interface_dimensions=pad_non_interface_dimensions,
+ )
@classmethod
def from_namelist(cls, namelist: Namelist, rank: int, backend: Backend) -> "Grid":
@@ -112,6 +121,7 @@ def __init__(
backend: Backend,
data_fields: dict | None = None,
local_indices: bool = False,
+ pad_non_interface_dimensions: bool = False,
) -> None:
if data_fields is None:
data_fields = {}
@@ -162,6 +172,7 @@ def __init__(
self._grid_data: GridData | None = None
self._driver_grid_data: DriverGridData | None = None
self._damping_coefficients: DampingCoefficients | None = None
+ self._pad_non_interface_dimensions = pad_non_interface_dimensions
@property
def sizer(self) -> GridSizer:
@@ -180,6 +191,7 @@ def sizer(self) -> GridSizer:
},
layout=self.layout,
backend=self.backend,
+ pad_non_interface_dimensions=self._pad_non_interface_dimensions,
)
return self._sizer
diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py
index 2bc0d4fc..66db8063 100644
--- a/ndsl/stencils/testing/test_translate.py
+++ b/ndsl/stencils/testing/test_translate.py
@@ -466,7 +466,7 @@ def _report_results(
os.makedirs(detail_dir, exist_ok=True)
# Summary
- header = f"{savepoint_name} w/ f{backend.as_humanly_readable()}"
+ header = f"{savepoint_name} w/ {backend.as_humanly_readable()}"
lines = []
for varname, metric in results.items():
lines.append(f"{varname}: {metric.one_line_report()}")
diff --git a/ndsl/stencils/testing/translate.py b/ndsl/stencils/testing/translate.py
index 29afc577..4011a1c0 100644
--- a/ndsl/stencils/testing/translate.py
+++ b/ndsl/stencils/testing/translate.py
@@ -68,10 +68,7 @@ def __init__(
self.ordered_input_vars = None
self.ignore_near_zero_errors: dict[str, Any] = {}
self.skip_test = skip_test
- if self.stencil_factory.backend.is_fortran_aligned():
- self.maxshape = self.grid.domain_shape_full()
- else:
- self.maxshape = self.grid.domain_shape_full(add=(1, 1, 1))
+ self.maxshape = self.grid.domain_shape_full(add=(1, 1, 1))
def extra_data_load(self, data_loader: DataLoader):
pass
@@ -322,7 +319,15 @@ def new_from_serialized_data(cls, serializer, rank, layout, backend: Backend):
grid_data[field] = read_serialized_data(serializer, grid_savepoint, field)
return cls(grid_data, rank, layout, backend=backend)
- def __init__(self, inputs, rank, layout, *, backend: Backend):
+ def __init__(
+ self,
+ inputs,
+ rank,
+ layout,
+ *,
+ backend: Backend,
+ pad_non_interface_dimensions: bool = False,
+ ):
self.backend = backend
self.indices = {}
self.shape_params = {}
@@ -338,6 +343,7 @@ def __init__(self, inputs, rank, layout, *, backend: Backend):
del inputs[index]
self.data = inputs
+ self._pad_non_interface_dimensions = pad_non_interface_dimensions
def _make_composite_var_storage(self, varname, data3d, shape, count):
for s in range(count):
@@ -444,7 +450,12 @@ def make_grid_storage(self, pygrid):
def python_grid(self):
pygrid = Grid(
- self.indices, self.shape_params, self.rank, self.layout, self.backend
+ self.indices,
+ self.shape_params,
+ self.rank,
+ self.layout,
+ self.backend,
+ pad_non_interface_dimensions=self._pad_non_interface_dimensions,
)
self.make_grid_storage(pygrid)
pygrid.add_data(self.data)
diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py
index 3acfd723..e7fc93eb 100644
--- a/ndsl/testing/comparison.py
+++ b/ndsl/testing/comparison.py
@@ -339,7 +339,7 @@ def one_line_report(self) -> str:
return f"❌ Numerical failures: {failed_indices}/{all_indices} failed - metric: {metric_thresholds}"
def report(self, file_path: str | None = None) -> list[str]:
- failed_indices = np.logical_not(self.success).nonzero()
+ failed_indices = np.atleast_1d(np.logical_not(self.success)).nonzero()
# List all errors to terminal and file
bad_indices_count = len(failed_indices[0])
if self.changing_column_map is not None:
diff --git a/tests/dsl/dace/stree/__init__.py b/tests/dsl/dace/stree/__init__.py
index 2fa38d13..b43c1d92 100644
--- a/tests/dsl/dace/stree/__init__.py
+++ b/tests/dsl/dace/stree/__init__.py
@@ -1,7 +1,7 @@
-from .sdfg_stree_tools import StreeOptimization, get_SDFG_and_purge
+from .sdfg_stree_tools import StreePipeline, get_SDFG_and_purge
__all__ = [
- "StreeOptimization",
+ "StreePipeline",
"get_SDFG_and_purge",
]
diff --git a/tests/dsl/dace/stree/common/__init__.py b/tests/dsl/dace/stree/common/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/dsl/dace/stree/common/test_loops.py b/tests/dsl/dace/stree/common/test_loops.py
new file mode 100644
index 00000000..a2c12d76
--- /dev/null
+++ b/tests/dsl/dace/stree/common/test_loops.py
@@ -0,0 +1,96 @@
+from dace.sdfg import nodes
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
+from dace.sdfg.state import LoopRegion
+
+from ndsl.dsl.dace.stree.optimizations.common import (
+ AxisIterator,
+ is_axis_for,
+ is_axis_map,
+ is_cartesian_axis,
+)
+
+
+def test_is_axis_map_multiple_params() -> None:
+ node = tn.MapScope(
+ node=nodes.MapEntry(
+ nodes.Map("map_ij", ["__i", "__j"], [(0, 3, 1), (0, 4, 1)])
+ ),
+ children=[],
+ )
+ assert not is_axis_map(node, AxisIterator._I)
+ assert not is_axis_map(node, AxisIterator._J)
+
+
+def test_is_axis_map_I() -> None:
+ node = tn.MapScope(
+ node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[]
+ )
+ assert is_axis_map(node, AxisIterator._I)
+
+
+def test_is_axis_map_not_I() -> None:
+ node = tn.MapScope(
+ node=nodes.MapEntry(nodes.Map("map_other_i", ["__i0"], [(0, 3, 1)])),
+ children=[],
+ )
+ assert not is_axis_map(node, AxisIterator._I)
+
+
+def test_is_axis_map_K() -> None:
+ node = tn.MapScope(
+ node=nodes.MapEntry(nodes.Map("map_k", ["__k_1234"], [(0, 3, 1)])), children=[]
+ )
+ assert is_axis_map(node, AxisIterator._K)
+
+
+def test_is_axis_map_wrong_iterator() -> None:
+ node = tn.MapScope(
+ node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[]
+ )
+ assert not is_axis_map(node, AxisIterator._J)
+
+
+def test_is_cartesian_axis() -> None:
+ map_i = tn.MapScope(
+ node=nodes.MapEntry(nodes.Map("map_i", ["__i"], [(0, 3, 1)])), children=[]
+ )
+ assert is_cartesian_axis(map_i)
+
+ map_j = tn.MapScope(
+ node=nodes.MapEntry(nodes.Map("map_j", ["__j"], [(0, 3, 1)])), children=[]
+ )
+ assert is_cartesian_axis(map_j)
+
+ map_k = tn.MapScope(
+ node=nodes.MapEntry(nodes.Map("map_k", ["__k_1234"], [(0, 3, 1)])), children=[]
+ )
+ assert is_cartesian_axis(map_k)
+
+ for_k = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[])
+ assert is_cartesian_axis(for_k)
+
+ map_non_cartesian = tn.MapScope(
+ node=nodes.MapEntry(nodes.Map("map_other_i", ["__i0"], [(0, 3, 1)])),
+ children=[],
+ )
+ assert not is_cartesian_axis(map_non_cartesian)
+
+
+def test_is_axis_for_k() -> None:
+ node = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[])
+ assert is_axis_for(node, AxisIterator._K)
+
+
+def test_is_axis_for_wrong_iterator() -> None:
+ node = tn.ForScope(loop=LoopRegion("for_k", loop_var="__k_1234"), children=[])
+ assert not is_axis_for(node, AxisIterator._I)
+
+
+def test_is_axis_for_i() -> None:
+ node = tn.ForScope(loop=LoopRegion("for_i", loop_var="__i"), children=[])
+ assert is_axis_for(node, AxisIterator._I)
+
+
+def test_is_axis_for_not_i() -> None:
+ node = tn.ForScope(loop=LoopRegion("for_i", loop_var="__i0"), children=[])
+ assert not is_axis_for(node, AxisIterator._I)
diff --git a/tests/dsl/dace/stree/common/test_memlet.py b/tests/dsl/dace/stree/common/test_memlet.py
new file mode 100644
index 00000000..44fe15e1
--- /dev/null
+++ b/tests/dsl/dace/stree/common/test_memlet.py
@@ -0,0 +1,32 @@
+from dace.symbolic import symbol
+
+from ndsl.dsl.dace.stree.optimizations.common import AxisIterator
+from ndsl.dsl.dace.stree.optimizations.common.memlet import (
+ normalize_cartesian_indexation,
+)
+
+
+def test_normalize_cartesian_index():
+ # Case of __k_id(node) - original case
+ original_symbol = symbol("__k_12345678789")
+ norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K)
+
+ assert norm_symbol == symbol("__k")
+
+ # Case of offset
+ original_symbol = 1 + symbol("__k_12345678789")
+ norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K)
+
+ assert norm_symbol == symbol("__k") + 1
+
+ # Case of no-op (with offset)
+ original_symbol = 1 + symbol("__k")
+ norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K)
+
+ assert norm_symbol == symbol("__k") + 1
+
+ # Case of index named with _k - so not a cartesian axis
+ original_symbol = 1 + symbol("_kindex")
+ norm_symbol = normalize_cartesian_indexation(original_symbol, AxisIterator._K)
+
+ assert norm_symbol == symbol("_kindex") + 1
diff --git a/tests/dsl/dace/stree/optimizations/__init__.py b/tests/dsl/dace/stree/optimizations/__init__.py
index e69de29b..e0e56d60 100644
--- a/tests/dsl/dace/stree/optimizations/__init__.py
+++ b/tests/dsl/dace/stree/optimizations/__init__.py
@@ -0,0 +1,6 @@
+from typing import TypeAlias
+
+from ndsl import QuantityFactory, StencilFactory
+
+
+Factories: TypeAlias = tuple[StencilFactory, QuantityFactory]
diff --git a/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py
new file mode 100644
index 00000000..5ddb9764
--- /dev/null
+++ b/tests/dsl/dace/stree/optimizations/test_kernelize_maps.py
@@ -0,0 +1,187 @@
+import pytest
+from dace import nodes
+from dace.sdfg.state import LoopRegion
+
+from ndsl import Backend, NDSLRuntime, OptimizationConfig, orchestrate
+from ndsl.boilerplate import get_factories_single_tile
+from ndsl.constants import I_DIM, J_DIM, K_DIM
+from ndsl.dsl.gt4py import BACKWARD, FORWARD, PARALLEL, computation, interval
+from ndsl.dsl.stencil import StencilFactory
+from ndsl.dsl.typing import FloatField
+from tests.dsl.dace.stree import get_SDFG_and_purge
+from tests.dsl.dace.stree.optimizations import Factories
+
+
+def stencil_kernelize(in_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ with computation(PARALLEL), interval(...):
+ value = in_field * 2
+ tmp = value
+
+ with computation(FORWARD), interval(0, -1):
+ tmp = 0.5 * (tmp + tmp[0, 0, 1])
+
+ with computation(PARALLEL), interval(...):
+ out_field = tmp
+
+
+def stencil_only_serial_noop(
+ in_field: FloatField, out_field: FloatField
+) -> None: # type:ignore
+ with computation(FORWARD), interval(...):
+ tmp = in_field
+
+ with computation(BACKWARD), interval(...):
+ out_field = tmp
+
+
+def stencil_only_parallel_noop(
+ in_field: FloatField, out_field: FloatField
+) -> None: # type:ignore
+ with computation(PARALLEL), interval(0, 2):
+ out_field = in_field
+
+ with computation(PARALLEL), interval(-2, None):
+ out_field = in_field + 1
+
+
+class OrchestratedCode(NDSLRuntime):
+ def __init__(self, stencil_factory: StencilFactory) -> None:
+ optimization_config = OptimizationConfig(
+ stree=OptimizationConfig.Tree(
+ enabled=True,
+ merger=OptimizationConfig.Tree.Merger(enabled=True),
+ )
+ )
+ super().__init__(stencil_factory, optimization_config)
+
+ methods_to_orchestrate = [
+ "kernelize_k",
+ "only_serial_noop",
+ "only_parallel_noop",
+ ]
+ for method in methods_to_orchestrate:
+ orchestrate(
+ obj=self,
+ config=stencil_factory.config.dace_config,
+ method_to_orchestrate=method,
+ optimization_config=optimization_config,
+ )
+
+ self._stencil_kernelize_k = stencil_factory.from_dims_halo(
+ func=stencil_kernelize,
+ compute_dims=(I_DIM, J_DIM, K_DIM),
+ )
+ self._stencil_only_serial_noop = stencil_factory.from_dims_halo(
+ func=stencil_only_serial_noop,
+ compute_dims=(I_DIM, J_DIM, K_DIM),
+ )
+ self._stencil_only_parallel_noop = stencil_factory.from_dims_halo(
+ func=stencil_only_parallel_noop,
+ compute_dims=(I_DIM, J_DIM, K_DIM),
+ )
+
+ def kernelize_k(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ self._stencil_kernelize_k(in_field, out_field)
+
+ def only_serial_noop(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ self._stencil_only_serial_noop(in_field, out_field)
+
+ def only_parallel_noop(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ self._stencil_only_parallel_noop(in_field, out_field)
+
+
+class TestKernelizeMaps:
+ @pytest.fixture(
+ params=[
+ "orch:dace:cpu:IJK",
+ pytest.param("orch:dace:gpu:IJK", marks=pytest.mark.gpu),
+ ]
+ )
+ def factories(self, request: pytest.FixtureRequest) -> Factories:
+ domain = (3, 4, 5)
+ return get_factories_single_tile(
+ nx=domain[0],
+ ny=domain[1],
+ nz=domain[2],
+ nhalo=0,
+ backend=Backend(request.param),
+ )
+
+ def test_kernelize_k_gpu(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), "")
+ out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), "")
+
+ code.kernelize_k(in_field, out_field)
+
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+
+ if stencil_factory.backend.is_gpu_backend():
+ # check for kernelization
+ all_maps = [
+ node
+ for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(node, nodes.MapEntry)
+ ]
+
+ ij_maps = 0
+ ijk_maps = 0
+ for map_entry in all_maps:
+ if map_entry.map.params == ["__i", "__j"]:
+ ij_maps += 1
+ elif len(map_entry.map.params) == 3:
+ params = map_entry.map.params
+ k_param = params[2]
+ if (
+ params[0:2] == ["__i", "__j"]
+ and isinstance(k_param, str)
+ and k_param.startswith("__k")
+ ):
+ ijk_maps += 1
+
+ # expect two IJK-maps and one IJ-map
+ assert ij_maps == 1
+ assert ijk_maps == 2
+ assert len(all_maps) == 3
+
+ all_loop_regions = [
+ node
+ for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(node, LoopRegion)
+ ]
+ # expect one k-loop is preserved
+ assert len(all_loop_regions) == 1
+ assert all_loop_regions[0].loop_variable.startswith("__k")
+ else:
+ # check that we keep IJ loops merged
+ all_maps = [
+ node
+ for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(node, nodes.MapEntry)
+ ]
+
+ ij_maps = 0
+ k_maps = 0
+ for map_entry in all_maps:
+ if map_entry.map.params == ["__i", "__j"]:
+ ij_maps += 1
+ elif len(map_entry.map.params) == 1:
+ param = map_entry.map.params[0]
+ if isinstance(param, str) and param.startswith("__k"):
+ k_maps += 1
+
+ # expect one IJ-map and two K-maps
+ assert ij_maps == 1
+ assert k_maps == 2
+ assert len(all_maps) == 3
+
+ all_loop_regions = [
+ node
+ for node, _ in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(node, LoopRegion)
+ ]
+ # expect one k-loop is preserved
+ assert len(all_loop_regions) == 1
+ assert all_loop_regions[0].loop_variable.startswith("__k")
diff --git a/tests/dsl/dace/stree/optimizations/test_merge.py b/tests/dsl/dace/stree/optimizations/test_merge.py
index d57e758a..8dd82624 100644
--- a/tests/dsl/dace/stree/optimizations/test_merge.py
+++ b/tests/dsl/dace/stree/optimizations/test_merge.py
@@ -1,18 +1,17 @@
-from typing import TypeAlias
-
import dace
import pytest
from dace import nodes
from dace.sdfg.analysis.schedule_tree import treenodes as tn
from dace.sdfg.state import LoopRegion
-from ndsl import QuantityFactory, StencilFactory, orchestrate
+from ndsl import OptimizationConfig, QuantityFactory, StencilFactory, orchestrate
from ndsl.boilerplate import get_factories_single_tile_orchestrated
from ndsl.config import Backend
from ndsl.constants import I_DIM, J_DIM, K_DIM
from ndsl.dsl.gt4py import FORWARD, PARALLEL, K, computation, interval
from ndsl.dsl.typing import FloatField
-from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge
+from tests.dsl.dace.stree import get_SDFG_and_purge
+from tests.dsl.dace.stree.optimizations import Factories
def stencil(in_field: FloatField, out_field: FloatField) -> None:
@@ -54,6 +53,12 @@ def __init__(
stencil_factory: StencilFactory,
quantity_factory: QuantityFactory,
) -> None:
+ config = OptimizationConfig(
+ stree=OptimizationConfig.Tree(
+ enabled=True,
+ merger=OptimizationConfig.Tree.Merger(enabled=True),
+ )
+ )
orchestratable_methods = [
"trivial_merge",
"missing_merge_of_forscope_and_map",
@@ -66,7 +71,21 @@ def __init__(
obj=self,
config=stencil_factory.config.dace_config,
method_to_orchestrate=method,
+ optimization_config=config,
)
+ orchestrate(
+ obj=self,
+ config=stencil_factory.config.dace_config,
+ method_to_orchestrate="no_overcompute_merge",
+ optimization_config=OptimizationConfig(
+ stree=OptimizationConfig.Tree(
+ enabled=True,
+ merger=OptimizationConfig.Tree.Merger(
+ enabled=True, overcompute=False
+ ),
+ )
+ ),
+ )
self.stencil = stencil_factory.from_dims_halo(
func=stencil,
@@ -120,6 +139,14 @@ def overcompute_merge(
self.stencil(in_field, out_field)
self.stencil_with_different_intervals(in_field, out_field)
+ def no_overcompute_merge(
+ self,
+ in_field: FloatField,
+ out_field: FloatField,
+ ) -> None:
+ self.stencil(in_field, out_field)
+ self.stencil_with_different_intervals(in_field, out_field)
+
def push_non_cartesian_for(
self,
in_field: FloatField,
@@ -130,9 +157,6 @@ def push_non_cartesian_for(
self.stencil(in_field, out_field)
-Factories: TypeAlias = tuple[StencilFactory, QuantityFactory]
-
-
class TestStreeMergeMapsIJK:
@pytest.fixture
def factories(self) -> Factories:
@@ -150,8 +174,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- code.trivial_merge(in_qty, out_qty)
+ code.trivial_merge(in_qty, out_qty)
precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
all_maps = [
@@ -160,7 +183,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No
if isinstance(me, nodes.MapEntry)
]
- assert len(all_maps) == 3
+ assert len(all_maps) == 1 # all merged and collapsed
assert (out_qty.field[:] == 2).all()
def test_missing_merge_of_forscope_and_map(
@@ -170,8 +193,7 @@ def test_missing_merge_of_forscope_and_map(
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- code.missing_merge_of_forscope_and_map(in_qty, out_qty)
+ code.missing_merge_of_forscope_and_map(in_qty, out_qty)
sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
@@ -179,7 +201,7 @@ def test_missing_merge_of_forscope_and_map(
for map_entry, _ in sdfg.all_nodes_recursive()
if isinstance(map_entry, nodes.MapEntry)
]
- assert len(all_maps) == 4 # 2 IJ + 2 Ks
+ assert len(all_maps) == 3 # 1 IJ + 2 Ks
all_loops = [
loop
for loop, _ in sdfg.all_nodes_recursive()
@@ -194,8 +216,7 @@ def test_overcompute_merge(
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- code.overcompute_merge(in_qty, out_qty)
+ code.overcompute_merge(in_qty, out_qty)
sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
@@ -203,7 +224,34 @@ def test_overcompute_merge(
for me, state in sdfg.all_nodes_recursive()
if isinstance(me, nodes.MapEntry)
]
- assert len(all_maps) == 3 # All maps merged
+ assert len(all_maps) == 1 # All maps merged and collapsed
+
+ def test_no_overcompute_merge(
+ self, code: OrchestratedCode, factories: Factories
+ ) -> None:
+ stencil_factory, quantity_factory = factories
+ in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
+ out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
+
+ code.no_overcompute_merge(in_qty, out_qty)
+
+ sdfg = get_SDFG_and_purge(stencil_factory).sdfg
+
+ all_maps = [
+ me for me, _ in sdfg.all_nodes_recursive() if isinstance(me, nodes.MapEntry)
+ ]
+ k_maps = 0
+ ij_maps = 0
+ for map_entry in all_maps:
+ if len(map_entry.map.params) == 1 and map_entry.map.params[0].startswith(
+ "__k"
+ ):
+ k_maps += 1
+ if map_entry.map.params == ["__i", "__j"]:
+ ij_maps += 1
+
+ assert ij_maps == 1
+ assert k_maps == 2
def test_block_merge_when_dependencies_are_found(
self, code: OrchestratedCode, factories: Factories
@@ -212,9 +260,8 @@ def test_block_merge_when_dependencies_are_found(
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- # Forbid merging when data dependencies are detected
- code.block_merge_when_dependencies_are_found(in_qty, out_qty)
+ # Forbid merging when data dependencies are detected
+ code.block_merge_when_dependencies_are_found(in_qty, out_qty)
sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
@@ -222,7 +269,7 @@ def test_block_merge_when_dependencies_are_found(
for me, state in sdfg.all_nodes_recursive()
if isinstance(me, nodes.MapEntry)
]
- assert len(all_maps) == 4 # 2 IJ + 2 Ks (un-merged)
+ assert len(all_maps) == 3 # 1 IJ + 2 Ks (un-merged)
def test_push_non_cartesian_for(
self, code: OrchestratedCode, factories: Factories
@@ -231,10 +278,9 @@ def test_push_non_cartesian_for(
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- # Push non-cartesian ForScope inwards, which allow to potentially
- # merge cartesian maps
- code.push_non_cartesian_for(in_qty, out_qty)
+ # Push non-cartesian ForScope inwards, which allow to potentially
+ # merge cartesian maps
+ code.push_non_cartesian_for(in_qty, out_qty)
sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
@@ -242,7 +288,7 @@ def test_push_non_cartesian_for(
for me, state in sdfg.all_nodes_recursive()
if isinstance(me, nodes.MapEntry)
]
- assert len(all_maps) == 3 # All merged
+ assert len(all_maps) == 1 # All merged & collapsed
for_loops = [
node
for node, _ in sdfg.all_nodes_recursive()
@@ -268,8 +314,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- code.trivial_merge(in_qty, out_qty)
+ code.trivial_merge(in_qty, out_qty)
precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
all_maps = [
@@ -278,7 +323,7 @@ def test_trivial_merge(self, code: OrchestratedCode, factories: Factories) -> No
if isinstance(me, nodes.MapEntry)
]
- assert len(all_maps) == 3
+ assert len(all_maps) == 1 # all maps merged and collapsed
assert (out_qty.field[:] == 2).all()
def test_missing_merge_of_forscope_and_map(
@@ -288,9 +333,8 @@ def test_missing_merge_of_forscope_and_map(
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- # K iterative loop - blocks all merges
- code.missing_merge_of_forscope_and_map(in_qty, out_qty)
+ # K iterative loop - blocks all merges
+ code.missing_merge_of_forscope_and_map(in_qty, out_qty)
sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
@@ -298,7 +342,7 @@ def test_missing_merge_of_forscope_and_map(
for map_entry, _ in sdfg.all_nodes_recursive()
if isinstance(map_entry, nodes.MapEntry)
]
- assert len(all_maps) == 8 # 2 KJI (all maps) + 1 for scope
+ assert len(all_maps) == 3 # 2 KJI (all maps) + 1 JI
all_loops = [
loop
for loop, _ in sdfg.all_nodes_recursive()
@@ -313,9 +357,8 @@ def test_overcompute_merge(
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- # Overcompute merge in K - we merge and introduce an If guard
- code.overcompute_merge(in_qty, out_qty)
+ # Overcompute merge in K - we merge and introduce an If guard
+ code.overcompute_merge(in_qty, out_qty)
sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
@@ -323,7 +366,7 @@ def test_overcompute_merge(
for me, state in sdfg.all_nodes_recursive()
if isinstance(me, nodes.MapEntry)
]
- assert len(all_maps) == 3 # All maps merged
+ assert len(all_maps) == 1 # All maps merged & collapsed
def test_block_merge_when_dependencies_are_found(
self, code: OrchestratedCode, factories: Factories
@@ -332,9 +375,8 @@ def test_block_merge_when_dependencies_are_found(
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- # Forbid merging when data dependencies are detected
- code.block_merge_when_dependencies_are_found(in_qty, out_qty)
+ # Forbid merging when data dependencies are detected
+ code.block_merge_when_dependencies_are_found(in_qty, out_qty)
sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
@@ -342,7 +384,7 @@ def test_block_merge_when_dependencies_are_found(
for me, state in sdfg.all_nodes_recursive()
if isinstance(me, nodes.MapEntry)
]
- assert len(all_maps) == 6 # 2 * KJI
+ assert len(all_maps) == 2 # 2 * KJI
def test_push_non_cartesian_for(
self, code: OrchestratedCode, factories: Factories
@@ -351,10 +393,9 @@ def test_push_non_cartesian_for(
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- # Push non-cartesian ForScope inwards, which allow to potentially
- # merge cartesian maps
- code.push_non_cartesian_for(in_qty, out_qty)
+ # Push non-cartesian ForScope inwards, which allow to potentially
+ # merge cartesian maps
+ code.push_non_cartesian_for(in_qty, out_qty)
sdfg = get_SDFG_and_purge(stencil_factory).sdfg
all_maps = [
@@ -362,7 +403,7 @@ def test_push_non_cartesian_for(
for me, state in sdfg.all_nodes_recursive()
if isinstance(me, nodes.MapEntry)
]
- assert len(all_maps) == 3 # All merged
+ assert len(all_maps) == 1 # All merged and collapsed
for_loops = [
node
for node, _ in sdfg.all_nodes_recursive()
diff --git a/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py
new file mode 100644
index 00000000..6232903a
--- /dev/null
+++ b/tests/dsl/dace/stree/optimizations/test_offgrid_conditionals.py
@@ -0,0 +1,158 @@
+import pytest
+from dace import nodes
+
+from ndsl import (
+ Backend,
+ NDSLRuntime,
+ OptimizationConfig,
+ StencilFactory,
+ orchestrate,
+ stencils,
+)
+from ndsl.boilerplate import get_factories_single_tile
+from ndsl.constants import I_DIM, J_DIM, K_DIM
+from ndsl.dsl.typing import FloatField
+from tests.dsl.dace.stree import get_SDFG_and_purge
+from tests.dsl.dace.stree.optimizations import Factories
+
+
+class OrchestratedCode(NDSLRuntime):
+ def __init__(self, stencil_factory: StencilFactory) -> None:
+ config = OptimizationConfig(
+ stree=OptimizationConfig.Tree(
+ enabled=True,
+ merger=OptimizationConfig.Tree.Merger(enabled=True),
+ )
+ )
+ super().__init__(stencil_factory, config)
+
+ methods_to_orchestrate = [
+ "happy_case",
+ "happy_case_2",
+ "blocked_by_else",
+ "blocked_by_other_nodes",
+ ]
+
+ for method in methods_to_orchestrate:
+ orchestrate(
+ obj=self,
+ config=stencil_factory.config.dace_config,
+ method_to_orchestrate=method,
+ optimization_config=config,
+ )
+
+ self._copy_stencil = stencil_factory.from_dims_halo(
+ func=stencils.copy, compute_dims=[I_DIM, J_DIM, K_DIM]
+ )
+
+ def happy_case(self, in_field: FloatField, out_field: FloatField) -> None:
+ if in_field[0, 0, 0] > 0:
+ self._copy_stencil(in_field, out_field)
+ self._copy_stencil(in_field, out_field)
+
+ def happy_case_2(self, in_field: FloatField, out_field: FloatField) -> None:
+ if not in_field[0, 0, 0] > 0:
+ self._copy_stencil(in_field, out_field)
+ self._copy_stencil(in_field, out_field)
+
+ def blocked_by_else(self, in_field: FloatField, out_field: FloatField) -> None:
+ self._copy_stencil(in_field, out_field)
+
+ if in_field[0, 0, 0] > 0:
+ self._copy_stencil(in_field, out_field)
+ else:
+ self._copy_stencil(out_field, in_field)
+
+ def blocked_by_other_nodes(
+ self, in_field: FloatField, out_field: FloatField
+ ) -> None:
+ if in_field[0, 0, 0] > 0:
+ in_field[:] = 42.0
+ self._copy_stencil(in_field, out_field)
+ self._copy_stencil(in_field, out_field)
+
+
+class TestStreeInlineOffgridConditionals:
+ @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"])
+ def factories(self, request: pytest.FixtureRequest) -> Factories:
+ domain = (3, 3, 4)
+ return get_factories_single_tile(
+ domain[0], domain[1], domain[2], 0, backend=Backend(request.param)
+ )
+
+ def test_happy_case(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+
+ code = OrchestratedCode(stencil_factory)
+ in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
+ out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
+
+ code.happy_case(in_quantity, out_quantity)
+
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+
+ all_maps = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, nodes.MapEntry)
+ ]
+ assert len(all_maps) == 1 # all merged and collapsed
+
+ def test_happy_case_2(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+
+ code = OrchestratedCode(stencil_factory)
+ in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
+ out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
+
+ code.happy_case_2(in_quantity, out_quantity)
+
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+
+ all_maps = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, nodes.MapEntry)
+ ]
+ assert len(all_maps) == 1 # all merged and collapsed
+
+ def test_blocked_by_else(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+
+ code = OrchestratedCode(stencil_factory)
+ in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
+ out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
+
+ code.blocked_by_else(in_quantity, out_quantity)
+
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+
+ all_maps = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, nodes.MapEntry)
+ ]
+ assert len(all_maps) == 3 # 3 * IJK/KJI
+
+ def test_blocked_by_other_nodes(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+
+ code = OrchestratedCode(stencil_factory)
+ in_quantity = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
+ out_quantity = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
+
+ code.blocked_by_other_nodes(in_quantity, out_quantity)
+
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+
+ all_maps = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, nodes.MapEntry)
+ ]
+
+ # ⚠️ Dev note:
+ # This should be just `assert len(all_maps) == 2`, but currently, the K-loops
+ # can't merge because the K-iterators are different. To be fixed (and simplified
+ # here) with a subsequent commit.
+ assert len(all_maps) == 3
diff --git a/tests/dsl/dace/stree/optimizations/test_pipeline.py b/tests/dsl/dace/stree/optimizations/test_pipeline.py
index 677790bc..545f5ffa 100644
--- a/tests/dsl/dace/stree/optimizations/test_pipeline.py
+++ b/tests/dsl/dace/stree/optimizations/test_pipeline.py
@@ -1,10 +1,9 @@
-from ndsl import StencilFactory, orchestrate
+from ndsl import OptimizationConfig, StencilFactory, orchestrate
from ndsl.boilerplate import get_factories_single_tile_orchestrated
from ndsl.config import Backend
from ndsl.constants import I_DIM, J_DIM, K_DIM
from ndsl.dsl.gt4py import PARALLEL, computation, interval
from ndsl.dsl.typing import FloatField
-from tests.dsl.dace.stree import StreeOptimization
def double_map(in_field: FloatField, out_field: FloatField):
@@ -17,7 +16,16 @@ def double_map(in_field: FloatField, out_field: FloatField):
class TriviallyMergeableCode:
def __init__(self, stencil_factory: StencilFactory):
- orchestrate(obj=self, config=stencil_factory.config.dace_config)
+ config = OptimizationConfig(
+ stree=OptimizationConfig.Tree(
+ enabled=True, merger=OptimizationConfig.Tree.Merger(enabled=True)
+ )
+ )
+ orchestrate(
+ obj=self,
+ config=stencil_factory.config.dace_config,
+ optimization_config=config,
+ )
self.stencil = stencil_factory.from_dims_halo(
func=double_map,
compute_dims=[I_DIM, J_DIM, K_DIM],
@@ -27,7 +35,7 @@ def __call__(self, in_field: FloatField, out_field: FloatField):
self.stencil(in_field, out_field)
-def test_stree_roundtrip_no_opt():
+def test_stree_roundtrip():
domain = (3, 3, 4)
stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
domain[0], domain[1], domain[2], 0, backend=Backend.cpu()
@@ -37,7 +45,6 @@ def test_stree_roundtrip_no_opt():
in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
- with StreeOptimization():
- code(in_qty, out_qty)
+ code(in_qty, out_qty)
assert (out_qty.field[:] == 4).all()
diff --git a/tests/dsl/dace/stree/optimizations/test_remove_loops.py b/tests/dsl/dace/stree/optimizations/test_remove_loops.py
new file mode 100644
index 00000000..afb98023
--- /dev/null
+++ b/tests/dsl/dace/stree/optimizations/test_remove_loops.py
@@ -0,0 +1,269 @@
+import pytest
+from dace import nodes
+from dace.sdfg.state import LoopRegion
+
+from ndsl import OptimizationConfig, StencilFactory, orchestrate
+from ndsl.boilerplate import get_factories_single_tile
+from ndsl.config import Backend, BackendLoopOrder
+from ndsl.constants import I_DIM, J_DIM, K_DIM, Float
+from ndsl.dsl.gt4py import FORWARD, computation, interval
+from ndsl.dsl.typing import FloatField, FloatFieldIJ
+from ndsl.stencils import copy
+from tests.dsl.dace.stree import StreePipeline, get_SDFG_and_purge
+from tests.dsl.dace.stree.optimizations import Factories
+
+
+def stencil_simple_2D_write(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None:
+ with computation(FORWARD), interval(0, 1):
+ out_fieldIJ = in_field
+
+
+def stencil_multiple_2D_write(
+ in_field: FloatField, out_fieldIJ: FloatFieldIJ, out_fieldIJ_2: FloatFieldIJ
+) -> None:
+ with computation(FORWARD), interval(0, 1):
+ out_fieldIJ = in_field
+ out_fieldIJ_2 = in_field + 1.0
+
+
+def stencil_2D_write_at_K(in_field: FloatField, out_fieldIJ: FloatFieldIJ) -> None:
+ with computation(FORWARD), interval(-1, None):
+ out_fieldIJ = in_field
+
+
+def stencil_forward_at_K(in_field: FloatField, out_field: FloatField) -> None:
+ with computation(FORWARD), interval(...):
+ out_field = in_field
+
+
+class OrchestratedCode:
+ def __init__(self, stencil_factory: StencilFactory) -> None:
+ config = OptimizationConfig(
+ stree=OptimizationConfig.Tree(
+ enabled=True,
+ inline_K_loops_size_one=True,
+ merger=OptimizationConfig.Tree.Merger(enabled=True),
+ )
+ )
+ methods_to_orchestrate = [
+ "write_at_0",
+ "write_at_top",
+ "do_not_inline",
+ "combined_stencils",
+ "multiple_statements",
+ ]
+ for method in methods_to_orchestrate:
+ orchestrate(
+ obj=self,
+ config=stencil_factory.config.dace_config,
+ method_to_orchestrate=method,
+ optimization_config=config,
+ )
+
+ self.stencil_simple_2D_write = stencil_factory.from_dims_halo(
+ func=stencil_simple_2D_write,
+ compute_dims=[I_DIM, J_DIM, K_DIM],
+ )
+ self.stencil_2D_write_at_K = stencil_factory.from_dims_halo(
+ func=stencil_2D_write_at_K,
+ compute_dims=[I_DIM, J_DIM, K_DIM],
+ )
+ self.stencil_do_not_inline = stencil_factory.from_dims_halo(
+ func=stencil_forward_at_K,
+ compute_dims=[I_DIM, J_DIM, K_DIM],
+ )
+ self.stencil_copy = stencil_factory.from_dims_halo(
+ func=copy,
+ compute_dims=[I_DIM, J_DIM, K_DIM],
+ )
+ self.stencil_multiple_2D_write = stencil_factory.from_dims_halo(
+ func=stencil_multiple_2D_write,
+ compute_dims=[I_DIM, J_DIM, K_DIM],
+ )
+
+ def write_at_0(
+ self,
+ in_field: FloatField,
+ out_field: FloatFieldIJ,
+ ) -> None:
+ self.stencil_simple_2D_write(in_field, out_field)
+
+ def write_at_top(
+ self,
+ in_field: FloatField,
+ out_field: FloatFieldIJ,
+ ) -> None:
+ self.stencil_2D_write_at_K(in_field, out_field)
+
+ def do_not_inline(
+ self,
+ in_field: FloatField,
+ out_field: FloatField,
+ ) -> None:
+ self.stencil_do_not_inline(in_field, out_field)
+
+ def combined_stencils(
+ self, field: FloatField, field2: FloatField, fieldIJ: FloatFieldIJ
+ ) -> None:
+ self.stencil_copy(field, field2)
+ self.stencil_simple_2D_write(field2, fieldIJ)
+
+ def multiple_statements(
+ self, in_field: FloatField, out_field: FloatFieldIJ, out_field2: FloatFieldIJ
+ ) -> None:
+ self.stencil_copy(in_field, in_field)
+ self.stencil_multiple_2D_write(in_field, out_field, out_field2)
+
+
+class TestStree2DWriteInline:
+ @pytest.fixture(params=["orch:dace:cpu:IJK", "orch:dace:cpu:KJI"])
+ def factories(self, request: pytest.FixtureRequest) -> Factories:
+
+ domain = (3, 3, 4)
+ return get_factories_single_tile(
+ domain[0], domain[1], domain[2], 0, backend=Backend(request.param)
+ )
+
+ def test_common_2D_write(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
+ out_qty = quantity_factory.zeros([I_DIM, J_DIM], "")
+ in_qty.field[:, :, 0] = Float(32.0)
+
+ with StreePipeline():
+ code.write_at_0(in_qty, out_qty)
+
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+ all_maps = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, nodes.MapEntry)
+ ]
+ all_loop_region = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, LoopRegion)
+ ]
+
+ assert len(all_maps) == 1 # IJ/JI collapsed
+ assert len(all_loop_region) == 0
+ assert (out_qty.field[:] == Float(32.0)).all()
+
+ def test_2D_write_K_top(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
+ out_qty = quantity_factory.zeros([I_DIM, J_DIM], "")
+ in_qty.field[:, :, -1] = Float(32.0)
+
+ with StreePipeline():
+ code.write_at_top(in_qty, out_qty)
+
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+ all_maps = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, nodes.MapEntry)
+ ]
+ all_loop_region = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, LoopRegion)
+ ]
+
+ assert len(all_maps) == 1 # IJ/JI collapsed
+ assert len(all_loop_region) == 0
+ assert (out_qty.field[:] == Float(32.0)).all()
+
+ def test_do_not_inline(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ in_qty = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
+ out_qty = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
+
+ with StreePipeline():
+ code.do_not_inline(in_qty, out_qty)
+
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+ all_maps = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, nodes.MapEntry)
+ ]
+ all_loop_region = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, LoopRegion)
+ ]
+
+ assert len(all_maps) == 1 # IJ/JI collapsed
+ assert len(all_loop_region) == 1
+ assert (out_qty.field[:] == Float(1)).all()
+
+ def test_combined_stencils(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
+ field_2 = quantity_factory.zeros([I_DIM, J_DIM, K_DIM], "")
+ field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "")
+
+ with StreePipeline():
+ code.combined_stencils(field, field_2, field_IJ)
+
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+ all_maps = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, nodes.MapEntry)
+ ]
+ all_loop_region = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, LoopRegion)
+ ]
+
+ assert (
+ len(all_maps) == 2 # IJ + K
+ if stencil_factory.backend.loop_order == BackendLoopOrder.IJK
+ else 2 # KJI + JI
+ )
+ assert len(all_loop_region) == 0
+ assert (field_IJ.field[:] == Float(1)).all()
+
+ def test_multiple_statements(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ field = quantity_factory.ones([I_DIM, J_DIM, K_DIM], "")
+ field_IJ = quantity_factory.zeros([I_DIM, J_DIM], "")
+ field_IJ_2 = quantity_factory.zeros([I_DIM, J_DIM], "")
+
+ field.field[:, :, 0] = Float(42.0)
+ with StreePipeline():
+ code.multiple_statements(field, field_IJ, field_IJ_2)
+
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+ all_maps = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, nodes.MapEntry)
+ ]
+ all_loop_region = [
+ (me, state)
+ for me, state in precompiled_sdfg.sdfg.all_nodes_recursive()
+ if isinstance(me, LoopRegion)
+ ]
+
+ assert (
+ len(all_maps) == 2 # IJ + K
+ if stencil_factory.backend.loop_order == BackendLoopOrder.IJK
+ else 2 # KJI + JI
+ )
+ assert len(all_loop_region) == 0
+ assert (field_IJ.field[:] == Float(42.0)).all()
+ assert (field_IJ_2.field[:] == Float(43.0)).all()
diff --git a/tests/dsl/dace/stree/optimizations/test_transient_refine.py b/tests/dsl/dace/stree/optimizations/test_transient_refine.py
index 9795957a..d3c7604f 100644
--- a/tests/dsl/dace/stree/optimizations/test_transient_refine.py
+++ b/tests/dsl/dace/stree/optimizations/test_transient_refine.py
@@ -1,10 +1,17 @@
-from ndsl import NDSLRuntime, Quantity, QuantityFactory, StencilFactory, orchestrate
+from ndsl import (
+ NDSLRuntime,
+ OptimizationConfig,
+ Quantity,
+ QuantityFactory,
+ StencilFactory,
+ orchestrate,
+)
from ndsl.boilerplate import get_factories_single_tile_orchestrated
from ndsl.config import Backend
from ndsl.constants import I_DIM, J_DIM, K_DIM
from ndsl.dsl.gt4py import IJK, PARALLEL, Field, J, K, computation, interval
from ndsl.dsl.typing import Float, FloatField
-from tests.dsl.dace.stree import StreeOptimization, get_SDFG_and_purge
+from tests.dsl.dace.stree import get_SDFG_and_purge
DATADIM_SIZE = 8
@@ -39,7 +46,13 @@ class TransientRefineableCode(NDSLRuntime):
def __init__(
self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory
) -> None:
- super().__init__(stencil_factory)
+ config = OptimizationConfig(
+ stree=OptimizationConfig.Tree(
+ enabled=True,
+ merger=OptimizationConfig.Tree.Merger(enabled=True),
+ )
+ )
+ super().__init__(stencil_factory, optimization_config=config)
orchestratable_methods = [
"refine_to_scalar",
"refine_to_K_buffer",
@@ -51,6 +64,7 @@ def __init__(
obj=self,
config=stencil_factory.config.dace_config,
method_to_orchestrate=method,
+ optimization_config=config,
)
self.stencil = stencil_factory.from_dims_halo(
func=stencil,
@@ -105,40 +119,39 @@ def test_stree_roundtrip_transient_is_refined() -> None:
code = TransientRefineableCode(stencil_factory, quantity_factory)
- with StreeOptimization():
- # Refine to scalar
- code.refine_to_scalar(in_qty, out_qty)
- precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
- for array in precompiled_sdfg.sdfg.arrays.values():
- if array.transient:
- assert array.shape == (1, 1, 1)
-
- # Refine cartesian axis to buffers
- # IJ merges - K is a buffer
- code.refine_to_K_buffer(in_qty, out_qty)
- precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
- for array in precompiled_sdfg.sdfg.arrays.values():
- if array.transient:
- assert array.shape == (
- 1,
- 1,
- domain[2] + 1, # Quantity are domain size + 1
- )
-
- # I merges - JK buffer
- code.refine_to_JK_buffer(in_qty, out_qty)
- precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
- for array in precompiled_sdfg.sdfg.arrays.values():
- if array.transient:
- assert array.shape == (
- 1,
- domain[1] + 1, # Quantity are domain size + 1
- domain[2] + 1,
- )
-
- # Refine to remaining data dimensions
- code.do_not_refine_datadims(in_qty_ddim, out_qty_ddim)
- precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
- for array in precompiled_sdfg.sdfg.arrays.values():
- if array.transient:
- assert array.shape == (1, 1, 1, DATADIM_SIZE) or len(array.shape) == 1
+ # Refine to scalar
+ code.refine_to_scalar(in_qty, out_qty)
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+ for array in precompiled_sdfg.sdfg.arrays.values():
+ if array.transient:
+ assert array.shape == (1, 1, 1)
+
+ # Refine cartesian axis to buffers
+ # IJ merges - K is a buffer
+ code.refine_to_K_buffer(in_qty, out_qty)
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+ for array in precompiled_sdfg.sdfg.arrays.values():
+ if array.transient:
+ assert array.shape == (
+ 1,
+ 1,
+ domain[2] + 1, # Quantity are domain size + 1
+ )
+
+ # I merges - JK buffer
+ code.refine_to_JK_buffer(in_qty, out_qty)
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+ for array in precompiled_sdfg.sdfg.arrays.values():
+ if array.transient:
+ assert array.shape == (
+ 1,
+ domain[1] + 1, # Quantity are domain size + 1
+ domain[2] + 1,
+ )
+
+ # Refine to remaining data dimensions
+ code.do_not_refine_datadims(in_qty_ddim, out_qty_ddim)
+ precompiled_sdfg = get_SDFG_and_purge(stencil_factory)
+ for array in precompiled_sdfg.sdfg.arrays.values():
+ if array.transient:
+ assert array.shape == (1, 1, 1, DATADIM_SIZE) or len(array.shape) == 1
diff --git a/tests/dsl/dace/stree/sdfg_stree_tools.py b/tests/dsl/dace/stree/sdfg_stree_tools.py
index 6c664205..aeb149a5 100644
--- a/tests/dsl/dace/stree/sdfg_stree_tools.py
+++ b/tests/dsl/dace/stree/sdfg_stree_tools.py
@@ -1,6 +1,7 @@
from types import TracebackType
import dace
+from dace.sdfg.analysis.schedule_tree import treenodes as tn
import ndsl.dsl.dace.orchestration as orch
from ndsl import StencilFactory
@@ -20,9 +21,13 @@ def get_SDFG_and_purge(stencil_factory: StencilFactory) -> dace.CompiledSDFG:
return sdfg
-class StreeOptimization:
+class StreePipeline:
+ def __init__(self, *, passes: list[tn.ScheduleNodeVisitor] | None = None) -> None:
+ self.passes = passes
+
def __enter__(self) -> None:
- orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = True
+ self.original_passes = orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES
+ orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES = self.passes
def __exit__(
self,
@@ -30,4 +35,4 @@ def __exit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
- orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION = False
+ orch._INTERNAL__SCHEDULE_TREE_OPTIMIZATION_PASSES = self.original_passes
diff --git a/tests/dsl/orchestration/test_boundaries_k.py b/tests/dsl/orchestration/test_boundaries_k.py
new file mode 100644
index 00000000..80ea4a84
--- /dev/null
+++ b/tests/dsl/orchestration/test_boundaries_k.py
@@ -0,0 +1,196 @@
+import numpy as np
+import pytest
+
+from ndsl import Backend, NDSLRuntime, StencilFactory, orchestrate
+from ndsl.boilerplate import get_factories_single_tile
+from ndsl.constants import I_DIM, J_DIM, K_DIM, K_INTERFACE_DIM
+from ndsl.dsl.gt4py import BACKWARD, FORWARD, computation, interval
+from ndsl.dsl.typing import FloatField
+from tests.dsl.dace.stree.optimizations import Factories
+
+
+def accumulate_down(in_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ with computation(BACKWARD):
+ # handle top layer separately
+ with interval(-1, None):
+ out_field = in_field
+
+ # accumulate "downwards"
+ with interval(0, -1):
+ out_field = out_field[0, 0, 1] + in_field
+
+
+def accumulate_down_from_interface_field(interface_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ with computation(BACKWARD):
+ # handle top layer separately
+ with interval(-1, None):
+ out_field = interface_field + interface_field[0, 0, 1]
+
+ # accumulate "downwards"
+ with interval(0, -1):
+ out_field = out_field[0, 0, 1] + interface_field
+
+
+def accumulate_on_interface(interface_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ with computation(BACKWARD):
+ # handle top layer separately
+ with interval(-2, -1):
+ out_field = interface_field + interface_field[0, 0, 1]
+
+ # accumulate "downwards"
+ with interval(0, -2):
+ out_field = out_field[0, 0, 1] + interface_field
+
+
+def accumulate_up(in_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ with computation(FORWARD):
+ # handle bottom layer separately
+ with interval(0, 1):
+ out_field = in_field
+
+ # accumulate "upwards"
+ with interval(1, None):
+ out_field = out_field[0, 0, -1] + in_field
+
+
+def accumulate_up_interface(in_field: FloatField, interface_field: FloatField) -> None: # type: ignore
+ with computation(FORWARD):
+ # handle bottom layer separately
+ with interval(0, 1):
+ interface_field = in_field
+
+ # accumulate "upwards"
+ with interval(1, None):
+ interface_field = interface_field[0, 0, -1] + in_field[0, 0, -1]
+
+
+class OrchestratedCode(NDSLRuntime):
+ def __init__(self, stencil_factory: StencilFactory) -> None:
+ super().__init__(stencil_factory)
+
+ methods_to_orchestrate = [
+ "accumulate_down",
+ "accumulate_down_from_interface_field",
+ "accumulate_on_interface",
+ "accumulate_up",
+ "accumulate_up_interface",
+ ]
+
+ for method in methods_to_orchestrate:
+ orchestrate(
+ obj=self,
+ method_to_orchestrate=method,
+ config=stencil_factory.config.dace_config,
+ )
+
+ self._accumulate_down = stencil_factory.from_dims_halo(
+ func=accumulate_down, compute_dims=(I_DIM, J_DIM, K_DIM)
+ )
+
+ self._accumulate_down_from_interface_field = stencil_factory.from_dims_halo(
+ func=accumulate_down_from_interface_field,
+ compute_dims=(I_DIM, J_DIM, K_DIM),
+ )
+
+ self._accumulate_on_interface = stencil_factory.from_dims_halo(
+ func=accumulate_on_interface, compute_dims=(I_DIM, J_DIM, K_INTERFACE_DIM)
+ )
+
+ self._accumulate_up = stencil_factory.from_dims_halo(
+ func=accumulate_up, compute_dims=(I_DIM, J_DIM, K_DIM)
+ )
+
+ self._accumulate_up_interface = stencil_factory.from_dims_halo(
+ func=accumulate_up_interface, compute_dims=(I_DIM, J_DIM, K_INTERFACE_DIM)
+ )
+
+ def accumulate_down(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ self._accumulate_down(in_field, out_field)
+
+ def accumulate_down_from_interface_field(self, interface_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ self._accumulate_down_from_interface_field(interface_field, out_field)
+
+ def accumulate_on_interface(self, interface_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ self._accumulate_on_interface(interface_field, out_field)
+
+ def accumulate_up(self, in_field: FloatField, out_field: FloatField) -> None: # type: ignore
+ self._accumulate_up(in_field, out_field)
+
+ def accumulate_up_interface(self, in_field: FloatField, interface_field: FloatField) -> None: # type: ignore
+ self._accumulate_up_interface(in_field, interface_field)
+
+
+class TestBoundariesK:
+ @pytest.fixture(
+ params=[
+ "orch:dace:cpu:IJK",
+ "orch:dace:cpu:KJI",
+ "st:dace:cpu:IJK",
+ "st:dace:cpu:KJI",
+ ]
+ )
+ def factories(self, request: pytest.FixtureRequest) -> Factories:
+ domain = (3, 4, 5)
+ return get_factories_single_tile(
+ nx=domain[0],
+ ny=domain[1],
+ nz=domain[2],
+ nhalo=0,
+ backend=Backend(request.param),
+ )
+
+ def test_accumulate_down(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="")
+ out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="")
+
+ code.accumulate_down(in_field, out_field)
+ assert np.array_equal(out_field.field[0, 0, :], [5, 4, 3, 2, 1])
+
+ def test_accumulate_interface_field(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ interface_field = quantity_factory.ones(
+ (I_DIM, J_DIM, K_INTERFACE_DIM), units=""
+ )
+ out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="")
+
+ code.accumulate_down_from_interface_field(interface_field, out_field)
+ assert np.array_equal(out_field.field[0, 0, :], [6, 5, 4, 3, 2])
+
+ def test_accumulate_interface_domain(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ interface_field = quantity_factory.ones(
+ (I_DIM, J_DIM, K_INTERFACE_DIM), units=""
+ )
+ out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="")
+
+ code.accumulate_on_interface(interface_field, out_field)
+ assert np.array_equal(out_field.field[0, 0, :], [6, 5, 4, 3, 2])
+
+ def test_accumulate_up(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="")
+ out_field = quantity_factory.zeros((I_DIM, J_DIM, K_DIM), units="")
+
+ code.accumulate_up(in_field, out_field)
+ assert np.array_equal(out_field.field[0, 0, :], [1, 2, 3, 4, 5])
+
+ def test_accumulate_up_interface(self, factories: Factories) -> None:
+ stencil_factory, quantity_factory = factories
+ code = OrchestratedCode(stencil_factory)
+
+ in_field = quantity_factory.ones((I_DIM, J_DIM, K_DIM), units="")
+ interface_field = quantity_factory.zeros(
+ (I_DIM, J_DIM, K_INTERFACE_DIM), units=""
+ )
+
+ code.accumulate_up_interface(in_field, interface_field)
+ assert np.array_equal(interface_field.field[0, 0, :], [1, 2, 3, 4, 5, 6])
diff --git a/tests/test_ndsl_runtime.py b/tests/test_ndsl_runtime.py
index 67e4f226..83694274 100644
--- a/tests/test_ndsl_runtime.py
+++ b/tests/test_ndsl_runtime.py
@@ -2,20 +2,19 @@
import pytest
-from ndsl import NDSLRuntime, QuantityFactory, StencilFactory
+from ndsl import (
+ NDSLRuntime,
+ OptimizationConfig,
+ QuantityFactory,
+ StencilFactory,
+ stencils,
+)
from ndsl.boilerplate import (
get_factories_single_tile,
get_factories_single_tile_orchestrated,
)
from ndsl.config import Backend
from ndsl.constants import I_DIM, J_DIM, K_DIM
-from ndsl.dsl.gt4py import PARALLEL, computation, interval
-from ndsl.dsl.typing import FloatField
-
-
-def the_copy_stencil(from_: FloatField, to: FloatField) -> None:
- with computation(PARALLEL), interval(...):
- to = from_
class Code(NDSLRuntime):
@@ -24,7 +23,7 @@ def __init__(
) -> None:
super().__init__(stencil_factory)
self.copy = stencil_factory.from_dims_halo(
- the_copy_stencil, compute_dims=[I_DIM, J_DIM, K_DIM]
+ stencils.copy, compute_dims=[I_DIM, J_DIM, K_DIM]
)
self.local = self.make_local(quantity_factory, [I_DIM, J_DIM, K_DIM])
@@ -105,3 +104,36 @@ def test_runtime_fail_when_not_super_init() -> None:
RuntimeError, match="inherit from NDSLRuntime but didn't call super()"
):
bad_code = BadCode_NoSuperInit()
+
+
+def test_runtime_with_performance_config() -> None:
+ class CustomPerformanceConfig(NDSLRuntime):
+ def __init__(
+ self,
+ stencil_factory: StencilFactory,
+ optimization_config: OptimizationConfig,
+ ) -> None:
+ super().__init__(stencil_factory, optimization_config)
+ self.copy = stencil_factory.from_dims_halo(
+ stencils.copy, compute_dims=[I_DIM, J_DIM, K_DIM]
+ )
+
+ def __call__(self, src, dst) -> None: # type: ignore[no-untyped-def]
+ self.copy(src, dst)
+
+ stencil_factory, quantity_factory = get_factories_single_tile_orchestrated(
+ nx=5, ny=5, nz=3, nhalo=0, backend=Backend.cpu()
+ )
+
+ # setup code
+ config = OptimizationConfig()
+ code = CustomPerformanceConfig(stencil_factory, config)
+
+ # setup inputs/outputs
+ src = quantity_factory.ones(dims=[I_DIM, J_DIM, K_DIM], units="n/a")
+ dst = quantity_factory.zeros(dims=[I_DIM, J_DIM, K_DIM], units="n/a")
+
+ # call code with inputs/outputs
+ code(src, dst)
+
+ assert (src.field[:] == dst.field[:]).all()