diff --git a/VTUNE.md b/VTUNE.md
index d77424c..d53be05 100644
--- a/VTUNE.md
+++ b/VTUNE.md
@@ -1,6 +1,9 @@
# VTune GPU Profiling
-Hardware-counter profiling for Triton kernels on Intel XPU using Intel VTune Profiler.
+Hardware-counter profiling on Intel XPU using Intel VTune Profiler. Two paths:
+**Triton/PyTorch** kernels (`gpu-offload` around a generated Python runner) and
+**SYCL/CUTLASS** `.cpp` kernels (`gpu-hotspots` characterization on the compiled
+binary). The path is selected by `--dsl`; see [SYCL kernels](#sycl-kernels) below.
---
@@ -125,6 +128,67 @@ xe-forge -i kernel.py -s spec.yaml --vtune --engine claude --workspace ./workspa
---
+## SYCL kernels
+
+SYCL/CUTLASS kernels are compiled `.cpp` binaries, so they are profiled
+differently from Triton: there is no Python runner. `SyclProfiler`
+(`core/sycl_profiler.py`) compiles the kernel via `SyclExecutor`, generates the
+same deterministic file-IO inputs the benchmark uses, and runs the binary
+directly under VTune `gpu-hotspots` in **characterization** mode — which exposes
+richer Intel Xe metrics than `gpu-offload`.
+
+```bash
+# Profile a SYCL kernel (point --vtune-bin at a 2026.x build if needed)
+xe-forge-skill profile examples/sycl/gemm.cpp \
+ --spec examples/sycl/gemm.yaml --dsl sycl --variant bench-xpu \
+ --iters 200 --vtune-bin /data/swtools/intel/vtune/2026.0/bin64/vtune
+```
+
+Under the hood:
+
+```bash
+vtune -collect gpu-hotspots \
+ -knob gpu-profiling-mode=characterization \
+ -knob characterization-mode=overview \
+ -result-dir
\
+ -- --m=M --n=N --k=K --input_dir= --output_dir= \
+ --iterations=200 --verify=0
+```
+
+### Metrics collected (SYCL)
+
+| Metric | Meaning |
+|--------|---------|
+| XVE Active / Stalled / Idle | Xe Vector Engine execution / stall / idle time |
+| Peak XVE Threads Occupancy | Thread occupancy (with Work-Size / SLM / Barrier sub-limiters) |
+| XMX (DPAS) Active | Fraction of time the matrix engine is busy — the key GEMM-efficiency signal |
+| GPU L3 Miss Ratio | L3 cache miss ratio |
+| GPU Memory Bandwidth Read/Write | GB/s to/from GPU memory |
+
+### Metric → CUTLASS knob (SYCL)
+
+| Condition | Diagnosis | Action | KB |
+|-----------|-----------|--------|----|
+| XVE Stalled > Active | Memory-bound mainloop | ↑ PipelineStages; 2D-block/VNNI copy atoms; ↓ TileK | `sycl_vtune.yaml` |
+| Peak occupancy < 50% | Grid too small / register pressure | Smaller TileShape (256→128); check 256-GRF | `sycl_vtune.yaml` |
+| XVE Idle > 30% | Work-distribution / tail | TileShape vs M/N; stream-K / persistent scheduler | `sycl_vtune.yaml` |
+| XMX active < 20% | Matrix engine underutilized | Larger N-per-subgroup; SubgroupLayout vs DPAS atom | `sycl_vtune.yaml` |
+| L3 miss > 50% | Cache thrashing | Reduce tiles; improve K-blocking/reuse | `sycl_vtune.yaml` |
+| Mem BW ≈ peak, low TFLOPS | Bandwidth-bound | Accept, or change algorithm | `sycl_vtune.yaml` |
+
+### SYCL-specific notes
+
+- **Self-checker false negative**: `vtune-self-checker.sh` may report GPU
+ profiling as unsupported (its bundled DPC++ app fails to launch), yet a real
+ AOT-compiled `bmg-g31` kernel profiles fine. Don't gate on the self-checker.
+- **`xe` kernel driver** (newer than `i915`) is supported by VTune 2026 for
+ `gpu-hotspots`; `perf_event_paranoid=0` helps.
+- **VTune version**: the config default `vtune_bin` may point at an older build;
+ pass `--vtune-bin /data/swtools/intel/vtune/2026.0/bin64/vtune` (or set
+ `VTUNE_BIN`) to use 2026.x.
+
+---
+
## Troubleshooting
**"VTune not found"** -- Ensure `vtune` is on `$PATH` after sourcing the oneAPI environment, or specify the path with `--vtune-bin`.
diff --git a/examples/sycl/README.md b/examples/sycl/README.md
new file mode 100644
index 0000000..7af5099
--- /dev/null
+++ b/examples/sycl/README.md
@@ -0,0 +1,80 @@
+# SYCL Claude Engine Example (Intel Xe / Battlemage)
+
+A worked, hardware-verified example of the **Claude Code engine for SYCL XPU
+kernels**: the agent rewrites a whole CUTLASS SYCL `.cpp` GEMM each trial, then
+compiles (`icpx`) + benchmarks it via `SyclExecutor`, with correctness checked
+against a **golden PyTorch reference** (`numpy.allclose` on the kernel's dumped
+`D2.bin`).
+
+| File | Role |
+|------|------|
+| `gemm.cpp` | Baseline kernel (t0). CUTLASS BMG GEMM `D = A·B0`, tile `256×256×32`. Honours the file-IO contract. |
+| `gemm_t1.cpp` | One optimization trial (t1). Same kernel, tile `128×128×32` — ~1.7× faster at 1024³ on an Arc Pro B70. |
+| `gemm.yaml` | KernelBench-style spec: GEMM dims `M=N=K=1024`, bf16, `rtol=atol=0.02`. |
+| `gemm_pytorch.py` | Golden PyTorch reference: `Model.forward(A, B0) -> A.float() @ B0.float()`. |
+
+## The file-IO contract
+
+Every SYCL kernel optimized by this engine is a standalone executable invoked as:
+
+```
+./kernel --m= --n= --k= --input_dir= --output_dir= --iterations= --verify=
+```
+
+It reads `A.bin` `[M,K]` and `B0.bin` `[K,N]` (raw row-major, bf16 stored as
+int16 bits) from `--input_dir`, computes `D = A·B0`, writes `D2.bin` `[M,N]`
+(float32, row-major) to `--output_dir`, and prints a `… TFlop/s … ms` line.
+Full spec: [`knowledge_base/sycl/xpu/sycl_io_contract.yaml`](../../knowledge_base/sycl/xpu/sycl_io_contract.yaml).
+
+## Environment (Intel XPU box)
+
+```bash
+export SYCL_TLA_DIR=/path/to/sycl-tla # CUTLASS SYCL checkout
+export AIBENCH_SYCL_TARGET=bmg-g31 # AOT target (Battlemage: B580/B570/B70)
+export MKL_INCLUDE=/path/to/oneapi/include
+export ONEAPI_DEVICE_SELECTOR="level_zero:gpu"
+export IGC_ExtraOCLOptions="-cl-intel-256-GRF-per-thread"
+export SYCL_PROGRAM_COMPILE_OPTIONS="-ze-opt-large-register-file -gline-tables-only"
+```
+
+## Reproduce the benchmark
+
+t0 — compiles both kernels, times the baseline, checks the trial vs the golden ref:
+
+```bash
+xe-forge-skill benchmark examples/sycl/gemm.cpp examples/sycl/gemm_t1.cpp \
+ --spec examples/sycl/gemm.yaml --dsl sycl --variant bench-xpu
+```
+
+```
+Correctness: PASSED
+Performance: baseline_us=193.80, triton_us=109.90, speedup=1.76x, tflops=19.54, util=12.2%
+```
+
+t1+ — reuse the cached baseline (no baseline recompile/rerun):
+
+```bash
+xe-forge-skill benchmark examples/sycl/gemm.cpp examples/sycl/gemm_t1.cpp \
+ --spec examples/sycl/gemm.yaml --dsl sycl --variant bench-xpu --baseline-us 193.80
+```
+
+The `triton_us=` token is kept verbatim across DSLs so the trial tooling parses
+uniformly; for SYCL it carries the optimized kernel's time in microseconds.
+`tflops=` is the achieved throughput and `util=` is its percentage of the
+device's theoretical peak (`peak_tflops`, default 160 TFLOPS bf16 for the B70;
+override with the `PEAK_TFLOPS` env var) — so `util=12.2%` means this trial
+reaches 12.2% of peak.
+
+## Generate an agentic workspace
+
+```bash
+python -m xe_forge.cli --input examples/sycl/gemm.cpp --name gemm \
+ --dsl sycl --engine claude --spec examples/sycl/gemm.yaml \
+ --variant bench-xpu --workspace /tmp/ws_sycl
+```
+
+This scaffolds a SYCL `CLAUDE.md`, an `/optimize-kernel` command wired with
+`--dsl sycl`, the kernel + `gemm_pytorch.py` golden reference under
+`test_kernels/`, and a `knowledge_base/` symlink. Passing a PyTorch-only `.py`
+input instead substitutes a compilable starter `.cpp` (a copy of `gemm.cpp`)
+and uses the `.py` as the golden reference.
diff --git a/examples/sycl/gemm.cpp b/examples/sycl/gemm.cpp
new file mode 100644
index 0000000..9cf75d5
--- /dev/null
+++ b/examples/sycl/gemm.cpp
@@ -0,0 +1,289 @@
+/*
+ Example SYCL baseline kernel (t0) for the Xe-Forge Claude engine — a minimal,
+ compilable CUTLASS BMG GEMM (D = A * B0) that honours the Xe-Forge file-IO
+ contract. This is the same source the workspace generator emits as the
+ starter stub; the optimizer iterates from here (see gemm_t1.cpp for one such
+ trial that swaps the 256x256x32 tile for 128x128x32).
+
+ File-IO contract (see knowledge_base/sycl/xpu/sycl_io_contract.yaml):
+ --m/--n/--k problem dims
+ --input_dir= read A.bin [M,K], B0.bin [K,N] (bf16 as raw int16 bits)
+ --output_dir= write D2.bin [M,N] as float32, row-major
+ --iterations= timed iterations
+ --verify= ignored (correctness checked in Python vs golden ref)
+
+ Prints a "[]TFlop/s ()ms" line that SyclExecutor parses.
+
+ Derived from sycl-tla/examples/00_bmg_gemm/00_bmg_gemm.cpp. CORRECT but
+ unoptimized — tune TileShape, SubgroupLayout, PipelineStages, copy atoms,
+ and the epilogue.
+*/
+
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/collective/xe_epilogue.hpp"
+#include "cutlass/epilogue/fusion/xe_callbacks.hpp"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/collective/collective_mma.hpp"
+#include "cutlass/util/GPU_Clock.hpp"
+
+#include
+#include
+#include
+#include
+
+#include "cutlass/util/command_line.h"
+#include "cutlass/util/device_memory.h"
+#include "cutlass/util/packed_stride.hpp"
+#include "cutlass/util/reference/device/gemm_complex.h"
+#include "cutlass/util/reference/device/tensor_compare.h"
+#include "sycl_common.hpp"
+#include "helper.h"
+
+using namespace cute;
+
+struct Options {
+ bool help = false;
+ bool error = false;
+ int m = 5120, n = 4096, k = 4096, l = 1, iterations = 20, verify = 1;
+ float alpha = 1.f, beta = 0.f;
+ std::string input_dir;
+ std::string output_dir;
+
+ void parse(int argc, char const **args) {
+ cutlass::CommandLine cmd(argc, args);
+ if (cmd.check_cmd_line_flag("help")) { help = true; return; }
+ cmd.get_cmd_line_argument("m", m, 5120);
+ cmd.get_cmd_line_argument("n", n, 4096);
+ cmd.get_cmd_line_argument("k", k, 4096);
+ cmd.get_cmd_line_argument("l", l, 1);
+ cmd.get_cmd_line_argument("alpha", alpha, 1.f);
+ cmd.get_cmd_line_argument("beta", beta, 0.f);
+ cmd.get_cmd_line_argument("iterations", iterations, 20);
+ cmd.get_cmd_line_argument("verify", verify, 1);
+ cmd.get_cmd_line_argument("input_dir", input_dir, std::string(""));
+ cmd.get_cmd_line_argument("output_dir", output_dir, std::string(""));
+ }
+};
+
+template
+static std::vector read_bin(const std::string& dir, const std::string& name, size_t count) {
+ std::string path = dir + "/" + name;
+ std::ifstream f(path, std::ios::binary);
+ if (!f) { std::cerr << "Could not open " << path << std::endl; std::exit(2); }
+ std::vector buf(count);
+ f.read(reinterpret_cast(buf.data()), count * sizeof(T));
+ if (static_cast(f.gcount()) != count * sizeof(T)) {
+ std::cerr << "Short read on " << path << ": got " << f.gcount()
+ << " want " << count * sizeof(T) << std::endl;
+ std::exit(2);
+ }
+ return buf;
+}
+
+template
+static void write_bin(const std::string& dir, const std::string& name, const std::vector& buf) {
+ std::string path = dir + "/" + name;
+ std::ofstream f(path, std::ios::binary);
+ if (!f) { std::cerr << "Could not open for write " << path << std::endl; std::exit(2); }
+ f.write(reinterpret_cast(buf.data()), buf.size() * sizeof(T));
+}
+
+template
+struct ExampleRunner {
+ using StrideA = typename Gemm::GemmKernel::StrideA;
+ using StrideB = typename Gemm::GemmKernel::StrideB;
+ using StrideC = typename Gemm::GemmKernel::StrideC;
+ using StrideD = typename Gemm::GemmKernel::StrideD;
+
+ using LayoutA = typename Gemm::LayoutA;
+ using LayoutB = typename Gemm::LayoutB;
+ using LayoutC = typename Gemm::LayoutC;
+ using LayoutD = typename Gemm::LayoutD;
+
+ using ElementA = typename Gemm::ElementA;
+ using ElementB = typename Gemm::ElementB;
+ using ElementAccumulator = typename Gemm::ElementAccumulator;
+
+ using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
+ using ElementC = typename Gemm::ElementC;
+ using ElementOutput = typename CollectiveEpilogue::ElementOutput;
+ using ElementCompute = typename CollectiveEpilogue::ElementCompute;
+
+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
+
+ StrideA stride_A;
+ StrideB stride_B;
+ StrideC stride_C;
+ StrideD stride_D;
+ uint64_t seed = 0;
+
+ cutlass::DeviceAllocation block_A;
+ cutlass::DeviceAllocation block_B;
+ cutlass::DeviceAllocation block_C;
+ cutlass::DeviceAllocation block_D;
+
+ void initialize(const ProblemShapeType& problem_size, const Options& options) {
+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
+ auto [M, N, K, L] = problem_shape_MNKL;
+
+ stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
+ stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
+ stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
+ stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
+
+ block_A.reset(static_cast(M) * K * L);
+ block_B.reset(static_cast(K) * N * L);
+ block_C.reset(static_cast(M) * N * L);
+ block_D.reset(static_cast(M) * N * L);
+
+ // A [M,K] and B0 [K,N] are raw bf16 bits (bit-identical to torch bf16).
+ auto host_A = read_bin(options.input_dir, "A.bin",
+ static_cast(M) * K * L);
+ auto host_B = read_bin(options.input_dir, "B0.bin",
+ static_cast(K) * N * L);
+ block_A.copy_from_host(host_A.data());
+ block_B.copy_from_host(host_B.data());
+
+ // C unused (beta = 0); zero-fill so the epilogue reads valid memory.
+ std::vector host_C(static_cast(M) * N * L, ElementC(0));
+ block_C.copy_from_host(host_C.data());
+ }
+
+ void dump_output(const ProblemShapeType& problem_size, const Options& options) {
+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
+ auto [M, N, K, L] = problem_shape_MNKL;
+ std::vector host_D(static_cast(M) * N * L);
+ block_D.copy_to_host(host_D.data());
+ write_bin(options.output_dir, "D2.bin", host_D); // ElementOutput = float32
+ }
+
+ cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
+ ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
+
+ initialize(problem_size, options);
+
+ typename Gemm::GemmKernel::Arguments arguments{
+ cutlass::gemm::GemmUniversalMode::kGemm,
+ problem_size,
+ {block_A.get(), stride_A, block_B.get(), stride_B},
+ {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
+ hw_info
+ };
+
+ Gemm gemm_op;
+
+ size_t workspace_size = Gemm::get_workspace_size(arguments);
+ cutlass::device_memory::allocation workspace(workspace_size);
+
+ if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) {
+ std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x'
+ << options.k << 'x' << options.l << std::endl;
+ std::exit(1);
+ }
+
+ CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
+ CUTLASS_CHECK(gemm_op.run());
+ compat::wait();
+
+ if (!options.output_dir.empty()) {
+ dump_output(problem_size, options);
+ }
+ std::cout << "Disposition: Passed" << std::endl;
+
+ if (options.iterations > 0) {
+ GPU_Clock timer;
+ timer.start();
+ for (int i = 0; i < options.iterations; ++i) {
+ gemm_op.run();
+ }
+ compat::wait();
+
+ float cute_time = timer.seconds() / options.iterations;
+ double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12;
+ std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x'
+ << options.k << 'x' << options.l << std::endl;
+ printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n",
+ tflops / cute_time, cute_time * 1000);
+ }
+
+ return cutlass::Status::kSuccess;
+ }
+};
+
+int main(int argc, const char** argv) {
+ Options options;
+ options.parse(argc, argv);
+ if (options.help) { std::cout << "Xe-Forge SYCL starter GEMM\n"; return 0; }
+ if (options.error) { std::cerr << "Aborting execution." << std::endl; return -1; }
+
+ cutlass::KernelHardwareInfo hw_info;
+ hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+
+ using ElementAccumulator = float;
+ using ElementComputeEpilogue = float;
+ using ElementInputA = bfloat16_t;
+ using ElementInputB = bfloat16_t;
+ using ElementOutput = float;
+
+ using LayoutA = cutlass::layout::RowMajor;
+ using LayoutB = cutlass::layout::RowMajor;
+ using LayoutC = cutlass::layout::RowMajor;
+ using LayoutD = cutlass::layout::RowMajor;
+
+ using GmemTiledCopyA = void;
+ using GmemTiledCopyB = void;
+
+ // Workgroup tile — the primary thing to tune.
+ using TileShape = Shape<_256, _256, _32>;
+
+ using TiledMma = typename TiledMMAHelper>,
+ Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA;
+
+ constexpr int PipelineStages = 2;
+ using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged;
+ using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;
+
+ using EpilogueOp = cutlass::epilogue::fusion::LinearCombination;
+
+ using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks;
+
+ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
+ EpilogueDispatchPolicy,
+ TileShape,
+ void,
+ ElementAccumulator,
+ cutlass::gemm::TagToStrideC_t,
+ ElementOutput,
+ cutlass::gemm::TagToStrideC_t,
+ FusionCallbacks,
+ void,
+ void>;
+
+ using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
+ GEMMDispatchPolicy,
+ TileShape,
+ ElementInputA,
+ cutlass::gemm::TagToStrideA_t,
+ ElementInputB,
+ cutlass::gemm::TagToStrideB_t,
+ TiledMma,
+ GmemTiledCopyA, void, void, cute::identity,
+ GmemTiledCopyB, void, void, cute::identity
+ >;
+
+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
+ Shape,
+ CollectiveMainloop,
+ CollectiveEpilogue
+ >;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
+
+ ExampleRunner runner;
+ CUTLASS_CHECK(runner.run(options, hw_info));
+
+ return 0;
+}
diff --git a/examples/sycl/gemm.yaml b/examples/sycl/gemm.yaml
new file mode 100644
index 0000000..4a65ee6
--- /dev/null
+++ b/examples/sycl/gemm.yaml
@@ -0,0 +1,18 @@
+inputs:
+ A:
+ shape: [M, K]
+ dtype: bfloat16
+ B0:
+ shape: [K, N]
+ dtype: bfloat16
+
+bench-xpu:
+ - params: [A, B0]
+ dtype: bfloat16
+ dims:
+ M: 1024
+ N: 1024
+ K: 1024
+ flop: "2*M*N*K"
+ rtol: 0.02
+ atol: 0.02
diff --git a/examples/sycl/gemm_pytorch.py b/examples/sycl/gemm_pytorch.py
new file mode 100644
index 0000000..2940050
--- /dev/null
+++ b/examples/sycl/gemm_pytorch.py
@@ -0,0 +1,23 @@
+"""Golden PyTorch reference for a plain bf16 GEMM: D = A @ B0 (f32 accumulate)."""
+
+import torch
+import torch.nn as nn
+
+
+class Model(nn.Module):
+ def forward(self, A, B0):
+ # f32 accumulate to match the kernel's ElementAccumulator = float.
+ return A.float() @ B0.float()
+
+
+def get_inputs():
+ # Shapes are illustrative; the benchmark harness feeds A/B0 from the .bin
+ # files generated for the spec's dims, not from here.
+ return [
+ torch.randn(1024, 1024, dtype=torch.bfloat16),
+ torch.randn(1024, 1024, dtype=torch.bfloat16),
+ ]
+
+
+def get_init_inputs():
+ return []
diff --git a/examples/sycl/gemm_t1.cpp b/examples/sycl/gemm_t1.cpp
new file mode 100644
index 0000000..552f346
--- /dev/null
+++ b/examples/sycl/gemm_t1.cpp
@@ -0,0 +1,287 @@
+/*
+ Example SYCL trial kernel (t1) for the Xe-Forge Claude engine — gemm.cpp with
+ a smaller workgroup tile (128x128x32 instead of 256x256x32). On an Intel Arc
+ Pro B70 at M=N=K=1024 this is ~1.7x faster than the t0 baseline while still
+ matching the PyTorch golden reference. Illustrates one optimization step the
+ agent takes; the file-IO scaffolding is unchanged from gemm.cpp.
+
+ File-IO contract (see knowledge_base/sycl/xpu/sycl_io_contract.yaml):
+ --m/--n/--k problem dims
+ --input_dir= read A.bin [M,K], B0.bin [K,N] (bf16 as raw int16 bits)
+ --output_dir= write D2.bin [M,N] as float32, row-major
+ --iterations= timed iterations
+ --verify= ignored (correctness checked in Python vs golden ref)
+
+ Prints a "[]TFlop/s ()ms" line that SyclExecutor parses.
+
+ Derived from sycl-tla/examples/00_bmg_gemm/00_bmg_gemm.cpp.
+*/
+
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/collective/xe_epilogue.hpp"
+#include "cutlass/epilogue/fusion/xe_callbacks.hpp"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/collective/collective_mma.hpp"
+#include "cutlass/util/GPU_Clock.hpp"
+
+#include
+#include
+#include
+#include
+
+#include "cutlass/util/command_line.h"
+#include "cutlass/util/device_memory.h"
+#include "cutlass/util/packed_stride.hpp"
+#include "cutlass/util/reference/device/gemm_complex.h"
+#include "cutlass/util/reference/device/tensor_compare.h"
+#include "sycl_common.hpp"
+#include "helper.h"
+
+using namespace cute;
+
+struct Options {
+ bool help = false;
+ bool error = false;
+ int m = 5120, n = 4096, k = 4096, l = 1, iterations = 20, verify = 1;
+ float alpha = 1.f, beta = 0.f;
+ std::string input_dir;
+ std::string output_dir;
+
+ void parse(int argc, char const **args) {
+ cutlass::CommandLine cmd(argc, args);
+ if (cmd.check_cmd_line_flag("help")) { help = true; return; }
+ cmd.get_cmd_line_argument("m", m, 5120);
+ cmd.get_cmd_line_argument("n", n, 4096);
+ cmd.get_cmd_line_argument("k", k, 4096);
+ cmd.get_cmd_line_argument("l", l, 1);
+ cmd.get_cmd_line_argument("alpha", alpha, 1.f);
+ cmd.get_cmd_line_argument("beta", beta, 0.f);
+ cmd.get_cmd_line_argument("iterations", iterations, 20);
+ cmd.get_cmd_line_argument("verify", verify, 1);
+ cmd.get_cmd_line_argument("input_dir", input_dir, std::string(""));
+ cmd.get_cmd_line_argument("output_dir", output_dir, std::string(""));
+ }
+};
+
+template
+static std::vector read_bin(const std::string& dir, const std::string& name, size_t count) {
+ std::string path = dir + "/" + name;
+ std::ifstream f(path, std::ios::binary);
+ if (!f) { std::cerr << "Could not open " << path << std::endl; std::exit(2); }
+ std::vector buf(count);
+ f.read(reinterpret_cast(buf.data()), count * sizeof(T));
+ if (static_cast(f.gcount()) != count * sizeof(T)) {
+ std::cerr << "Short read on " << path << ": got " << f.gcount()
+ << " want " << count * sizeof(T) << std::endl;
+ std::exit(2);
+ }
+ return buf;
+}
+
+template
+static void write_bin(const std::string& dir, const std::string& name, const std::vector& buf) {
+ std::string path = dir + "/" + name;
+ std::ofstream f(path, std::ios::binary);
+ if (!f) { std::cerr << "Could not open for write " << path << std::endl; std::exit(2); }
+ f.write(reinterpret_cast(buf.data()), buf.size() * sizeof(T));
+}
+
+template
+struct ExampleRunner {
+ using StrideA = typename Gemm::GemmKernel::StrideA;
+ using StrideB = typename Gemm::GemmKernel::StrideB;
+ using StrideC = typename Gemm::GemmKernel::StrideC;
+ using StrideD = typename Gemm::GemmKernel::StrideD;
+
+ using LayoutA = typename Gemm::LayoutA;
+ using LayoutB = typename Gemm::LayoutB;
+ using LayoutC = typename Gemm::LayoutC;
+ using LayoutD = typename Gemm::LayoutD;
+
+ using ElementA = typename Gemm::ElementA;
+ using ElementB = typename Gemm::ElementB;
+ using ElementAccumulator = typename Gemm::ElementAccumulator;
+
+ using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
+ using ElementC = typename Gemm::ElementC;
+ using ElementOutput = typename CollectiveEpilogue::ElementOutput;
+ using ElementCompute = typename CollectiveEpilogue::ElementCompute;
+
+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
+
+ StrideA stride_A;
+ StrideB stride_B;
+ StrideC stride_C;
+ StrideD stride_D;
+ uint64_t seed = 0;
+
+ cutlass::DeviceAllocation block_A;
+ cutlass::DeviceAllocation block_B;
+ cutlass::DeviceAllocation block_C;
+ cutlass::DeviceAllocation block_D;
+
+ void initialize(const ProblemShapeType& problem_size, const Options& options) {
+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
+ auto [M, N, K, L] = problem_shape_MNKL;
+
+ stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
+ stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
+ stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
+ stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
+
+ block_A.reset(static_cast(M) * K * L);
+ block_B.reset(static_cast(K) * N * L);
+ block_C.reset(static_cast(M) * N * L);
+ block_D.reset(static_cast(M) * N * L);
+
+ // A [M,K] and B0 [K,N] are raw bf16 bits (bit-identical to torch bf16).
+ auto host_A = read_bin(options.input_dir, "A.bin",
+ static_cast(M) * K * L);
+ auto host_B = read_bin(options.input_dir, "B0.bin",
+ static_cast(K) * N * L);
+ block_A.copy_from_host(host_A.data());
+ block_B.copy_from_host(host_B.data());
+
+ // C unused (beta = 0); zero-fill so the epilogue reads valid memory.
+ std::vector host_C(static_cast(M) * N * L, ElementC(0));
+ block_C.copy_from_host(host_C.data());
+ }
+
+ void dump_output(const ProblemShapeType& problem_size, const Options& options) {
+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
+ auto [M, N, K, L] = problem_shape_MNKL;
+ std::vector host_D(static_cast(M) * N * L);
+ block_D.copy_to_host(host_D.data());
+ write_bin(options.output_dir, "D2.bin", host_D); // ElementOutput = float32
+ }
+
+ cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
+ ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
+
+ initialize(problem_size, options);
+
+ typename Gemm::GemmKernel::Arguments arguments{
+ cutlass::gemm::GemmUniversalMode::kGemm,
+ problem_size,
+ {block_A.get(), stride_A, block_B.get(), stride_B},
+ {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
+ hw_info
+ };
+
+ Gemm gemm_op;
+
+ size_t workspace_size = Gemm::get_workspace_size(arguments);
+ cutlass::device_memory::allocation workspace(workspace_size);
+
+ if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) {
+ std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x'
+ << options.k << 'x' << options.l << std::endl;
+ std::exit(1);
+ }
+
+ CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
+ CUTLASS_CHECK(gemm_op.run());
+ compat::wait();
+
+ if (!options.output_dir.empty()) {
+ dump_output(problem_size, options);
+ }
+ std::cout << "Disposition: Passed" << std::endl;
+
+ if (options.iterations > 0) {
+ GPU_Clock timer;
+ timer.start();
+ for (int i = 0; i < options.iterations; ++i) {
+ gemm_op.run();
+ }
+ compat::wait();
+
+ float cute_time = timer.seconds() / options.iterations;
+ double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12;
+ std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x'
+ << options.k << 'x' << options.l << std::endl;
+ printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n",
+ tflops / cute_time, cute_time * 1000);
+ }
+
+ return cutlass::Status::kSuccess;
+ }
+};
+
+int main(int argc, const char** argv) {
+ Options options;
+ options.parse(argc, argv);
+ if (options.help) { std::cout << "Xe-Forge SYCL starter GEMM\n"; return 0; }
+ if (options.error) { std::cerr << "Aborting execution." << std::endl; return -1; }
+
+ cutlass::KernelHardwareInfo hw_info;
+ hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+
+ using ElementAccumulator = float;
+ using ElementComputeEpilogue = float;
+ using ElementInputA = bfloat16_t;
+ using ElementInputB = bfloat16_t;
+ using ElementOutput = float;
+
+ using LayoutA = cutlass::layout::RowMajor;
+ using LayoutB = cutlass::layout::RowMajor;
+ using LayoutC = cutlass::layout::RowMajor;
+ using LayoutD = cutlass::layout::RowMajor;
+
+ using GmemTiledCopyA = void;
+ using GmemTiledCopyB = void;
+
+ // Workgroup tile — the primary thing to tune.
+ using TileShape = Shape<_128, _128, _32>;
+
+ using TiledMma = typename TiledMMAHelper>,
+ Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA;
+
+ constexpr int PipelineStages = 2;
+ using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged;
+ using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;
+
+ using EpilogueOp = cutlass::epilogue::fusion::LinearCombination;
+
+ using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks;
+
+ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
+ EpilogueDispatchPolicy,
+ TileShape,
+ void,
+ ElementAccumulator,
+ cutlass::gemm::TagToStrideC_t,
+ ElementOutput,
+ cutlass::gemm::TagToStrideC_t,
+ FusionCallbacks,
+ void,
+ void>;
+
+ using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
+ GEMMDispatchPolicy,
+ TileShape,
+ ElementInputA,
+ cutlass::gemm::TagToStrideA_t,
+ ElementInputB,
+ cutlass::gemm::TagToStrideB_t,
+ TiledMma,
+ GmemTiledCopyA, void, void, cute::identity,
+ GmemTiledCopyB, void, void, cute::identity
+ >;
+
+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
+ Shape,
+ CollectiveMainloop,
+ CollectiveEpilogue
+ >;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
+
+ ExampleRunner runner;
+ CUTLASS_CHECK(runner.run(options, hw_info));
+
+ return 0;
+}
diff --git a/knowledge_base/sycl/xpu/sycl_io_contract.yaml b/knowledge_base/sycl/xpu/sycl_io_contract.yaml
new file mode 100644
index 0000000..0d77dff
--- /dev/null
+++ b/knowledge_base/sycl/xpu/sycl_io_contract.yaml
@@ -0,0 +1,121 @@
+# Xe-Forge SYCL runner harness contract for Intel Xe (Battlemage)
+#
+# This documents the file-IO + CLI contract that every SYCL kernel optimized by
+# the Claude Code engine MUST honour. The contract is enforced by SyclExecutor
+# (src/xe_forge/core/sycl_executor.py): inputs are generated once as .bin files,
+# the kernel reads them, computes, and dumps its output as a .bin file that is
+# compared in Python against a PyTorch/numpy golden reference.
+#
+# A kernel that does not honour this contract cannot be benchmarked — it will be
+# reported as a correctness failure ("no D2.bin produced").
+
+io_contract:
+ - id: sycl_cli_arguments
+ name: "Required command-line arguments"
+ description: |
+ The kernel binary is invoked as:
+
+ ./kernel --m= --n= --k= \
+ --input_dir= --output_dir= \
+ --iterations= --verify=
+
+ Parse them with cutlass::CommandLine (see cutlass/util/command_line.h):
+
+ cutlass::CommandLine cmd(argc, args);
+ cmd.get_cmd_line_argument("m", m, 5120);
+ cmd.get_cmd_line_argument("n", n, 4096);
+ cmd.get_cmd_line_argument("k", k, 4096);
+ cmd.get_cmd_line_argument("iterations", iterations, 20);
+ cmd.get_cmd_line_argument("verify", verify, 1);
+ std::string input_dir, output_dir;
+ cmd.get_cmd_line_argument("input_dir", input_dir, std::string(""));
+ cmd.get_cmd_line_argument("output_dir", output_dir, std::string(""));
+
+ `--verify` is supplied but should be IGNORED for file-IO runs: the runner
+ sets verify=0 because correctness is checked in Python against the golden
+ reference, not by the kernel's internal reference GEMM.
+ applies_to: [gemm, all_sycl_kernels]
+
+ - id: sycl_input_layout
+ name: "Input tensor .bin layout (read from --input_dir)"
+ description: |
+ The runner writes these files into --input_dir before launching the kernel
+ (generated once, deterministic seed=42, cached across trials):
+
+ A.bin : shape [M, K], row-major
+ B0.bin : shape [K, N], row-major
+ B1.bin : shape [K, N], row-major (dual-GEMM second operand; ignore for
+ plain GEMM)
+
+ Element encoding by dtype:
+ * bfloat16 : raw 16-bit bf16 bit-pattern (torch stores bf16 as int16
+ bits; cute::bfloat16_t is bit-identical, so a raw byte read into a
+ bfloat16_t* buffer is correct — NO conversion needed).
+ * float16 : raw IEEE half (2 bytes).
+ * float32 : raw IEEE float (4 bytes).
+
+ Read into host memory then copy to a cutlass::DeviceAllocation:
+
+ std::ifstream f(input_dir + "/A.bin", std::ios::binary);
+ std::vector hostA(size_t(M) * K);
+ f.read(reinterpret_cast(hostA.data()), hostA.size() * sizeof(ElementA));
+ block_A.copy_from_host(hostA.data());
+
+ Matrix A is the left operand [M,K]; B0 is the right operand [K,N]; the GEMM
+ computes D = A * B0 (alpha=1, beta=0 by default).
+ applies_to: [gemm, dual_gemm]
+
+ - id: sycl_output_layout
+ name: "Output tensor .bin layout (write to --output_dir)"
+ description: |
+ After the (untimed) correctness run, dump the result to --output_dir:
+
+ D2.bin : shape [M, N], row-major, FLOAT32 (raw IEEE float, M*N*4 bytes)
+
+ Always store D2.bin as float32 regardless of the input dtype — the golden
+ reference is compared in float32. Copy device output back to host first:
+
+ std::vector hostD(size_t(M) * N);
+ block_D.copy_to_host(hostD.data()); // ElementOutput must be float
+ std::ofstream o(output_dir + "/D2.bin", std::ios::binary);
+ o.write(reinterpret_cast(hostD.data()), hostD.size() * sizeof(float));
+
+ Only write D2.bin when --output_dir is non-empty. The name is literally
+ "D2.bin" (a historical dual-GEMM artifact; it is the single GEMM output D).
+ applies_to: [gemm, dual_gemm]
+
+ - id: sycl_perf_line
+ name: "Performance output line (parsed by SyclExecutor)"
+ description: |
+ Print a timing line in EXACTLY this format so _parse_raw_output matches it
+ (regex: `\[([0-9.]+)\]\s*TFlop/s\s+\(([0-9.]+)\)\s*ms`):
+
+ printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n",
+ tflops / cute_time, cute_time * 1000);
+
+ where tflops = 2*M*N*K*L * 1e-12 and cute_time is seconds per iteration.
+ Also print "Disposition: Passed" — though the Python golden check is
+ authoritative, the parser reads this token.
+
+ Time the kernel over --iterations launches with GPU_Clock + compat::wait(),
+ NOT including the file IO or the correctness run.
+ applies_to: [gemm, all_sycl_kernels]
+
+ - id: sycl_correctness_rules
+ name: "SYCL/CUTLASS correctness constraints (Xe Battlemage)"
+ description: |
+ 1. TileShape (e.g. Shape<_256,_256,_32>) and SubgroupLayout (e.g.
+ Shape<_8,_4,_1>) must be consistent with the DPAS MMA atom
+ (XE_DPAS_TT<8, float, bfloat16_t> = 8x16x16). M-per-subgroup =
+ TileM / SgRows, N-per-subgroup = TileN / SgCols must be integers.
+ 2. SLM budget on Battlemage is 128 KB/work-group. Double-buffered staging
+ of A+B tiles must fit: 2*(TileM*TileK + TileK*TileN)*sizeof(elem).
+ 3. Accumulate in float32 (ElementAccumulator = float) for bf16/f16 inputs.
+ 4. bf16 inputs use VNNI-friendly layouts internally; the COLLECTIVE builder
+ handles this — keep LayoutA/B = RowMajor for the A[M,K], B[K,N] contract.
+ 5. ElementOutput MUST be float for D2.bin (the golden reference is f32).
+
+ Tolerances: a correct bf16 GEMM matches an f32 PyTorch matmul at
+ rtol=1e-2, atol=1e-2 (empirically max abs diff ~2e-5 at K=512; scales with
+ K). These are spec-driven via the YAML rtol/atol fields.
+ applies_to: [gemm, bf16, correctness]
diff --git a/knowledge_base/sycl/xpu/sycl_vtune.yaml b/knowledge_base/sycl/xpu/sycl_vtune.yaml
new file mode 100644
index 0000000..dda914f
--- /dev/null
+++ b/knowledge_base/sycl/xpu/sycl_vtune.yaml
@@ -0,0 +1,123 @@
+# VTune GPU profiling for SYCL/CUTLASS kernels on Intel Xe (Battlemage)
+#
+# How Xe-Forge profiles a compiled SYCL GEMM kernel and how to turn the VTune
+# gpu-hotspots hardware metrics into CUTLASS tuning decisions. Used by
+# core/sycl_profiler.py (SyclProfiler) and surfaced to the agent in the Profile
+# step of CLAUDE.md. Verbs/columns confirmed on VTune 2026.0 + Arc Pro B70.
+
+profiling_harness:
+ - id: sycl_vtune_collect
+ name: "How a SYCL kernel is profiled (gpu-hotspots characterization)"
+ description: |
+ The kernel is compiled (icpx, via SyclExecutor) and run as a standalone
+ binary directly under VTune — there is NO Python runner (unlike the Triton
+ path's gpu-offload). The binary loops --iterations internally; VTune
+ samples GPU hardware metrics over that loop.
+
+ Collect:
+ vtune -collect gpu-hotspots \
+ -knob gpu-profiling-mode=characterization \
+ -knob characterization-mode=overview \
+ -result-dir \
+ -- --m=M --n=N --k=K \
+ --input_dir= --output_dir= --iterations=200 --verify=0
+
+ Report (per GPU kernel):
+ vtune -report hotspots -r -group-by computing-task \
+ -column "" -format csv -csv-delimiter tab
+
+ Notes / gotchas:
+ - Use a high --iterations (e.g. 200) for stable sampling; correctness is
+ NOT checked here (golden comparison happens during benchmarking).
+ - The CUTLASS kernel appears under its full templated GemmUniversal<...>
+ type name; overhead tasks (zeCommandListAppendMemoryCopy, "[Outside any
+ task]") are filtered out — the primary kernel is the longest-running
+ non-overhead computing task.
+ - Column names containing a literal comma ("GPU Memory Bandwidth,
+ GB/sec:Read") must be requested in a SEPARATE report pass, or the
+ -column comma-list parser mis-splits them.
+ - VTune's vtune-self-checker.sh may FALSELY report GPU profiling as
+ unsupported on this box (its bundled DPC++ test app fails to launch);
+ a real AOT bmg-g31 kernel profiles correctly regardless.
+ applies_to: [gemm, profiling, vtune]
+
+metrics_to_knobs:
+ - id: vtune_memory_bound
+ name: "XVE Stalled > Active -> memory-bound mainloop"
+ description: |
+ When XVE Array:Stalled exceeds XVE Array:Active, execution units are
+ waiting on data rather than computing. For a CUTLASS Xe GEMM this points
+ at the mainloop's global->register/SLM feed.
+
+ Actions (in rough order):
+ - Increase PipelineStages (deeper A/B prefetch ahead of the DPAS).
+ e.g. MainloopXeL1Staged<3> or <4> instead of <2>.
+ - Use explicit 2D-block / VNNI copy atoms (XE_LOAD_2D / XE_LOAD_2D_VNNI)
+ instead of the auto (void) copy.
+ - Reduce TileK so each K-iteration's load is smaller / better overlapped.
+ applies_to: [gemm, memory_bound]
+
+ - id: vtune_low_occupancy
+ name: "Peak occupancy < 50% -> too few threads resident"
+ description: |
+ Low Peak XVE Threads Occupancy means the GPU's 256 XVEs aren't filled.
+ Check the occupancy sub-limiters VTune reports:
+ - Work Size Limit low -> the grid is too small for the problem (common
+ for small M/N, or an over-large TileShape that yields too few
+ work-groups). Try a SMALLER TileShape (e.g. 256x256 -> 128x128) so more
+ work-groups are launched.
+ - SLM Use Limit / Barriers Use Limit low -> per-work-group SLM or barrier
+ pressure caps residency; shrink staged tiles or stages.
+ Also confirm 256-GRF mode is on (IGC_ExtraOCLOptions). For genuinely small
+ problems, work-size-limited occupancy is expected, not a bug.
+ applies_to: [gemm, occupancy]
+
+ - id: vtune_high_idle
+ name: "XVE Idle > 30% -> work-distribution / tail effects"
+ description: |
+ High XVE Array:Idle means EUs have no work scheduled — usually uneven tile
+ distribution or tail work-groups. Revisit TileShape vs the M/N extents so
+ tiles divide evenly; consider a stream-K or persistent tile scheduler to
+ balance load across EUs.
+ applies_to: [gemm, scheduling]
+
+ - id: vtune_low_xmx
+ name: "XMX (DPAS) active low -> matrix engine underutilized"
+ description: |
+ XVE Pipelines:XMX active is the fraction of time the DPAS/XMX matrix engine
+ is busy — the metric that most directly reflects GEMM efficiency. Low XMX%
+ with high stall/idle means the kernel rarely reaches the matmul. Raise
+ arithmetic intensity per subgroup: larger N-per-subgroup, verify the
+ SubgroupLayout (e.g. Shape<_8,_4,_1>) is consistent with the DPAS atom
+ (XE_DPAS_TT<8, float, bfloat16_t> = 8x16x16), and prefer larger tiles when
+ occupancy allows.
+ applies_to: [gemm, dpas, compute_bound]
+
+ - id: vtune_l3_thrashing
+ name: "L3 miss ratio high -> cache thrashing / poor reuse"
+ description: |
+ High GPU L3:Miss Ratio indicates the working set isn't being reused in L3.
+ Reduce tile sizes or improve K-blocking / data reuse so A and B sub-tiles
+ stay resident across the inner loop.
+ applies_to: [gemm, cache]
+
+ - id: vtune_bandwidth_bound
+ name: "GPU memory bandwidth near peak with low TFLOPS"
+ description: |
+ If GPU Memory Bandwidth (Read+Write) is near the device peak while TFLOPS /
+ XMX% stay low, the kernel is genuinely bandwidth-bound — no tile change will
+ help much. Accept the result for that shape, or change the algorithm
+ (e.g. fuse to cut traffic). Cross-check against the util% from benchmarking
+ (peak ~160 TFLOPS bf16 on the B70).
+ applies_to: [gemm, bandwidth_bound]
+
+reference_reading:
+ description: |
+ Example gpu-hotspots reading for the starter GEMM (TileShape 256x256x32,
+ M=N=K=1024) on the Arc Pro B70: XVE Active 7.3%, Stalled 39.4%, Idle 53.3%,
+ Peak occupancy 25% (work-size-limited), XMX active 5%, L3 miss 0.7%. Diagnosis:
+ memory-bound AND under-occupied — the 256x256 tile produces too few
+ work-groups at 1024^3 to fill 256 XVEs, and the matrix engine is idle most of
+ the time. Consistent with the ~12-16% peak utilization seen in benchmarking;
+ a smaller tile (128x128x32) both raises occupancy and improves measured
+ throughput.
diff --git a/src/xe_forge/claude/generator.py b/src/xe_forge/claude/generator.py
index 86eadb7..e93d75d 100644
--- a/src/xe_forge/claude/generator.py
+++ b/src/xe_forge/claude/generator.py
@@ -31,6 +31,24 @@ def _render(template_name: str, **context: object) -> str:
return _env.get_template(template_name).render(**context)
+def _read_template_raw(template_name: str) -> str:
+ """Read a template file verbatim, bypassing Jinja.
+
+ Used for the C++ starter kernel: its initializer lists contain ``{{ }}``
+ which would collide with Jinja's variable delimiters. The starter kernel
+ needs no substitution, so raw text is exactly right.
+ """
+ return (_TEMPLATES_DIR / template_name).read_text()
+
+
+# DSLs whose kernel source is C++ (.cpp) rather than Python (.py).
+_CPP_DSLS = ("sycl", "cuda")
+
+
+def _kernel_ext(dsl: str) -> str:
+ return ".cpp" if dsl in _CPP_DSLS else ".py"
+
+
def generate_workspace(
workspace: Path,
config: Config,
@@ -47,12 +65,21 @@ def generate_workspace(
dsl = config.device_config.dsl
device = config.device_config.device
+ # Select DSL-specific templates. SYCL gets its own CLAUDE.md / command
+ # (C++ workflow, file-IO contract, no PyTorch-AST analyze step); other DSLs
+ # keep the existing Triton/PyTorch templates.
+ is_sycl = dsl == "sycl"
+ claude_template = "CLAUDE.sycl.md.j2" if is_sycl else "CLAUDE.md.j2"
+ optimize_template = "optimize-kernel.sycl.md.j2" if is_sycl else "optimize-kernel.md.j2"
+ ext = _kernel_ext(dsl)
+
(workspace / "CLAUDE.md").write_text(
_render(
- "CLAUDE.md.j2",
+ claude_template,
dsl=dsl,
device=device,
kernel_name=kernel_name,
+ kernel_ext=ext,
)
)
(workspace / "config.yaml").write_text(
@@ -66,13 +93,13 @@ def generate_workspace(
cmd_dir = workspace / ".claude" / "commands"
cmd_dir.mkdir(parents=True, exist_ok=True)
- (cmd_dir / "optimize-kernel.md").write_text(_render("optimize-kernel.md.j2", dsl=dsl))
+ (cmd_dir / "optimize-kernel.md").write_text(_render(optimize_template, dsl=dsl, kernel_ext=ext))
agent_dir = workspace / ".claude" / "agents"
agent_dir.mkdir(parents=True, exist_ok=True)
(agent_dir / "tool-runner.md").write_text(_render("tool-runner.md.j2"))
- _write_kernel_files(workspace, kernel_name, kernel_code, reference_code, spec_path)
+ _write_kernel_files(workspace, kernel_name, kernel_code, reference_code, spec_path, dsl)
_symlink_knowledge_base(workspace)
if config.engine.git_init:
@@ -85,11 +112,24 @@ def _write_kernel_files(
kernel_code: str,
reference_code: str,
spec_path: str | None,
+ dsl: str = "triton",
) -> None:
tk_dir = workspace / "test_kernels"
tk_dir.mkdir(parents=True, exist_ok=True)
- (tk_dir / f"{kernel_name}.py").write_text(kernel_code)
+ ext = _kernel_ext(dsl)
+
+ if dsl in _CPP_DSLS:
+ # If the input is PyTorch-only or spec-only (no C++ source), substitute a
+ # compilable starter stub honouring the file-IO contract so the t0
+ # baseline produces a real timing.
+ if "#include" not in kernel_code:
+ kernel_code = _read_template_raw("starter_kernel.sycl.cpp.j2")
+
+ (tk_dir / f"{kernel_name}{ext}").write_text(kernel_code)
+
+ # The reference is PyTorch even when the kernel is C++; always write it as
+ # {name}_pytorch.py so the benchmark skill can locate the golden reference.
if reference_code:
(tk_dir / f"{kernel_name}_pytorch.py").write_text(reference_code)
if spec_path and Path(spec_path).exists():
diff --git a/src/xe_forge/claude/templates/CLAUDE.sycl.md.j2 b/src/xe_forge/claude/templates/CLAUDE.sycl.md.j2
new file mode 100644
index 0000000..7e11265
--- /dev/null
+++ b/src/xe_forge/claude/templates/CLAUDE.sycl.md.j2
@@ -0,0 +1,108 @@
+# Xe-Forge SYCL Kernel Optimizer
+
+Optimize kernels into high-performance **{{ dsl | upper }}** (CUTLASS SYCL C++) implementations for **{{ device | upper }}** (Intel Xe / Battlemage).
+
+## CONFIGURATION — Read `config.yaml` first
+
+At the start of every session, read `config.yaml` in the workspace root. It controls:
+- **`max_trials`** — hard cap on optimization trials
+- **`vtune_enabled`** — whether VTune profiling is available this session
+- **`vtune_bin`** — path to VTune binary
+
+All runtime behavior below is gated by these values; re-read `config.yaml` if anything looks off.
+
+## RULES — NEVER VIOLATE
+
+1. **ONLY create** kernel files (trial files `t{{ kernel_ext }}` or output files).
+2. **NEVER create** benchmark scripts, test scripts, helper utilities, or any other files.
+3. **NEVER write custom scripts** to measure performance — ONLY use `xe-forge-skill benchmark`.
+4. If a tool fails, **STOP and report the error**. Do NOT work around it with custom scripts.
+5. Generated kernels must be **self-contained, standalone C++ executables** with a `main()` — all logic in one `.cpp`.
+6. You **MUST run all `max_trials` trials** (as set in `config.yaml`). Do NOT stop early due to plateau. Only stop early if speedup > 5x.
+
+## THE FILE-IO CONTRACT — every kernel MUST honour it
+
+The kernel is a standalone executable invoked by the harness as:
+
+```
+./kernel --m= --n= --k= --input_dir= --output_dir= --iterations= --verify=
+```
+
+It MUST:
+1. Parse all of `--m --n --k --input_dir --output_dir --iterations --verify` (use `cutlass::CommandLine`).
+2. Read inputs from `--input_dir`: `A.bin` `[M,K]`, `B0.bin` `[K,N]` — raw row-major, **bf16 stored as raw int16 bits** (read straight into a `bfloat16_t*` buffer, no conversion). `B1.bin` `[K,N]` exists too (dual-GEMM second operand); ignore it for plain GEMM.
+3. Compute `D = A * B0` (alpha=1, beta=0).
+4. Write the output to `--output_dir/D2.bin`: `[M,N]`, **float32**, raw row-major. Always f32 regardless of input dtype.
+5. Print a perf line in exactly this format (parsed by the harness):
+ `printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops/time, time*1000);`
+ and a `Disposition: Passed` line.
+
+`--verify` is supplied but should be **ignored** for file-IO runs — correctness is checked in Python via `numpy.allclose` against a **PyTorch golden reference** computed on the same bit-identical inputs, NOT by the kernel's internal reference GEMM.
+
+The starter kernel in `test_kernels/{{ kernel_name }}{{ kernel_ext }}` already implements this contract correctly — read it first and keep the IO scaffolding intact; only change the GEMM configuration (TileShape, SubgroupLayout, PipelineStages, copy atoms, epilogue).
+
+Full details: `knowledge_base/sycl/xpu/sycl_io_contract.yaml`.
+
+## MANDATORY TOOLS — Use these and ONLY these
+
+**Delegate all tool execution** to the `tool-runner` agent to keep the main context clean.
+
+**CRITICAL — Single-XPU serialization**: There is only ONE XPU. NEVER dispatch multiple tool-runner agents in parallel if any runs `benchmark` or `profile`. These GPU workloads must execute strictly one at a time.
+
+| Task | Command |
+|------|---------|
+| **Validate** | `xe-forge-skill validate --dsl {{ dsl }}` |
+| **Benchmark** | `xe-forge-skill benchmark --spec --dsl {{ dsl }} --variant bench-xpu [--baseline-us ]` |
+| **Init trials** | `xe-forge-skill trial init ` |
+| **Save trial** | `xe-forge-skill trial save [--parent ] [--strategy "..."]` |
+| **Record result** | `xe-forge-skill trial result --correctness --speedup --baseline-us --triton-us --tflops ` |
+| **Check status** | `xe-forge-skill trial status ` |
+| **Best trial** | `xe-forge-skill trial best ` |
+| **Baseline time** | `xe-forge-skill trial baseline-us ` |
+| **Finalize** | `xe-forge-skill trial finalize ` |
+| **Profile** (only when `vtune_enabled: true`) | `xe-forge-skill profile --spec --dsl {{ dsl }} --variant bench-xpu` |
+
+Note: `--triton-us` and `triton_us=` are kept as uniform token names across all DSLs — for SYCL they carry the **optimized kernel's** time in microseconds. The benchmark `Performance:` line also reports `tflops=` (achieved throughput) and `util=` (percentage of the device's theoretical peak, ~160 TFLOPS bf16 on the B70) — record `tflops` and track utilization across trials.
+
+## WORKFLOW — Follow these steps in order
+
+### Step 1: Analyze
+- Read the baseline kernel `test_kernels/{{ kernel_name }}{{ kernel_ext }}` — note the TileShape, SubgroupLayout, dispatch policy, dtypes, and the IO scaffolding.
+- Read the PyTorch golden reference `test_kernels/{{ kernel_name }}_pytorch.py` (if present) to understand the exact math the kernel must reproduce.
+- Read `knowledge_base/sycl/xpu/` patterns: `cutlass_sycl_framework.yaml` (tile configs, dispatch policies, epilogues), `xetla_patterns.yaml`, and `sycl_io_contract.yaml` (the harness contract).
+
+### Step 2: Initialize
+```bash
+xe-forge-skill trial init {{ kernel_name }} test_kernels/{{ kernel_name }}{{ kernel_ext }}
+```
+
+### Step 3: Trial Loop (always run all `max_trials` trials from `config.yaml`)
+For each trial:
+1. **Write kernel** — start from the baseline `.cpp` (or the best prior trial) and change the GEMM config. Keep the IO contract intact. Write to `t.cpp`.
+2. **Validate** — `xe-forge-skill validate --dsl {{ dsl }}` (fix until no errors; heed the `missing_io_contract` warning).
+3. **Save** — `xe-forge-skill trial save {{ kernel_name }} --parent --strategy "description"`.
+4. **Benchmark** (MANDATORY every trial):
+ - **Trial t0:** `xe-forge-skill benchmark test_kernels/{{ kernel_name }}{{ kernel_ext }} --spec test_kernels/{{ kernel_name }}.yaml --dsl {{ dsl }} --variant bench-xpu`
+ - **Trials t1+:** Get cached baseline via `xe-forge-skill trial baseline-us {{ kernel_name }}`, then `xe-forge-skill benchmark test_kernels/{{ kernel_name }}{{ kernel_ext }} --spec test_kernels/{{ kernel_name }}.yaml --dsl {{ dsl }} --variant bench-xpu --baseline-us `
+5. **Record** — `xe-forge-skill trial result {{ kernel_name }} --correctness --speedup --baseline-us --triton-us --tflops ` (pass the `tflops=` value from the benchmark output)
+6. **Profile** — if `vtune_enabled: true` in `config.yaml` and this is trial t1 or later, run `xe-forge-skill profile --spec --dsl {{ dsl }} --variant bench-xpu`. VTune `gpu-hotspots` reports XVE Active/Stalled/Idle, occupancy, XMX (DPAS) utilization, and L3/memory metrics, then maps them to CUTLASS knobs (see `knowledge_base/sycl/xpu/sycl_vtune.yaml`). Use those recommendations to choose the next trial's TileShape / PipelineStages / copy-atom change. If `vtune_enabled: false`, skip this step.
+7. **Decide next action** (reason about `util%` — distance from the ~160 TFLOPS bf16 peak — not just relative speedup):
+ - Speedup > 5x -> stop, finalize
+ - Speedup improved -> continue on this branch
+ - Speedup regressed -> branch back to best trial, try different strategy
+ - Correctness failed -> fix on same branch (check accumulate dtype, layouts, D2.bin f32 dump)
+ - Plateau -> try fundamentally different approach (tile shape, copy atoms, pipeline stages, stream-K)
+ - Low utilization (util well below peak) -> the kernel is leaving FLOPS on the table; prioritize tile shape / pipeline-stage / copy-atom changes that raise compute throughput
+
+### Step 4: Finalize
+```bash
+xe-forge-skill trial finalize {{ kernel_name }} output/{{ kernel_name }}_optimized.cpp
+```
+
+## CRITICAL CORRECTNESS CONSTRAINTS (SYCL / CUTLASS on Xe)
+
+- **TileShape ↔ SubgroupLayout ↔ MMA atom** must be consistent. With `XE_DPAS_TT<8, float, bfloat16_t>` (8x16x16) and a `Shape<_8,_4,_1>` subgroup layout, `TileM/8` and `TileN/4` must be integers.
+- **SLM budget** on Battlemage is 128 KB/work-group. Double-buffered A+B staging must fit: `2*(TileM*TileK + TileK*TileN)*sizeof(elem)`.
+- **Accumulate in float32** (`ElementAccumulator = float`) for bf16/f16 inputs — never accumulate in the input dtype.
+- **bf16 input layout**: keep `LayoutA = LayoutB = RowMajor` for the `A[M,K]`, `B0[K,N]` contract; the collective builder handles VNNI internally.
+- **Output dtype**: `ElementOutput` MUST be `float` so `D2.bin` is float32 (the golden reference is f32).
diff --git a/src/xe_forge/claude/templates/optimize-kernel.sycl.md.j2 b/src/xe_forge/claude/templates/optimize-kernel.sycl.md.j2
new file mode 100644
index 0000000..a3aa7ba
--- /dev/null
+++ b/src/xe_forge/claude/templates/optimize-kernel.sycl.md.j2
@@ -0,0 +1,19 @@
+Optimize `$ARGUMENTS` into a high-performance {{ dsl | upper }} (CUTLASS SYCL) kernel.
+
+Resolve the argument to a file in `test_kernels/` — it could be a full filename, partial name, or just a number. Glob to find the match; if ambiguous, ask.
+
+First read `config.yaml` for session settings (max_trials, vtune_enabled).
+Follow the CLAUDE.md workflow exactly — every step, in order:
+
+1. Read the baseline kernel `test_kernels/{{ kernel_ext }}` — note its TileShape, SubgroupLayout, dispatch policy, and the file-IO scaffolding (which you must keep intact).
+2. Read the PyTorch golden reference `test_kernels/_pytorch.py` (if present) to understand the exact math.
+3. Read relevant `knowledge_base/sycl/xpu/` files — especially `cutlass_sycl_framework.yaml` (tile configs) and `sycl_io_contract.yaml` (the harness contract).
+4. Initialize: `xe-forge-skill trial init test_kernels/{{ kernel_ext }}`
+5. Run ALL max_trials trials (from config.yaml). For EACH trial you MUST:
+ - Validate with `xe-forge-skill validate --dsl {{ dsl }}`
+ - Save with `xe-forge-skill trial save`
+ - Benchmark with `xe-forge-skill benchmark --spec --dsl {{ dsl }} --variant bench-xpu` — NEVER create custom test scripts
+ - Record with `xe-forge-skill trial result`
+6. Finalize the best correct trial
+
+CRITICAL: ONLY create kernel `.cpp` files — NO benchmark scripts, NO test scripts, NO helpers. Every kernel must remain a standalone executable honouring the file-IO contract (read A.bin/B0.bin from --input_dir, write D2.bin as f32 to --output_dir, print the TFlop/s line).
diff --git a/src/xe_forge/claude/templates/starter_kernel.sycl.cpp.j2 b/src/xe_forge/claude/templates/starter_kernel.sycl.cpp.j2
new file mode 100644
index 0000000..f6db872
--- /dev/null
+++ b/src/xe_forge/claude/templates/starter_kernel.sycl.cpp.j2
@@ -0,0 +1,287 @@
+/*
+ Xe-Forge SYCL starter kernel — minimal compilable CUTLASS BMG GEMM (D = A * B0)
+ that honours the Xe-Forge file-IO contract. Generated as the t0 baseline when
+ the optimization input is PyTorch-only or spec-only (no C++ source supplied).
+
+ File-IO contract (see knowledge_base/sycl/xpu/sycl_io_contract.yaml):
+ --m/--n/--k problem dims
+ --input_dir= read A.bin [M,K], B0.bin [K,N] (bf16 as raw int16 bits)
+ --output_dir= write D2.bin [M,N] as float32, row-major
+ --iterations= timed iterations
+ --verify= ignored (correctness checked in Python vs golden ref)
+
+ Prints a "[]TFlop/s ()ms" line that SyclExecutor parses.
+
+ Derived from sycl-tla/examples/00_bmg_gemm/00_bmg_gemm.cpp. This is a CORRECT
+ but unoptimized starting point — the optimizer should iterate on TileShape,
+ SubgroupLayout, PipelineStages, copy atoms, and the epilogue.
+*/
+
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+#include "cutlass/epilogue/collective/xe_epilogue.hpp"
+#include "cutlass/epilogue/fusion/xe_callbacks.hpp"
+#include "cutlass/gemm/device/gemm_universal.h"
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/collective/collective_mma.hpp"
+#include "cutlass/util/GPU_Clock.hpp"
+
+#include
+#include
+#include
+#include
+
+#include "cutlass/util/command_line.h"
+#include "cutlass/util/device_memory.h"
+#include "cutlass/util/packed_stride.hpp"
+#include "cutlass/util/reference/device/gemm_complex.h"
+#include "cutlass/util/reference/device/tensor_compare.h"
+#include "sycl_common.hpp"
+#include "helper.h"
+
+using namespace cute;
+
+struct Options {
+ bool help = false;
+ bool error = false;
+ int m = 5120, n = 4096, k = 4096, l = 1, iterations = 20, verify = 1;
+ float alpha = 1.f, beta = 0.f;
+ std::string input_dir;
+ std::string output_dir;
+
+ void parse(int argc, char const **args) {
+ cutlass::CommandLine cmd(argc, args);
+ if (cmd.check_cmd_line_flag("help")) { help = true; return; }
+ cmd.get_cmd_line_argument("m", m, 5120);
+ cmd.get_cmd_line_argument("n", n, 4096);
+ cmd.get_cmd_line_argument("k", k, 4096);
+ cmd.get_cmd_line_argument("l", l, 1);
+ cmd.get_cmd_line_argument("alpha", alpha, 1.f);
+ cmd.get_cmd_line_argument("beta", beta, 0.f);
+ cmd.get_cmd_line_argument("iterations", iterations, 20);
+ cmd.get_cmd_line_argument("verify", verify, 1);
+ cmd.get_cmd_line_argument("input_dir", input_dir, std::string(""));
+ cmd.get_cmd_line_argument("output_dir", output_dir, std::string(""));
+ }
+};
+
+template
+static std::vector read_bin(const std::string& dir, const std::string& name, size_t count) {
+ std::string path = dir + "/" + name;
+ std::ifstream f(path, std::ios::binary);
+ if (!f) { std::cerr << "Could not open " << path << std::endl; std::exit(2); }
+ std::vector buf(count);
+ f.read(reinterpret_cast(buf.data()), count * sizeof(T));
+ if (static_cast(f.gcount()) != count * sizeof(T)) {
+ std::cerr << "Short read on " << path << ": got " << f.gcount()
+ << " want " << count * sizeof(T) << std::endl;
+ std::exit(2);
+ }
+ return buf;
+}
+
+template
+static void write_bin(const std::string& dir, const std::string& name, const std::vector& buf) {
+ std::string path = dir + "/" + name;
+ std::ofstream f(path, std::ios::binary);
+ if (!f) { std::cerr << "Could not open for write " << path << std::endl; std::exit(2); }
+ f.write(reinterpret_cast(buf.data()), buf.size() * sizeof(T));
+}
+
+template
+struct ExampleRunner {
+ using StrideA = typename Gemm::GemmKernel::StrideA;
+ using StrideB = typename Gemm::GemmKernel::StrideB;
+ using StrideC = typename Gemm::GemmKernel::StrideC;
+ using StrideD = typename Gemm::GemmKernel::StrideD;
+
+ using LayoutA = typename Gemm::LayoutA;
+ using LayoutB = typename Gemm::LayoutB;
+ using LayoutC = typename Gemm::LayoutC;
+ using LayoutD = typename Gemm::LayoutD;
+
+ using ElementA = typename Gemm::ElementA;
+ using ElementB = typename Gemm::ElementB;
+ using ElementAccumulator = typename Gemm::ElementAccumulator;
+
+ using CollectiveEpilogue = typename Gemm::CollectiveEpilogue;
+ using ElementC = typename Gemm::ElementC;
+ using ElementOutput = typename CollectiveEpilogue::ElementOutput;
+ using ElementCompute = typename CollectiveEpilogue::ElementCompute;
+
+ using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
+
+ StrideA stride_A;
+ StrideB stride_B;
+ StrideC stride_C;
+ StrideD stride_D;
+ uint64_t seed = 0;
+
+ cutlass::DeviceAllocation block_A;
+ cutlass::DeviceAllocation block_B;
+ cutlass::DeviceAllocation block_C;
+ cutlass::DeviceAllocation block_D;
+
+ void initialize(const ProblemShapeType& problem_size, const Options& options) {
+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
+ auto [M, N, K, L] = problem_shape_MNKL;
+
+ stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
+ stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
+ stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
+ stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
+
+ block_A.reset(static_cast(M) * K * L);
+ block_B.reset(static_cast(K) * N * L);
+ block_C.reset(static_cast(M) * N * L);
+ block_D.reset(static_cast(M) * N * L);
+
+ // A [M,K] and B0 [K,N] are raw bf16 bits (bit-identical to torch bf16).
+ auto host_A = read_bin(options.input_dir, "A.bin",
+ static_cast(M) * K * L);
+ auto host_B = read_bin(options.input_dir, "B0.bin",
+ static_cast(K) * N * L);
+ block_A.copy_from_host(host_A.data());
+ block_B.copy_from_host(host_B.data());
+
+ // C unused (beta = 0); zero-fill so the epilogue reads valid memory.
+ std::vector host_C(static_cast(M) * N * L, ElementC(0));
+ block_C.copy_from_host(host_C.data());
+ }
+
+ void dump_output(const ProblemShapeType& problem_size, const Options& options) {
+ auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
+ auto [M, N, K, L] = problem_shape_MNKL;
+ std::vector host_D(static_cast(M) * N * L);
+ block_D.copy_to_host(host_D.data());
+ write_bin(options.output_dir, "D2.bin", host_D); // ElementOutput = float32
+ }
+
+ cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
+ ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};
+
+ initialize(problem_size, options);
+
+ typename Gemm::GemmKernel::Arguments arguments{
+ cutlass::gemm::GemmUniversalMode::kGemm,
+ problem_size,
+ {block_A.get(), stride_A, block_B.get(), stride_B},
+ {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
+ hw_info
+ };
+
+ Gemm gemm_op;
+
+ size_t workspace_size = Gemm::get_workspace_size(arguments);
+ cutlass::device_memory::allocation workspace(workspace_size);
+
+ if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess) {
+ std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x'
+ << options.k << 'x' << options.l << std::endl;
+ std::exit(1);
+ }
+
+ CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get()));
+ CUTLASS_CHECK(gemm_op.run());
+ compat::wait();
+
+ if (!options.output_dir.empty()) {
+ dump_output(problem_size, options);
+ }
+ std::cout << "Disposition: Passed" << std::endl;
+
+ if (options.iterations > 0) {
+ GPU_Clock timer;
+ timer.start();
+ for (int i = 0; i < options.iterations; ++i) {
+ gemm_op.run();
+ }
+ compat::wait();
+
+ float cute_time = timer.seconds() / options.iterations;
+ double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12;
+ std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x'
+ << options.k << 'x' << options.l << std::endl;
+ printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n",
+ tflops / cute_time, cute_time * 1000);
+ }
+
+ return cutlass::Status::kSuccess;
+ }
+};
+
+int main(int argc, const char** argv) {
+ Options options;
+ options.parse(argc, argv);
+ if (options.help) { std::cout << "Xe-Forge SYCL starter GEMM\n"; return 0; }
+ if (options.error) { std::cerr << "Aborting execution." << std::endl; return -1; }
+
+ cutlass::KernelHardwareInfo hw_info;
+ hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
+
+ using ElementAccumulator = float;
+ using ElementComputeEpilogue = float;
+ using ElementInputA = bfloat16_t;
+ using ElementInputB = bfloat16_t;
+ using ElementOutput = float;
+
+ using LayoutA = cutlass::layout::RowMajor;
+ using LayoutB = cutlass::layout::RowMajor;
+ using LayoutC = cutlass::layout::RowMajor;
+ using LayoutD = cutlass::layout::RowMajor;
+
+ using GmemTiledCopyA = void;
+ using GmemTiledCopyB = void;
+
+ // Workgroup tile — the primary thing to tune.
+ using TileShape = Shape<_256, _256, _32>;
+
+ using TiledMma = typename TiledMMAHelper>,
+ Layout, Layout, Stride<_4, _1, _0>>>::TiledMMA;
+
+ constexpr int PipelineStages = 2;
+ using GEMMDispatchPolicy = cutlass::gemm::MainloopXeL1Staged;
+ using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeGeneric;
+
+ using EpilogueOp = cutlass::epilogue::fusion::LinearCombination;
+
+ using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks;
+
+ using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue<
+ EpilogueDispatchPolicy,
+ TileShape,
+ void,
+ ElementAccumulator,
+ cutlass::gemm::TagToStrideC_t,
+ ElementOutput,
+ cutlass::gemm::TagToStrideC_t,
+ FusionCallbacks,
+ void,
+ void>;
+
+ using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma<
+ GEMMDispatchPolicy,
+ TileShape,
+ ElementInputA,
+ cutlass::gemm::TagToStrideA_t,
+ ElementInputB,
+ cutlass::gemm::TagToStrideB_t,
+ TiledMma,
+ GmemTiledCopyA, void, void, cute::identity,
+ GmemTiledCopyB, void, void, cute::identity
+ >;
+
+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
+ Shape,
+ CollectiveMainloop,
+ CollectiveEpilogue
+ >;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
+
+ ExampleRunner runner;
+ CUTLASS_CHECK(runner.run(options, hw_info));
+
+ return 0;
+}
diff --git a/src/xe_forge/claude/templates/tool-runner.md.j2 b/src/xe_forge/claude/templates/tool-runner.md.j2
index 400aa43..b1448f1 100644
--- a/src/xe_forge/claude/templates/tool-runner.md.j2
+++ b/src/xe_forge/claude/templates/tool-runner.md.j2
@@ -31,7 +31,7 @@ Safe to parallelize: analyze, validate, trial (CPU-only).
## Output Rules
For benchmark:
-- Extract ONLY: Correctness (PASSED/FAILED), Performance (baseline_us, triton_us, speedup), Errors.
+- Extract ONLY: Correctness (PASSED/FAILED), Performance (baseline_us, triton_us, speedup, and tflops/util when present), Errors.
- Do NOT include configuration header or decorative separators.
For profile:
diff --git a/src/xe_forge/cli.py b/src/xe_forge/cli.py
index 8f0ad00..8e4787a 100644
--- a/src/xe_forge/cli.py
+++ b/src/xe_forge/cli.py
@@ -516,16 +516,29 @@ def _run_optimize(parser, args, config: Config) -> int:
with open(args.input) as f:
kernel_code = f.read()
- # Read reference implementation (Python DSLs only)
+ # Read reference implementation.
+ # * Python DSLs: {input_stem}_pytorch.py is the optimization reference.
+ # * SYCL/CUDA: the golden correctness reference is also PyTorch, located by
+ # the same {input_stem}_pytorch.py convention. If absent and the input is
+ # itself a .py, the input *is* the reference (spec/PyTorch-only start).
reference_code = ""
- if dsl not in ("sycl", "cuda"):
- try:
- with open(f"{os.path.splitext(args.input)[0]}_pytorch.py") as f:
- reference_code = f.read()
- except FileNotFoundError:
- print(
- f"No PyTorch reference file found at {os.path.splitext(args.input)[0]}_pytorch.py"
- )
+ input_stem = os.path.splitext(args.input)[0]
+ ref_path = f"{input_stem}_pytorch.py"
+ try:
+ with open(ref_path) as f:
+ reference_code = f.read()
+ except FileNotFoundError:
+ if dsl in ("sycl", "cuda"):
+ if args.input.endswith(".py"):
+ reference_code = kernel_code
+ print(f"Using {args.input} as the PyTorch golden reference")
+ else:
+ print(
+ f"No PyTorch golden reference at {ref_path}; "
+ "SYCL correctness will be skipped during benchmarking"
+ )
+ else:
+ print(f"No PyTorch reference file found at {ref_path}")
# Create engine and optimize
from xe_forge.engines import create_engine
diff --git a/src/xe_forge/config.py b/src/xe_forge/config.py
index 9301b22..944dd42 100644
--- a/src/xe_forge/config.py
+++ b/src/xe_forge/config.py
@@ -73,6 +73,9 @@ class DeviceConfig:
preferred_tile_m: int = 256
preferred_tile_n: int = 256
preferred_tile_k: int = 32
+ # Theoretical peak throughput for utilization reporting (TFLOPS). Default is
+ # the Intel Arc Pro B70 fp16/bf16 DPAS peak; override per device/dtype.
+ peak_tflops: float = 160.0
@dataclass
@@ -87,6 +90,7 @@ class XPUConfig(DeviceConfig):
preferred_tile_n: int = 256
preferred_tile_k: int = 32
group_size_m: int = 4
+ peak_tflops: float = 160.0
@dataclass
@@ -307,6 +311,7 @@ def _build_device_config(self, device_type: str, dsl: str) -> DeviceConfig:
preferred_tile_n=self._get_env("PREFERRED_TILE_N", 256, int),
preferred_tile_k=self._get_env("PREFERRED_TILE_K", 32, int),
group_size_m=self._get_env("GROUP_SIZE_M", 4, int),
+ peak_tflops=self._get_env("PEAK_TFLOPS", 160.0, float),
)
def override(self, **kwargs) -> "ConfigManager":
diff --git a/src/xe_forge/core/sycl_executor.py b/src/xe_forge/core/sycl_executor.py
index 9c11dad..cc1dcf2 100644
--- a/src/xe_forge/core/sycl_executor.py
+++ b/src/xe_forge/core/sycl_executor.py
@@ -35,6 +35,8 @@
_DEVICE_NAME_TO_TARGET: dict[str, str] = {
"b580": "bmg-g31",
"b570": "bmg-g31",
+ "b70": "bmg-g31",
+ "arc pro b": "bmg-g31",
"battlemage": "bmg-g31",
"bmg": "bmg-g31",
"a770": "acm-g10",
@@ -658,6 +660,138 @@ def compare_kernels(
feedback_message=msg,
)
+ def compare_with_reference(
+ self,
+ golden_output: np.ndarray,
+ optimized_code: str | None = None,
+ optimized_path: str | None = None,
+ m: int = 1024,
+ n: int = 1024,
+ k: int = 1024,
+ dims: dict[str, int | float] | None = None,
+ rtol: float = 1e-2,
+ atol: float = 1e-2,
+ input_dir: str | None = None,
+ seed: int = 42,
+ ) -> SyclComparisonResult:
+ """Compile + run a SYCL kernel and check it against a golden numpy array.
+
+ Mirrors :meth:`compare_kernels`, but instead of comparing the optimized
+ kernel's output to an *original* kernel's ``D2.bin``, it compares against
+ a precomputed **PyTorch/numpy golden reference** (the Claude-engine path).
+ Only the optimized kernel is run; there is no baseline rerun here.
+
+ Data layout (kept identical to ``get_or_create_inputs`` / ``_save_tensor``):
+ * ``A.bin`` : ``[M, K]`` row-major, bf16 stored as raw int16 bits.
+ * ``B0.bin`` : ``[K, N]`` row-major, bf16 stored as raw int16 bits.
+ * ``B1.bin`` : ``[K, N]`` (dual-GEMM second operand; unused by plain GEMM).
+ * ``D2.bin`` : ``[M, N]`` row-major float32 (read back flat, reshaped to
+ ``golden_output.shape``).
+
+ Args:
+ golden_output: Reference output as a numpy array (any shape; the
+ kernel's flat ``D2.bin`` is reshaped to match it).
+ optimized_code/path: Kernel source string or ``.cpp`` path.
+ m, n, k: GEMM dims (backward compat; ``dims`` takes precedence).
+ dims: Generic dimension dict from the spec.
+ rtol, atol: Tolerances for ``numpy.allclose`` (spec-driven; do not
+ rely on ``compare_outputs`` defaults).
+ input_dir: Directory holding the shared input ``.bin`` files. When
+ ``None``, inputs are generated/cached via ``get_or_create_inputs``.
+ seed: Seed used when generating inputs (only when ``input_dir`` is None).
+
+ Returns:
+ SyclComparisonResult carrying ``optimized_time_ms``,
+ ``optimized_tflops``, ``optimized_correct``, and a feedback message.
+ ``original_*`` fields are left at their defaults (no baseline here).
+ """
+ if input_dir is None:
+ effective_dims = dims or {"M": m, "N": n, "K": k}
+ input_dir = self.get_or_create_inputs(effective_dims, seed=seed)
+
+ out_dir = tempfile.mkdtemp(prefix="sycl_golden_")
+ opt_result = self.execute(
+ kernel_code=optimized_code,
+ kernel_path=optimized_path,
+ m=m,
+ n=n,
+ k=k,
+ dims=dims,
+ output_name="optimized_sycl",
+ input_dir=input_dir,
+ output_dir=out_dir,
+ )
+
+ if not opt_result.success:
+ try:
+ shutil.rmtree(out_dir)
+ except Exception:
+ pass
+ return SyclComparisonResult(
+ original_time_ms=float("inf"),
+ optimized_time_ms=float("inf"),
+ speedup=0.0,
+ optimized_correct=False,
+ feedback_message=(
+ f"FAILURE: Kernel failed: {opt_result.error_message}. "
+ "Fix compilation or runtime errors."
+ ),
+ )
+
+ opt_ms = opt_result.execution_time_ms or float("inf")
+ opt_tflops = opt_result.tflops
+
+ opt_correct = True
+ correctness_msg = ""
+ d2_path = os.path.join(out_dir, "D2.bin")
+ if os.path.exists(d2_path):
+ opt_flat = self.load_output(d2_path, np.float32)
+ golden = np.asarray(golden_output, dtype=np.float32)
+ if opt_flat.size == golden.size:
+ opt_out = opt_flat.reshape(golden.shape)
+ passed, detail = self.compare_outputs(golden, opt_out, rtol=rtol, atol=atol)
+ else:
+ passed = False
+ detail = (
+ f"Size mismatch: D2.bin has {opt_flat.size} elems, golden has {golden.size}"
+ )
+ opt_correct = passed
+ correctness_msg = (
+ " Correctness: PASSED." if passed else f" CORRECTNESS FAILED: {detail}."
+ )
+ logger.info(f"Golden comparison (rtol={rtol}, atol={atol}): {detail}")
+ else:
+ opt_correct = False
+ correctness_msg = " (no D2.bin produced — kernel did not honour the IO contract)"
+ logger.warning("D2.bin not found — cannot verify against golden reference")
+
+ try:
+ shutil.rmtree(out_dir)
+ except Exception:
+ pass
+
+ if not opt_correct:
+ msg = (
+ f"CORRECTNESS FAILURE: Kernel produces wrong results vs the golden "
+ f"reference.{correctness_msg} Optimized: {opt_ms:.4f}ms. "
+ "Fix numerical correctness before optimizing for speed."
+ )
+ else:
+ tflops_str = f" ({opt_tflops:.3f} TFlop/s)" if opt_tflops else ""
+ msg = (
+ f"SUCCESS: Correct vs golden reference. "
+ f"Optimized: {opt_ms:.4f}ms{tflops_str}.{correctness_msg}"
+ )
+
+ return SyclComparisonResult(
+ original_time_ms=float("inf"),
+ optimized_time_ms=opt_ms,
+ speedup=0.0,
+ optimized_tflops=opt_tflops,
+ optimized_correct=opt_correct,
+ feedback_message=msg,
+ )
+
def __del__(self):
if self._build_dir is not None:
try:
diff --git a/src/xe_forge/core/sycl_profiler.py b/src/xe_forge/core/sycl_profiler.py
new file mode 100644
index 0000000..feb55a7
--- /dev/null
+++ b/src/xe_forge/core/sycl_profiler.py
@@ -0,0 +1,425 @@
+"""VTune GPU profiler for SYCL/CUTLASS kernels on Intel Xe.
+
+Independent of the Triton profiler (``core/profiler.py``): that one generates a
+Python runner that imports a PyTorch ``Model`` and collects ``gpu-offload``.
+A SYCL kernel is a compiled ``.cpp`` binary, so this module instead:
+
+ 1. compiles the kernel via :class:`SyclExecutor` (reusing the benchmark build),
+ 2. generates the same deterministic file-IO inputs the benchmark uses,
+ 3. runs the binary directly under ``vtune -collect gpu-hotspots`` in
+ characterization mode (richer Xe hardware metrics than ``gpu-offload``),
+ 4. parses the per-computing-task hotspots report and maps the metrics to
+ CUTLASS tuning knobs (TileShape, SubgroupLayout, PipelineStages, copy
+ atoms) via :data:`SYCL_RECOMMENDATION_RULES`.
+
+Verbs, knobs, and column names were confirmed on VTune 2026.0 + Arc Pro B70
+(see knowledge_base/sycl/xpu/sycl_vtune.yaml). Gracefully degrades to a result
+with ``error`` set when VTune is absent or collection fails, so the trial loop
+continues without profiling.
+"""
+
+from __future__ import annotations
+
+import csv
+import io
+import logging
+import os
+import re
+import shutil
+import subprocess
+import tempfile
+from dataclasses import dataclass, field
+from pathlib import Path
+
+import torch
+
+from xe_forge.core.sycl_executor import KernelType, SyclExecutor
+
+logger = logging.getLogger(__name__)
+
+# VTune gpu-hotspots collection knobs (confirmed on VTune 2026.0 / B70).
+_COLLECT_KNOBS = [
+ "gpu-profiling-mode=characterization",
+ "characterization-mode=overview",
+]
+
+# Overhead / non-kernel computing tasks to skip when picking the primary kernel.
+_OVERHEAD_PATTERNS = [
+ re.compile(r"zeCommandListAppendMemoryCopy"),
+ re.compile(r"\[Outside any task\]"),
+ re.compile(r"clEnqueue"),
+]
+
+# Report columns to request. The memory-bandwidth columns are fetched in a
+# SEPARATE pass via a comma-free SUBSTRING filter ("GB/sec"): the full names
+# ("GPU Memory Bandwidth, GB/sec:Read") contain a literal comma that VTune's
+# -column parser mis-splits, but -column does substring matching, so "GB/sec"
+# selects both Read/Write columns cleanly. _build_metrics reads the full
+# header names that come back.
+_CORE_COLUMNS = [
+ "Computing Task:Total Time",
+ "XVE Array:Active",
+ "XVE Array:Stalled",
+ "XVE Array:Idle",
+ "Peak XVE Threads Occupancy",
+ "XVE Pipelines:XMX active",
+ "GPU L3:Miss Ratio",
+]
+_BW_COLUMNS = ["GB/sec"]
+
+
+def _is_overhead(name: str) -> bool:
+ return any(p.search(name) for p in _OVERHEAD_PATTERNS)
+
+
+@dataclass
+class SyclRecommendation:
+ category: str
+ message: str
+ kb_reference: str = ""
+
+
+@dataclass
+class SyclProfileMetrics:
+ xve_active_pct: float | None = None
+ xve_stalled_pct: float | None = None
+ xve_idle_pct: float | None = None
+ peak_occupancy_pct: float | None = None
+ xmx_active_pct: float | None = None
+ l3_miss_pct: float | None = None
+ gpu_mem_bw_read_gbps: float | None = None
+ gpu_mem_bw_write_gbps: float | None = None
+
+
+@dataclass
+class SyclProfileResult:
+ primary_kernel: str = ""
+ metrics: SyclProfileMetrics = field(default_factory=SyclProfileMetrics)
+ recommendations: list[SyclRecommendation] = field(default_factory=list)
+ raw_counters: dict = field(default_factory=dict)
+ error: str | None = None
+
+ def format_for_llm(self) -> str:
+ """Structured digest for the agent / tool-runner (mirrors XPUProfiler)."""
+ if self.error:
+ return f"Profiling error: {self.error}"
+ if not self.primary_kernel:
+ return "No profiling data available."
+
+ # The CUTLASS kernel name is a giant templated type; show a short tag.
+ short = self.primary_kernel.split("<", 1)[0]
+ parts = [f"== VTune GPU Profile: {short} ==", "", "Metrics:"]
+ m = self.metrics
+ rows = [
+ ("XVE Active", m.xve_active_pct, "%"),
+ ("XVE Stalled", m.xve_stalled_pct, "%"),
+ ("XVE Idle", m.xve_idle_pct, "%"),
+ ("Peak Occupancy", m.peak_occupancy_pct, "%"),
+ ("XMX (DPAS) Active", m.xmx_active_pct, "%"),
+ ("L3 Miss Ratio", m.l3_miss_pct, "%"),
+ ("GPU Mem BW Read", m.gpu_mem_bw_read_gbps, " GB/s"),
+ ("GPU Mem BW Write", m.gpu_mem_bw_write_gbps, " GB/s"),
+ ]
+ for label, val, unit in rows:
+ if val is not None:
+ parts.append(f" {label}: {val:.1f}{unit}")
+
+ if self.recommendations:
+ parts.append("")
+ parts.append("Recommendations:")
+ for rec in self.recommendations:
+ parts.append(f" [{rec.category}] {rec.message}")
+ if rec.kb_reference:
+ parts.append(f" -> {rec.kb_reference}")
+ return "\n".join(parts)
+
+
+class SyclProfiler:
+ """VTune ``gpu-hotspots`` profiler for compiled SYCL GEMM kernels."""
+
+ def __init__(
+ self,
+ vtune_bin: str = "vtune",
+ sycl_tla_dir: str | None = None,
+ device_target: str | None = None,
+ kernel_type: KernelType | str = KernelType.GEMM,
+ iterations: int = 200,
+ collect_timeout: int = 300,
+ ):
+ self.vtune_bin = vtune_bin
+ self.iterations = iterations
+ self.collect_timeout = collect_timeout
+ # Reuse the benchmark executor for compile + deterministic inputs.
+ executor_kwargs: dict = {"kernel_type": kernel_type, "verify": False}
+ if sycl_tla_dir is not None:
+ executor_kwargs["sycl_tla_dir"] = sycl_tla_dir
+ if device_target is not None:
+ executor_kwargs["device_target"] = device_target
+ self._executor = SyclExecutor(**executor_kwargs)
+
+ def available(self) -> bool:
+ """Whether the VTune binary is on PATH (or an absolute path that exists)."""
+ return shutil.which(self.vtune_bin) is not None or os.path.exists(self.vtune_bin)
+
+ def profile(
+ self,
+ kernel_path: str | Path,
+ dims: dict[str, int | float],
+ dtype: torch.dtype = torch.bfloat16,
+ input_dir: str | None = None,
+ ) -> SyclProfileResult:
+ """Compile, run under VTune, and return parsed metrics + recommendations.
+
+ Returns a result with ``error`` set (never raises) when VTune is missing
+ or any stage fails — the optimization loop treats profiling as advisory.
+ """
+ if not self.available():
+ return SyclProfileResult(
+ error=f"VTune not found ({self.vtune_bin}). Set vtune_bin or VTUNE_BIN."
+ )
+ kernel_path = Path(kernel_path)
+ if not kernel_path.exists():
+ return SyclProfileResult(error=f"Kernel file not found: {kernel_path}")
+
+ try:
+ return self._profile(kernel_path, dims, dtype, input_dir)
+ except Exception as e: # pragma: no cover - defensive
+ logger.exception("SYCL profiling failed")
+ return SyclProfileResult(error=str(e))
+
+ def _profile(
+ self,
+ kernel_path: Path,
+ dims: dict[str, int | float],
+ dtype: torch.dtype,
+ input_dir: str | None,
+ ) -> SyclProfileResult:
+ ok, binary, err = self._executor.compile(
+ source_path=str(kernel_path), output_name="kernel_vtune"
+ )
+ if not ok:
+ return SyclProfileResult(error=f"Compilation failed:\n{err[-1500:]}")
+
+ if input_dir is None:
+ input_dir = self._executor.get_or_create_inputs(dims, seed=42, dtype=dtype)
+
+ m, n, k = SyclExecutor._dims_to_mnk(dims)
+ result_dir = tempfile.mkdtemp(prefix="sycl_vtune_")
+ out_dir = tempfile.mkdtemp(prefix="sycl_vtune_out_")
+ try:
+ collect_err = self._collect(binary, m, n, k, input_dir, out_dir, result_dir)
+ if collect_err:
+ return SyclProfileResult(error=collect_err)
+
+ counters = self._extract_counters(result_dir)
+ if not counters:
+ return SyclProfileResult(error="No GPU kernel data in VTune report.")
+
+ primary = self._primary_kernel(counters)
+ if primary is None:
+ return SyclProfileResult(error="Could not identify the primary GPU kernel.")
+
+ metrics = self._build_metrics(counters[primary])
+ recs = self._recommendations(metrics)
+ return SyclProfileResult(
+ primary_kernel=primary,
+ metrics=metrics,
+ recommendations=recs,
+ raw_counters=counters[primary],
+ )
+ finally:
+ for d in (result_dir, out_dir):
+ shutil.rmtree(d, ignore_errors=True)
+
+ def _collect(
+ self,
+ binary: str,
+ m: int,
+ n: int,
+ k: int,
+ input_dir: str,
+ out_dir: str,
+ result_dir: str,
+ ) -> str | None:
+ """Run the gpu-hotspots collection. Returns an error string or None."""
+ # VTune refuses to write into an existing result-dir.
+ shutil.rmtree(result_dir, ignore_errors=True)
+ cmd = [self.vtune_bin, "-collect", "gpu-hotspots"]
+ for knob in _COLLECT_KNOBS:
+ cmd += ["-knob", knob]
+ cmd += [
+ "-result-dir",
+ result_dir,
+ "--",
+ binary,
+ f"--m={m}",
+ f"--n={n}",
+ f"--k={k}",
+ f"--input_dir={input_dir}",
+ f"--output_dir={out_dir}",
+ f"--iterations={self.iterations}",
+ "--verify=0",
+ ]
+ logger.info("VTune collect: %s", " ".join(cmd))
+ try:
+ proc = subprocess.run(cmd, capture_output=True, text=True, timeout=self.collect_timeout)
+ except subprocess.TimeoutExpired:
+ return f"VTune collection timed out after {self.collect_timeout}s"
+ except Exception as e:
+ return str(e)
+ if proc.returncode != 0:
+ return f"VTune collection failed (exit {proc.returncode}):\n{proc.stderr[-1500:]}"
+ return None
+
+ def _report_csv(self, result_dir: str, columns: list[str]) -> list[dict]:
+ """Run one hotspots report pass and return its rows as dicts."""
+ cmd = [
+ self.vtune_bin,
+ "-report",
+ "hotspots",
+ "-result-dir",
+ result_dir,
+ "-group-by",
+ "computing-task",
+ "-column",
+ ",".join(columns),
+ "-format",
+ "csv",
+ "-csv-delimiter",
+ "tab",
+ ]
+ try:
+ proc = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
+ except Exception as e: # pragma: no cover - defensive
+ logger.warning("VTune report failed: %s", e)
+ return []
+ if proc.returncode != 0:
+ logger.warning("VTune report rc=%s: %s", proc.returncode, proc.stderr[:300])
+ return []
+ # Skip any leading warning lines before the "Computing Task" header.
+ lines = proc.stdout.splitlines()
+ header_idx = next((i for i, ln in enumerate(lines) if ln.startswith("Computing Task")), 0)
+ payload = "\n".join(lines[header_idx:])
+ return list(csv.DictReader(io.StringIO(payload), delimiter="\t"))
+
+ def _extract_counters(self, result_dir: str) -> dict[str, dict]:
+ """Two report passes (core metrics + bandwidth) merged by task name.
+
+ The bandwidth pass is separate because its column names contain literal
+ commas that the -column list parser would mis-split.
+ """
+ counters: dict[str, dict] = {}
+ for columns in (_CORE_COLUMNS, _BW_COLUMNS):
+ for row in self._report_csv(result_dir, columns):
+ name = (row.get("Computing Task") or "").strip()
+ if not name:
+ continue
+ entry = counters.setdefault(name, {})
+ for key, val in row.items():
+ if val and val.strip():
+ entry.setdefault(key, val.strip())
+ return counters
+
+ def _primary_kernel(self, counters: dict[str, dict]) -> str | None:
+ """Highest-total-time non-overhead computing task."""
+
+ def total_time(cols: dict) -> float:
+ try:
+ return float(str(cols.get("Computing Task:Total Time", 0)).replace(",", ""))
+ except (ValueError, TypeError):
+ return 0.0
+
+ user = [(total_time(c), n) for n, c in counters.items() if not _is_overhead(n)]
+ if user:
+ return max(user)[1]
+ # Fall back to the hottest task overall (with a warning) rather than fail.
+ allt = [(total_time(c), n) for n, c in counters.items()]
+ if allt:
+ logger.warning("Only overhead tasks captured; using hottest task")
+ return max(allt)[1]
+ return None
+
+ @staticmethod
+ def _build_metrics(cols: dict) -> SyclProfileMetrics:
+ def num(*keys: str) -> float | None:
+ # VTune appends "(%)" to percentage headers; try both spellings.
+ for key in keys:
+ for variant in (key, f"{key}(%)"):
+ if variant in cols:
+ try:
+ return float(str(cols[variant]).rstrip("%").replace(",", "").strip())
+ except (ValueError, TypeError):
+ return None
+ return None
+
+ return SyclProfileMetrics(
+ xve_active_pct=num("XVE Array:Active"),
+ xve_stalled_pct=num("XVE Array:Stalled"),
+ xve_idle_pct=num("XVE Array:Idle"),
+ peak_occupancy_pct=num("Peak XVE Threads Occupancy"),
+ xmx_active_pct=num("XVE Pipelines:XMX active"),
+ l3_miss_pct=num("GPU L3:Miss Ratio"),
+ gpu_mem_bw_read_gbps=num("GPU Memory Bandwidth, GB/sec:Read"),
+ gpu_mem_bw_write_gbps=num("GPU Memory Bandwidth, GB/sec:Write"),
+ )
+
+ @staticmethod
+ def _recommendations(m: SyclProfileMetrics) -> list[SyclRecommendation]:
+ recs: list[SyclRecommendation] = []
+ kb = "knowledge_base/sycl/xpu/sycl_vtune.yaml"
+
+ if (
+ m.xve_stalled_pct is not None
+ and m.xve_active_pct is not None
+ and m.xve_stalled_pct > m.xve_active_pct
+ ):
+ recs.append(
+ SyclRecommendation(
+ "memory_bound",
+ "XVE Stalled > Active — mainloop is memory-bound. Increase "
+ "PipelineStages (prefetch depth), try explicit 2D-block/VNNI "
+ "copy atoms, or reduce TileK.",
+ kb,
+ )
+ )
+ if m.peak_occupancy_pct is not None and m.peak_occupancy_pct < 50:
+ recs.append(
+ SyclRecommendation(
+ "low_occupancy",
+ f"Peak occupancy {m.peak_occupancy_pct:.0f}% — grid may be too "
+ "small or registers too high. Try a smaller TileShape (e.g. "
+ "256->128) or confirm 256-GRF mode; for small problems expect "
+ "work-size-limited occupancy.",
+ kb,
+ )
+ )
+ if m.xve_idle_pct is not None and m.xve_idle_pct > 30:
+ recs.append(
+ SyclRecommendation(
+ "high_idle",
+ f"XVE Idle {m.xve_idle_pct:.0f}% — poor work distribution across "
+ "EUs. Revisit TileShape vs M/N (tail effects); consider stream-K "
+ "or a persistent scheduler.",
+ kb,
+ )
+ )
+ if m.xmx_active_pct is not None and m.xmx_active_pct < 20:
+ recs.append(
+ SyclRecommendation(
+ "low_xmx",
+ f"XMX (DPAS) active {m.xmx_active_pct:.0f}% — the matrix engine is "
+ "underutilized. Raise compute intensity: larger N-per-subgroup, "
+ "check SubgroupLayout vs the DPAS atom.",
+ kb,
+ )
+ )
+ if m.l3_miss_pct is not None and m.l3_miss_pct > 50:
+ recs.append(
+ SyclRecommendation(
+ "l3_thrashing",
+ f"L3 miss ratio {m.l3_miss_pct:.0f}% — cache thrashing. Reduce tile "
+ "sizes or improve data reuse / K-blocking.",
+ kb,
+ )
+ )
+ return recs
diff --git a/src/xe_forge/core/trial_manager.py b/src/xe_forge/core/trial_manager.py
index 4009d86..dcbc4a2 100644
--- a/src/xe_forge/core/trial_manager.py
+++ b/src/xe_forge/core/trial_manager.py
@@ -119,6 +119,7 @@ def save_trial(
"speedup": None,
"baseline_us": None,
"triton_us": None,
+ "tflops": None,
"status": "saved",
}
self._save_state(kernel_name, state)
@@ -135,6 +136,7 @@ def record_result(
speedup: float | None = None,
baseline_us: float | None = None,
triton_us: float | None = None,
+ tflops: float | None = None,
) -> dict:
"""Record benchmark results for a trial. Returns the trial dict."""
state = self._load_state(kernel_name)
@@ -155,6 +157,8 @@ def record_result(
trial["baseline_us"] = baseline_us
if triton_us is not None:
trial["triton_us"] = triton_us
+ if tflops is not None:
+ trial["tflops"] = tflops
if baseline_us is not None and state.get("baseline_us") is None:
state["baseline_us"] = [baseline_us]
@@ -223,7 +227,13 @@ def _render(tid: str, prefix: str = "", is_last: bool = True) -> None:
speedup_str = f"{trial['speedup']:.2f}x" if trial["speedup"] is not None else "---"
runtime = ""
if trial.get("baseline_us") is not None and trial.get("triton_us") is not None:
- runtime = f" (bl={trial['baseline_us']:.0f}us, tr={trial['triton_us']:.0f}us)"
+ tflops_part = (
+ f", {trial['tflops']:.1f} TFLOPS" if trial.get("tflops") is not None else ""
+ )
+ runtime = (
+ f" (bl={trial['baseline_us']:.0f}us, "
+ f"tr={trial['triton_us']:.0f}us{tflops_part})"
+ )
best_marker = " <<<< BEST" if tid == state["best_trial"] else ""
strategy_short = (trial["strategy"] or "")[:60]
lines.append(
diff --git a/src/xe_forge/core/validator.py b/src/xe_forge/core/validator.py
index 3922538..057f878 100644
--- a/src/xe_forge/core/validator.py
+++ b/src/xe_forge/core/validator.py
@@ -410,6 +410,45 @@ def _validate_sycl(self, code: str) -> list[ValidationIssue]:
)
)
+ # A standalone runnable kernel must have an entry point.
+ if "int main(" not in code and "int main (" not in code:
+ issues.append(
+ ValidationIssue(
+ "missing_main",
+ "error",
+ "No 'int main(' found. SYCL kernel must be a standalone "
+ "executable with a main() that parses CLI args.",
+ )
+ )
+
+ # The Xe-Forge file-IO harness contract: read inputs from --input_dir,
+ # write D2.bin to --output_dir. Missing any of these means the kernel
+ # cannot be benchmarked against the golden reference.
+ contract_tokens = ("input_dir", "output_dir", "D2.bin")
+ missing = [t for t in contract_tokens if t not in code]
+ if missing:
+ issues.append(
+ ValidationIssue(
+ "missing_io_contract",
+ "warning",
+ "Missing file-IO contract token(s): "
+ f"{', '.join(missing)}. Kernel must read A.bin/B0.bin from "
+ "--input_dir and write D2.bin (f32) to --output_dir.",
+ suggestion="See knowledge_base/sycl/xpu/sycl_io_contract.yaml",
+ )
+ )
+
+ # Informational: CUTLASS is the recommended GEMM framework on Xe.
+ if "cutlass" not in code.lower():
+ issues.append(
+ ValidationIssue(
+ "no_cutlass_include",
+ "info",
+ "No CUTLASS include detected. CUTLASS SYCL is the recommended "
+ "framework for high-performance GEMM on Intel Xe.",
+ )
+ )
+
return issues
# ------------------------------------------------------------------
diff --git a/src/xe_forge/skills/__init__.py b/src/xe_forge/skills/__init__.py
index 0d1ac02..877d30e 100644
--- a/src/xe_forge/skills/__init__.py
+++ b/src/xe_forge/skills/__init__.py
@@ -67,6 +67,7 @@ def main():
t_result.add_argument("--speedup", type=float)
t_result.add_argument("--baseline-us", type=float)
t_result.add_argument("--triton-us", type=float)
+ t_result.add_argument("--tflops", type=float, help="Optimized kernel throughput (TFLOPS)")
t_result.add_argument("--trials-dir", default="./trials")
t_status = trial_sub.add_parser("status")
@@ -94,6 +95,7 @@ def main():
p_profile.add_argument("--warmup", type=int, default=5)
p_profile.add_argument("--iters", type=int, default=20)
p_profile.add_argument("--vtune-bin", default="vtune")
+ p_profile.add_argument("--dsl", default="triton", choices=["triton", "sycl", "gluon", "cuda"])
args = parser.parse_args()
diff --git a/src/xe_forge/skills/benchmark.py b/src/xe_forge/skills/benchmark.py
index d416d51..10497a1 100644
--- a/src/xe_forge/skills/benchmark.py
+++ b/src/xe_forge/skills/benchmark.py
@@ -1,7 +1,50 @@
-"""xe-forge-skill benchmark: Correctness + performance comparison."""
+"""xe-forge-skill benchmark: Correctness + performance comparison.
+
+Two DSL paths, dispatched on ``args.dsl``:
+ * ``_run_triton`` — PyTorch/Triton kernels via ``KernelBenchExecutor``
+ (original-vs-optimized correctness).
+ * ``_run_sycl`` — CUTLASS SYCL ``.cpp`` kernels via ``SyclExecutor``,
+ correctness checked against a **golden PyTorch/numpy reference** computed on
+ the same (bit-identical) inputs.
+
+Both print the same lines so ``tool-runner`` / ``trial result --triton-us``
+parsing stays uniform across DSLs (the ``triton_us=`` token is kept verbatim).
+"""
def run(args):
+ if getattr(args, "dsl", "triton") == "sycl":
+ return _run_sycl(args)
+ return _run_triton(args)
+
+
+def _perf_line(baseline_us, opt_us, speedup, tflops=None, peak_tflops=None):
+ """Build the uniform Performance: line, appending TFLOPS + utilization.
+
+ ``tflops`` is the optimized kernel's achieved throughput; ``peak_tflops`` is
+ the device's theoretical peak (config.device_config.peak_tflops) used to
+ report utilization as a percentage. Both are optional — the us/speedup part
+ is always emitted, so the line stays parseable when TFLOPS is unavailable.
+ """
+ line = (
+ f"Performance: baseline_us={baseline_us:.2f}, "
+ f"triton_us={opt_us:.2f}, speedup={speedup:.2f}x"
+ )
+ if tflops is not None:
+ line += f", tflops={tflops:.2f}"
+ if peak_tflops:
+ line += f", util={tflops / peak_tflops * 100:.1f}%"
+ return line
+
+
+def _peak_tflops():
+ """Theoretical peak TFLOPS from the active device config (for utilization)."""
+ from xe_forge.config import get_config
+
+ return get_config().device_config.peak_tflops
+
+
+def _run_triton(args):
from pathlib import Path
from xe_forge.core.executor import KernelBenchExecutor
@@ -19,6 +62,7 @@ def run(args):
init_args = spec.get_init_args(variant)
executor = KernelBenchExecutor(device=args.device)
+ peak = _peak_tflops()
if args.baseline_us is not None:
baseline_us = [float(v) for v in str(args.baseline_us).split(",")]
@@ -38,8 +82,13 @@ def run(args):
speedup = baseline_ms / opt_ms if opt_ms > 0 else 0
print(f"Correctness: {'PASSED' if optimized_result.success else 'FAILED'}")
print(
- f"Performance: baseline_us={baseline_ms * 1000:.2f}, "
- f"triton_us={opt_ms * 1000:.2f}, speedup={speedup:.2f}x"
+ _perf_line(
+ baseline_ms * 1000,
+ opt_ms * 1000,
+ speedup,
+ tflops=optimized_result.tflops,
+ peak_tflops=peak,
+ )
)
else:
print("Correctness: FAILED")
@@ -57,8 +106,187 @@ def run(args):
print(f"Correctness: {'PASSED' if result.optimized_correct else 'FAILED'}")
if result.original_time_us and result.optimized_time_us:
print(
- f"Performance: baseline_us={result.original_time_us:.2f}, "
- f"triton_us={result.optimized_time_us:.2f}, speedup={result.speedup:.2f}x"
+ _perf_line(
+ result.original_time_us,
+ result.optimized_time_us,
+ result.speedup,
+ tflops=result.optimized_tflops,
+ peak_tflops=peak,
+ )
)
if result.feedback_message:
print(f"Feedback: {result.feedback_message}")
+
+
+# ---------------------------------------------------------------------------
+# SYCL path
+# ---------------------------------------------------------------------------
+
+
+def _bf16_bin_to_torch(path: str, rows: int, cols: int, dtype):
+ """Read a tensor .bin written by SyclExecutor._save_tensor back into torch.
+
+ bfloat16 is stored as raw int16 bits (NumPy has no bf16); other dtypes are
+ stored as their native numpy bytes. Layout is row-major [rows, cols].
+ """
+ import numpy as np
+ import torch
+
+ if dtype == torch.bfloat16:
+ raw = np.fromfile(path, dtype=np.int16)
+ t = torch.from_numpy(raw).view(torch.bfloat16)
+ else:
+ np_dtype = {
+ torch.float16: np.float16,
+ torch.float32: np.float32,
+ }.get(dtype, np.float32)
+ raw = np.fromfile(path, dtype=np_dtype)
+ t = torch.from_numpy(raw)
+ return t.reshape(rows, cols)
+
+
+def _compute_golden(reference_path, input_dir, m, n, k, dtype):
+ """Run the PyTorch golden reference on the bit-identical .bin inputs.
+
+ Reads A.bin [M,K], B0.bin [K,N], B1.bin [K,N] (the GEMM-shaped inputs
+ SyclExecutor.generate_inputs emits) and feeds (A, B0[, B1]) by position to
+ the reference Model. Returns the result as a float32 numpy array.
+ """
+ import inspect
+ import os
+
+ from ai_bench.utils import import_from_path
+
+ module = import_from_path("sycl_golden_ref", reference_path)
+ if not hasattr(module, "Model"):
+ raise ValueError(f"Reference {reference_path} has no Model class")
+
+ # Instantiate the reference. GEMM goldens take no init args; fall back to
+ # get_init_inputs() when the reference declares it (matches executor).
+ if hasattr(module, "get_init_inputs"):
+ model = module.Model(*module.get_init_inputs())
+ else:
+ model = module.Model()
+
+ A = _bf16_bin_to_torch(os.path.join(input_dir, "A.bin"), m, k, dtype)
+ B0 = _bf16_bin_to_torch(os.path.join(input_dir, "B0.bin"), k, n, dtype)
+
+ # Pass B1 only when the reference's forward expects a third positional arg.
+ forward_params = [
+ p
+ for p in inspect.signature(model.forward).parameters.values()
+ if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
+ ]
+ if len(forward_params) >= 3:
+ B1 = _bf16_bin_to_torch(os.path.join(input_dir, "B1.bin"), k, n, dtype)
+ out = model(A, B0, B1)
+ else:
+ out = model(A, B0)
+
+ return out.float().detach().cpu().numpy()
+
+
+def _run_sycl(args):
+ import os
+
+ from xe_forge.config import get_config
+ from xe_forge.core.spec_loader import load_spec
+ from xe_forge.core.sycl_executor import KernelType, SyclExecutor
+
+ config = get_config()
+ spec = load_spec(args.spec)
+ variant = spec.resolve_variant(args.variant)
+ dims = spec.get_dims(variant)
+ dtype = spec.get_dtype(variant)
+
+ m = int(dims.get("M", dims.get("N", 1024)))
+ n = int(dims.get("N", m))
+ k = int(dims.get("K", m))
+
+ # Spec-driven tolerances, falling back to config defaults.
+ rtol = spec.get_rtol(variant)
+ atol = spec.get_atol(variant)
+ if rtol is None:
+ rtol = config.optimization.correctness_rtol
+ if atol is None:
+ atol = config.optimization.correctness_atol
+
+ executor = SyclExecutor(kernel_type=KernelType.GEMM, verify=False)
+ input_dir = executor.get_or_create_inputs(dims, seed=42, dtype=dtype)
+
+ # Golden reference: {baseline_stem}_pytorch.py next to the baseline .cpp.
+ baseline_stem = os.path.splitext(args.baseline)[0]
+ reference_path = f"{baseline_stem}_pytorch.py"
+ golden = None
+ if os.path.exists(reference_path):
+ try:
+ golden = _compute_golden(reference_path, input_dir, m, n, k, dtype)
+ except Exception as e:
+ print(f"Warning: golden reference failed ({e}); correctness will be skipped")
+ else:
+ print(f"No PyTorch golden reference at {reference_path}; correctness will be skipped")
+
+ # Correctness + optimized timing vs the golden array.
+ opt_tflops = None
+ if golden is not None:
+ result = executor.compare_with_reference(
+ golden_output=golden,
+ optimized_path=args.optimized,
+ dims=dims,
+ rtol=rtol,
+ atol=atol,
+ input_dir=input_dir,
+ )
+ opt_ms = result.optimized_time_ms
+ opt_tflops = result.optimized_tflops
+ correct = result.optimized_correct
+ feedback = result.feedback_message
+ else:
+ # No golden — just run the kernel for timing.
+ run = executor.execute(
+ kernel_path=args.optimized,
+ dims=dims,
+ output_name="optimized_sycl",
+ input_dir=input_dir,
+ )
+ opt_ms = run.execution_time_ms if run.success else None
+ opt_tflops = run.tflops if run.success else None
+ correct = run.success
+ feedback = run.error_message if not run.success else ""
+
+ if not correct or opt_ms is None:
+ print("Correctness: FAILED")
+ if feedback:
+ print(f"Feedback: {feedback}")
+ return
+
+ opt_us = opt_ms * 1000.0
+ peak = config.device_config.peak_tflops
+
+ # Baseline caching, uniform with Triton: t0 times the baseline .cpp; t1+
+ # reuses the cached baseline_us and only computes the speedup.
+ if args.baseline_us is not None:
+ baseline_us_list = [float(v) for v in str(args.baseline_us).split(",")]
+ baseline_us = sum(baseline_us_list) / len(baseline_us_list)
+ print(f"Using cached baseline: {baseline_us:.2f} us")
+ else:
+ baseline_run = executor.execute(
+ kernel_path=args.baseline,
+ dims=dims,
+ output_name="baseline_sycl",
+ input_dir=input_dir,
+ )
+ if not baseline_run.success or not baseline_run.execution_time_ms:
+ print("Correctness: PASSED")
+ print(
+ f"Warning: baseline kernel did not produce a timing: {baseline_run.error_message}"
+ )
+ print(_perf_line(0.0, opt_us, 0.0, tflops=opt_tflops, peak_tflops=peak))
+ return
+ baseline_us = baseline_run.execution_time_ms * 1000.0
+
+ speedup = baseline_us / opt_us if opt_us > 0 else 0.0
+ print("Correctness: PASSED")
+ print(_perf_line(baseline_us, opt_us, speedup, tflops=opt_tflops, peak_tflops=peak))
+ if feedback:
+ print(f"Feedback: {feedback}")
diff --git a/src/xe_forge/skills/profile.py b/src/xe_forge/skills/profile.py
index ab2b012..5215c58 100644
--- a/src/xe_forge/skills/profile.py
+++ b/src/xe_forge/skills/profile.py
@@ -1,7 +1,20 @@
-"""xe-forge-skill profile: VTune GPU profiling."""
+"""xe-forge-skill profile: VTune GPU profiling.
+
+Dispatched on ``args.dsl``:
+ * ``_profile_triton`` — PyTorch/Triton kernels via ``XPUProfiler`` (gpu-offload
+ around a generated Python runner).
+ * ``_profile_sycl`` — compiled CUTLASS SYCL ``.cpp`` via ``SyclProfiler``
+ (gpu-hotspots characterization on the binary, no Python runner).
+"""
def run(args):
+ if getattr(args, "dsl", "triton") == "sycl":
+ return _profile_sycl(args)
+ return _profile_triton(args)
+
+
+def _profile_triton(args):
from xe_forge.core.profiler import XPUProfiler
profiler = XPUProfiler(vtune_bin=args.vtune_bin)
@@ -13,3 +26,21 @@ def run(args):
iters=args.iters,
)
print(result.format_for_llm())
+
+
+def _profile_sycl(args):
+ from xe_forge.core.spec_loader import load_spec
+ from xe_forge.core.sycl_profiler import KernelType, SyclProfiler
+
+ spec = load_spec(args.spec)
+ variant = spec.resolve_variant(args.variant)
+ dims = spec.get_dims(variant)
+ dtype = spec.get_dtype(variant)
+
+ profiler = SyclProfiler(
+ vtune_bin=args.vtune_bin,
+ kernel_type=KernelType.GEMM,
+ iterations=args.iters,
+ )
+ result = profiler.profile(kernel_path=args.kernel_file, dims=dims, dtype=dtype)
+ print(result.format_for_llm())
diff --git a/src/xe_forge/skills/trial.py b/src/xe_forge/skills/trial.py
index 2792770..d47c5ed 100644
--- a/src/xe_forge/skills/trial.py
+++ b/src/xe_forge/skills/trial.py
@@ -31,12 +31,14 @@ def run(args):
speedup=args.speedup,
baseline_us=args.baseline_us,
triton_us=args.triton_us,
+ tflops=args.tflops,
)
status_icon = {"completed": "+", "failed": "X", "partial": "~", "saved": "?"}
icon = status_icon.get(trial["status"], "?")
+ tflops_str = f", tflops={trial['tflops']}" if trial.get("tflops") is not None else ""
print(
f"[{icon}] {args.trial_id}: correctness={trial['correctness']}, "
- f"speedup={trial['speedup']}"
+ f"speedup={trial['speedup']}{tflops_str}"
)
case "status":
diff --git a/tests/test_benchmark_skill_routing.py b/tests/test_benchmark_skill_routing.py
new file mode 100644
index 0000000..0256354
--- /dev/null
+++ b/tests/test_benchmark_skill_routing.py
@@ -0,0 +1,205 @@
+"""Tests for the benchmark skill DSL dispatch (Triton vs SYCL).
+
+Platform-independent: the SyclExecutor / KernelBenchExecutor GPU seams are
+mocked, so no icpx or GPU is needed.
+"""
+
+import types
+from argparse import Namespace
+
+import numpy as np
+import pytest
+
+from xe_forge.skills import benchmark
+
+
+class _FakeExecResult:
+ def __init__(self, success=True, time_ms=0.5, tflops=10.0, error=""):
+ self.success = success
+ self.execution_time_ms = time_ms
+ self.tflops = tflops
+ self.error_message = error
+
+
+class _FakeSyclComparison:
+ def __init__(self, correct=True, opt_ms=0.5, tflops=10.0, msg="ok"):
+ self.optimized_correct = correct
+ self.optimized_time_ms = opt_ms
+ self.optimized_tflops = tflops
+ self.feedback_message = msg
+
+
+def _write_gemm_spec(tmp_path):
+ spec = tmp_path / "gemm.yaml"
+ spec.write_text(
+ "inputs:\n"
+ " A:\n"
+ " shape: [M, K]\n"
+ " dtype: bfloat16\n"
+ "bench-xpu:\n"
+ " - params: [A]\n"
+ " dtype: bfloat16\n"
+ " dims: { M: 256, N: 256, K: 256 }\n"
+ " flop: '2*M*N*K'\n"
+ " rtol: 0.02\n"
+ " atol: 0.01\n"
+ )
+ return spec
+
+
+def test_sycl_routes_to_sycl_executor(tmp_path, monkeypatch):
+ """args.dsl == 'sycl' must use SyclExecutor, never KernelBenchExecutor."""
+ spec = _write_gemm_spec(tmp_path)
+ baseline = tmp_path / "gemm.cpp"
+ baseline.write_text("#include \nint main(){}\n")
+ optimized = tmp_path / "t1.cpp"
+ optimized.write_text("#include \nint main(){}\n")
+
+ captured = {}
+
+ class FakeSyclExecutor:
+ def __init__(self, *a, **k):
+ captured["constructed"] = True
+
+ def get_or_create_inputs(self, dims, seed=42, dtype=None):
+ captured["dims"] = dims
+ return str(tmp_path / "inputs")
+
+ def compare_with_reference(self, **kwargs):
+ captured["compare_kwargs"] = kwargs
+ return _FakeSyclComparison(correct=True, opt_ms=0.4)
+
+ def execute(self, **kwargs):
+ captured.setdefault("execute_calls", []).append(kwargs)
+ return _FakeExecResult(success=True, time_ms=0.8)
+
+ # Fail loudly if the SYCL path ever touches KernelBenchExecutor.
+ def _boom(*a, **k):
+ raise AssertionError("SYCL path must not construct KernelBenchExecutor")
+
+ monkeypatch.setattr("xe_forge.core.sycl_executor.SyclExecutor", FakeSyclExecutor, raising=True)
+ monkeypatch.setattr("xe_forge.core.executor.KernelBenchExecutor", _boom, raising=True)
+ # No golden reference file -> correctness skipped, exercises the execute() path.
+ args = Namespace(
+ dsl="sycl",
+ baseline=str(baseline),
+ optimized=str(optimized),
+ spec=str(spec),
+ variant="bench-xpu",
+ baseline_us=None,
+ device="xpu",
+ )
+ benchmark.run(args)
+ assert captured.get("constructed")
+ assert captured["dims"] == {"M": 256, "N": 256, "K": 256}
+
+
+def test_sycl_baseline_us_skips_baseline_rerun(tmp_path, monkeypatch, capsys):
+ """When --baseline-us is set, the baseline .cpp is NOT re-run."""
+ spec = _write_gemm_spec(tmp_path)
+ baseline = tmp_path / "gemm.cpp"
+ baseline.write_text("#include \nint main(){}\n")
+ optimized = tmp_path / "t1.cpp"
+ optimized.write_text("#include \nint main(){}\n")
+ # Provide a golden reference so compare_with_reference is taken.
+ (tmp_path / "gemm_pytorch.py").write_text(
+ "import torch, torch.nn as nn\n"
+ "class Model(nn.Module):\n"
+ " def forward(self, A, B0):\n"
+ " return A.float() @ B0.float()\n"
+ )
+
+ execute_calls = []
+
+ class FakeSyclExecutor:
+ def __init__(self, *a, **k):
+ pass
+
+ def get_or_create_inputs(self, dims, seed=42, dtype=None):
+ return str(tmp_path / "inputs")
+
+ def compare_with_reference(self, **kwargs):
+ return _FakeSyclComparison(correct=True, opt_ms=0.4)
+
+ def execute(self, **kwargs):
+ execute_calls.append(kwargs)
+ return _FakeExecResult(success=True, time_ms=0.8)
+
+ # Stub the golden computation so we don't need torch-on-bin roundtrip here.
+ monkeypatch.setattr("xe_forge.core.sycl_executor.SyclExecutor", FakeSyclExecutor, raising=True)
+ monkeypatch.setattr(
+ benchmark,
+ "_compute_golden",
+ lambda *a, **k: np.zeros((256, 256), dtype=np.float32),
+ )
+
+ args = Namespace(
+ dsl="sycl",
+ baseline=str(baseline),
+ optimized=str(optimized),
+ spec=str(spec),
+ variant="bench-xpu",
+ baseline_us=123.0,
+ device="xpu",
+ )
+ benchmark.run(args)
+ out = capsys.readouterr().out
+ # No execute() call at all (compare_with_reference handles the optimized run,
+ # baseline rerun is skipped because baseline_us was supplied).
+ assert execute_calls == []
+ assert "Using cached baseline" in out
+ assert "triton_us=" in out
+ assert "Correctness: PASSED" in out
+ # SYCL perf line carries optimized TFLOPS + utilization (10/160 = 6.2%).
+ assert "tflops=10.00" in out
+ assert "util=6.2%" in out
+
+
+def test_triton_path_unchanged(tmp_path, monkeypatch, capsys):
+ """args.dsl == 'triton' still uses KernelBenchExecutor.compare_kernels."""
+ spec = _write_gemm_spec(tmp_path)
+ baseline = tmp_path / "b.py"
+ baseline.write_text("class Model: pass\n")
+ optimized = tmp_path / "o.py"
+ optimized.write_text("class Model: pass\n")
+
+ called = {}
+
+ class FakeKB:
+ def __init__(self, *a, **k):
+ called["constructed"] = True
+
+ def compare_kernels(self, **kwargs):
+ called["compared"] = True
+ r = types.SimpleNamespace(
+ optimized_correct=True,
+ original_time_us=100.0,
+ optimized_time_us=50.0,
+ speedup=2.0,
+ optimized_tflops=80.0,
+ feedback_message="good",
+ )
+ return r
+
+ monkeypatch.setattr("xe_forge.core.executor.KernelBenchExecutor", FakeKB, raising=True)
+ args = Namespace(
+ dsl="triton",
+ baseline=str(baseline),
+ optimized=str(optimized),
+ spec=str(spec),
+ variant="bench-xpu",
+ baseline_us=None,
+ device="xpu",
+ )
+ benchmark.run(args)
+ out = capsys.readouterr().out
+ assert called.get("compared")
+ assert "Correctness: PASSED" in out
+ assert "triton_us=50.00" in out
+ # TFLOPS + utilization appended; util = 80 / peak(160) * 100 = 50.0%.
+ assert "tflops=80.00" in out
+ assert "util=50.0%" in out
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/test_generator_sycl.py b/tests/test_generator_sycl.py
new file mode 100644
index 0000000..fcf4948
--- /dev/null
+++ b/tests/test_generator_sycl.py
@@ -0,0 +1,149 @@
+"""Tests for DSL-aware Claude workspace generation (SYCL path).
+
+Platform-independent: no icpx/GPU involved — only template rendering and file
+layout are exercised.
+"""
+
+from pathlib import Path
+
+import pytest
+
+from xe_forge.claude.generator import generate_workspace
+from xe_forge.config import Config, XPUConfig
+
+REFERENCE_PY = """\
+import torch
+import torch.nn as nn
+
+
+class Model(nn.Module):
+ def forward(self, A, B0):
+ return A.float() @ B0.float()
+"""
+
+SYCL_CPP = """\
+#include "cutlass/gemm/device/gemm_universal.h"
+int main(int argc, const char** argv) {
+ // reads input_dir A.bin/B0.bin, writes output_dir D2.bin
+ return 0;
+}
+"""
+
+TRITON_PY = """\
+import triton
+import triton.language as tl
+
+
+class Model:
+ pass
+"""
+
+
+def _sycl_config(workspace: Path) -> Config:
+ cfg = Config()
+ cfg.device_config = XPUConfig(device="xpu", dsl="sycl")
+ cfg.engine.git_init = False
+ cfg.engine.workspace = str(workspace)
+ cfg.trial.max_trials = 5
+ return cfg
+
+
+def _triton_config(workspace: Path) -> Config:
+ cfg = Config()
+ cfg.device_config = XPUConfig(device="xpu", dsl="triton")
+ cfg.engine.git_init = False
+ cfg.trial.max_trials = 5
+ return cfg
+
+
+def test_sycl_writes_cpp_kernel(tmp_path):
+ ws = tmp_path / "ws"
+ generate_workspace(
+ workspace=ws,
+ config=_sycl_config(ws),
+ kernel_name="gemm",
+ kernel_code=SYCL_CPP,
+ reference_code=REFERENCE_PY,
+ )
+ assert (ws / "test_kernels" / "gemm.cpp").exists()
+ assert not (ws / "test_kernels" / "gemm.py").exists()
+ # PyTorch golden reference written alongside the .cpp.
+ assert (ws / "test_kernels" / "gemm_pytorch.py").read_text() == REFERENCE_PY
+ # Existing .cpp with #include is kept verbatim (not replaced by the stub).
+ assert (ws / "test_kernels" / "gemm.cpp").read_text() == SYCL_CPP
+
+
+def test_sycl_starter_stub_when_no_include(tmp_path):
+ ws = tmp_path / "ws"
+ # PyTorch-only input (no #include) -> starter stub is substituted.
+ generate_workspace(
+ workspace=ws,
+ config=_sycl_config(ws),
+ kernel_name="gemm",
+ kernel_code=REFERENCE_PY,
+ reference_code=REFERENCE_PY,
+ )
+ cpp = (ws / "test_kernels" / "gemm.cpp").read_text()
+ assert "#include" in cpp
+ assert "int main(" in cpp
+ assert "input_dir" in cpp and "output_dir" in cpp and "D2.bin" in cpp
+ assert "cutlass" in cpp.lower()
+
+
+def test_sycl_claude_md_content(tmp_path):
+ ws = tmp_path / "ws"
+ generate_workspace(
+ workspace=ws,
+ config=_sycl_config(ws),
+ kernel_name="gemm",
+ kernel_code=SYCL_CPP,
+ reference_code=REFERENCE_PY,
+ )
+ claude = (ws / "CLAUDE.md").read_text()
+ # SYCL-specific contract markers present.
+ for token in [
+ "input_dir",
+ "output_dir",
+ "D2.bin",
+ "--dsl sycl",
+ "bench-xpu",
+ "knowledge_base/sycl/xpu",
+ ]:
+ assert token in claude, f"missing {token!r} in SYCL CLAUDE.md"
+ # Triton-isms absent.
+ for token in ["@triton.autotune", "GROUP_SIZE_M", "xe-forge-skill analyze"]:
+ assert token not in claude, f"unexpected Triton token {token!r} in SYCL CLAUDE.md"
+
+
+def test_sycl_optimize_command_uses_dsl_flag(tmp_path):
+ ws = tmp_path / "ws"
+ generate_workspace(
+ workspace=ws,
+ config=_sycl_config(ws),
+ kernel_name="gemm",
+ kernel_code=SYCL_CPP,
+ reference_code=REFERENCE_PY,
+ )
+ cmd = (ws / ".claude" / "commands" / "optimize-kernel.md").read_text()
+ assert "--dsl sycl" in cmd
+ assert ".cpp" in cmd
+
+
+def test_triton_regression_writes_py(tmp_path):
+ ws = tmp_path / "ws"
+ generate_workspace(
+ workspace=ws,
+ config=_triton_config(ws),
+ kernel_name="kern",
+ kernel_code=TRITON_PY,
+ reference_code=REFERENCE_PY,
+ )
+ assert (ws / "test_kernels" / "kern.py").exists()
+ assert not (ws / "test_kernels" / "kern.cpp").exists()
+ claude = (ws / "CLAUDE.md").read_text()
+ # Triton CLAUDE.md keeps the analyze step.
+ assert "analyze" in claude
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/test_profile_skill_routing.py b/tests/test_profile_skill_routing.py
new file mode 100644
index 0000000..ebad2cc
--- /dev/null
+++ b/tests/test_profile_skill_routing.py
@@ -0,0 +1,123 @@
+"""Tests for the profile skill DSL dispatch (Triton vs SYCL).
+
+Platform-independent: the XPUProfiler / SyclProfiler seams are mocked.
+"""
+
+from argparse import Namespace
+
+import pytest
+
+from xe_forge.skills import profile
+
+
+class _FakeResult:
+ def __init__(self, text):
+ self._text = text
+
+ def format_for_llm(self):
+ return self._text
+
+
+def test_sycl_routes_to_sycl_profiler(tmp_path, monkeypatch, capsys):
+ """args.dsl == 'sycl' must use SyclProfiler, never XPUProfiler."""
+ spec = tmp_path / "gemm.yaml"
+ spec.write_text(
+ "inputs:\n A:\n shape: [M, K]\n dtype: bfloat16\n"
+ "bench-xpu:\n - params: [A]\n dtype: bfloat16\n"
+ " dims: { M: 1024, N: 1024, K: 1024 }\n flop: '2*M*N*K'\n"
+ )
+ captured = {}
+
+ class FakeSyclProfiler:
+ def __init__(self, *a, **k):
+ captured["constructed"] = True
+
+ def profile(self, kernel_path, dims, dtype, **k):
+ captured["dims"] = dims
+ return _FakeResult("SYCL PROFILE OK")
+
+ def _boom(*a, **k):
+ raise AssertionError("SYCL path must not construct XPUProfiler")
+
+ monkeypatch.setattr("xe_forge.core.sycl_profiler.SyclProfiler", FakeSyclProfiler, raising=True)
+ monkeypatch.setattr("xe_forge.core.profiler.XPUProfiler", _boom, raising=True)
+
+ args = Namespace(
+ dsl="sycl",
+ kernel_file=str(tmp_path / "gemm.cpp"),
+ spec=str(spec),
+ variant="bench-xpu",
+ warmup=5,
+ iters=200,
+ vtune_bin="vtune",
+ )
+ profile.run(args)
+ out = capsys.readouterr().out
+ assert captured.get("constructed")
+ assert captured["dims"] == {"M": 1024, "N": 1024, "K": 1024}
+ assert "SYCL PROFILE OK" in out
+
+
+def test_triton_routes_to_xpu_profiler(tmp_path, monkeypatch, capsys):
+ """args.dsl == 'triton' still uses XPUProfiler."""
+ called = {}
+
+ class FakeXPUProfiler:
+ def __init__(self, *a, **k):
+ called["constructed"] = True
+
+ def profile(self, kernel_file, **k):
+ return _FakeResult("TRITON PROFILE OK")
+
+ monkeypatch.setattr("xe_forge.core.profiler.XPUProfiler", FakeXPUProfiler, raising=True)
+
+ args = Namespace(
+ dsl="triton",
+ kernel_file=str(tmp_path / "k.py"),
+ spec=str(tmp_path / "s.yaml"),
+ variant="bench-gpu",
+ warmup=5,
+ iters=20,
+ vtune_bin="vtune",
+ )
+ profile.run(args)
+ out = capsys.readouterr().out
+ assert called.get("constructed")
+ assert "TRITON PROFILE OK" in out
+
+
+def test_profile_subparser_has_dsl(monkeypatch):
+ """The REAL profile subparser must accept --dsl (regression for the missing flag).
+
+ Drives xe_forge.skills.main() with a stubbed profile.run so only argument
+ parsing is exercised — proving the flag exists on the actual parser, not a
+ reconstruction.
+ """
+ import sys
+
+ from xe_forge import skills
+
+ seen = {}
+ monkeypatch.setattr(profile, "run", lambda args: seen.update(dsl=args.dsl, iters=args.iters))
+ monkeypatch.setattr(
+ sys,
+ "argv",
+ [
+ "xe-forge-skill",
+ "profile",
+ "k.cpp",
+ "--spec",
+ "s.yaml",
+ "--dsl",
+ "sycl",
+ "--iters",
+ "200",
+ ],
+ )
+ skills.main()
+ assert seen["dsl"] == "sycl"
+ assert seen["iters"] == 200
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/test_sycl_golden_reference.py b/tests/test_sycl_golden_reference.py
new file mode 100644
index 0000000..913f641
--- /dev/null
+++ b/tests/test_sycl_golden_reference.py
@@ -0,0 +1,201 @@
+"""Tests for SyclExecutor.compare_with_reference (golden-reference path).
+
+The icpx/GPU seam is mocked at SyclExecutor.execute: instead of compiling and
+running a kernel, the fake writes a known D2.bin so the Python-side comparison,
+reshape, and tolerance logic can be exercised on any platform.
+"""
+
+import os
+import re
+
+import numpy as np
+import pytest
+import torch
+
+from xe_forge.core.sycl_executor import SyclExecutor, _save_tensor
+from xe_forge.models import ExecutionResult
+
+
+def _make_executor():
+ # device_target="" avoids torch.xpu auto-detection during construction.
+ return SyclExecutor(device_target="", verify=False)
+
+
+def test_bf16_bin_roundtrip_bit_exact(tmp_path):
+ """_save_tensor (bf16-as-int16) round-trips bit-exactly via torch view."""
+ t = torch.randn(8, 16, dtype=torch.bfloat16)
+ path = str(tmp_path / "A.bin")
+ _save_tensor(t, path)
+ raw = np.fromfile(path, dtype=np.int16)
+ back = torch.from_numpy(raw).view(torch.bfloat16).reshape(8, 16)
+ assert torch.equal(t, back)
+
+
+def test_compare_with_reference_passed(tmp_path, monkeypatch):
+ ex = _make_executor()
+ golden = np.arange(256, dtype=np.float32).reshape(16, 16)
+
+ def fake_execute(**kwargs):
+ out_dir = kwargs["output_dir"]
+ os.makedirs(out_dir, exist_ok=True)
+ # Kernel produces exactly the golden values (flat f32).
+ golden.astype(np.float32).tofile(os.path.join(out_dir, "D2.bin"))
+ return ExecutionResult(success=True, execution_time_ms=0.3, tflops=12.0)
+
+ monkeypatch.setattr(ex, "execute", fake_execute)
+ monkeypatch.setattr(ex, "get_or_create_inputs", lambda *a, **k: str(tmp_path / "in"))
+
+ res = ex.compare_with_reference(
+ golden_output=golden,
+ optimized_path="dummy.cpp",
+ dims={"M": 16, "N": 16, "K": 16},
+ rtol=1e-2,
+ atol=1e-2,
+ input_dir=str(tmp_path / "in"),
+ )
+ assert res.optimized_correct is True
+ assert res.optimized_time_ms == 0.3
+ assert res.optimized_tflops == 12.0
+ assert "PASSED" in res.feedback_message
+
+
+def test_compare_with_reference_failed(tmp_path, monkeypatch):
+ ex = _make_executor()
+ golden = np.zeros((16, 16), dtype=np.float32)
+
+ def fake_execute(**kwargs):
+ out_dir = kwargs["output_dir"]
+ os.makedirs(out_dir, exist_ok=True)
+ # Wrong output (all ones) -> must fail.
+ np.ones((16, 16), dtype=np.float32).tofile(os.path.join(out_dir, "D2.bin"))
+ return ExecutionResult(success=True, execution_time_ms=0.3, tflops=12.0)
+
+ monkeypatch.setattr(ex, "execute", fake_execute)
+
+ res = ex.compare_with_reference(
+ golden_output=golden,
+ optimized_path="dummy.cpp",
+ dims={"M": 16, "N": 16, "K": 16},
+ rtol=1e-2,
+ atol=1e-2,
+ input_dir=str(tmp_path / "in"),
+ )
+ assert res.optimized_correct is False
+ assert "CORRECTNESS FAILURE" in res.feedback_message
+
+
+def test_compare_with_reference_compile_failure_surfaces(tmp_path, monkeypatch):
+ ex = _make_executor()
+ golden = np.zeros((16, 16), dtype=np.float32)
+
+ def fake_execute(**kwargs):
+ return ExecutionResult(success=False, error_message="Compilation failed: boom")
+
+ monkeypatch.setattr(ex, "execute", fake_execute)
+
+ res = ex.compare_with_reference(
+ golden_output=golden,
+ optimized_path="dummy.cpp",
+ dims={"M": 16, "N": 16, "K": 16},
+ input_dir=str(tmp_path / "in"),
+ )
+ assert res.optimized_correct is False
+ assert "boom" in res.feedback_message
+
+
+def test_compare_with_reference_missing_d2(tmp_path, monkeypatch):
+ ex = _make_executor()
+ golden = np.zeros((16, 16), dtype=np.float32)
+
+ def fake_execute(**kwargs):
+ # Success but no D2.bin written — kernel ignored the IO contract.
+ os.makedirs(kwargs["output_dir"], exist_ok=True)
+ return ExecutionResult(success=True, execution_time_ms=0.3, tflops=12.0)
+
+ monkeypatch.setattr(ex, "execute", fake_execute)
+
+ res = ex.compare_with_reference(
+ golden_output=golden,
+ optimized_path="dummy.cpp",
+ dims={"M": 16, "N": 16, "K": 16},
+ input_dir=str(tmp_path / "in"),
+ )
+ assert res.optimized_correct is False
+ assert "D2.bin" in res.feedback_message
+
+
+def test_printed_format_matches_trial_parser(tmp_path, monkeypatch, capsys):
+ """The benchmark skill's SYCL output must match the trial --triton-us regex."""
+ from argparse import Namespace
+
+ from xe_forge.skills import benchmark
+
+ spec = tmp_path / "gemm.yaml"
+ spec.write_text(
+ "inputs:\n A:\n shape: [M, K]\n dtype: bfloat16\n"
+ "bench-xpu:\n - params: [A]\n dtype: bfloat16\n"
+ " dims: { M: 16, N: 16, K: 16 }\n flop: '2*M*N*K'\n"
+ )
+ baseline = tmp_path / "gemm.cpp"
+ baseline.write_text("#include \nint main(){}\n")
+ optimized = tmp_path / "t1.cpp"
+ optimized.write_text("#include \nint main(){}\n")
+
+ class FakeSyclExecutor:
+ def __init__(self, *a, **k):
+ pass
+
+ def get_or_create_inputs(self, dims, seed=42, dtype=None):
+ return str(tmp_path / "in")
+
+ def compare_with_reference(self, **kwargs):
+ from xe_forge.core.sycl_executor import SyclComparisonResult
+
+ return SyclComparisonResult(
+ original_time_ms=float("inf"),
+ optimized_time_ms=0.5,
+ speedup=0.0,
+ optimized_tflops=10.0,
+ optimized_correct=True,
+ feedback_message="ok",
+ )
+
+ def execute(self, **kwargs):
+ return ExecutionResult(success=True, execution_time_ms=1.0, tflops=5.0)
+
+ monkeypatch.setattr("xe_forge.core.sycl_executor.SyclExecutor", FakeSyclExecutor, raising=True)
+ monkeypatch.setattr(
+ benchmark,
+ "_compute_golden",
+ lambda *a, **k: np.zeros((16, 16), dtype=np.float32),
+ )
+
+ args = Namespace(
+ dsl="sycl",
+ baseline=str(baseline),
+ optimized=str(optimized),
+ spec=str(spec),
+ variant="bench-xpu",
+ baseline_us=None,
+ device="xpu",
+ )
+ benchmark.run(args)
+ out = capsys.readouterr().out
+ # Same regex tool-runner / trial result parsing relies on.
+ m = re.search(r"baseline_us=([0-9.]+), triton_us=([0-9.]+), speedup=([0-9.]+)x", out)
+ assert m, f"perf line did not match expected format:\n{out}"
+ assert "Correctness: PASSED" in out
+ # No gemm_pytorch.py here -> the no-golden execute() path, which carries the
+ # kernel's parsed tflops (5.0; util = 5/160 = 3.1%). Confirms TFLOPS + util
+ # are appended after speedup on that path too.
+ full = re.search(
+ r"baseline_us=[0-9.]+, triton_us=[0-9.]+, speedup=[0-9.]+x, "
+ r"tflops=([0-9.]+), util=([0-9.]+)%",
+ out,
+ )
+ assert full, f"perf line missing tflops/util:\n{out}"
+ assert full.group(1) == "5.00"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/test_sycl_profiler.py b/tests/test_sycl_profiler.py
new file mode 100644
index 0000000..b32e7bb
--- /dev/null
+++ b/tests/test_sycl_profiler.py
@@ -0,0 +1,200 @@
+"""Tests for SyclProfiler (VTune gpu-hotspots path).
+
+Platform-independent: the vtune subprocess and SyclExecutor.compile/inputs seams
+are mocked, so no VTune, icpx, or GPU is needed. Column names and CSV shape match
+what VTune 2026.0 emits for a CUTLASS GEMM on the B70.
+"""
+
+import subprocess
+
+import pytest
+
+from xe_forge.core import sycl_profiler as sp
+from xe_forge.core.sycl_profiler import SyclProfileMetrics, SyclProfiler
+
+# A realistic two-row hotspots CSV: the CUTLASS kernel + an overhead copy task.
+# Tab-delimited, with the "(%)" suffix VTune appends to percentage columns.
+_CORE_CSV = (
+ "Computing Task\tComputing Task:Total Time\tXVE Array:Active(%)\t"
+ "XVE Array:Stalled(%)\tXVE Array:Idle(%)\tPeak XVE Threads Occupancy(%)\t"
+ "XVE Pipelines:XMX active(%)\tGPU L3:Miss Ratio(%)\n"
+ "GemmUniversal>\t0.0598\t7.3\t39.4\t53.3\t25.0\t5.0\t0.7\n"
+ "zeCommandListAppendMemoryCopy\t0.0006\t0.0\t0.0\t100.0\t0.0\t0.0\t100.0\n"
+)
+_BW_CSV = (
+ "Computing Task\tComputing Task:Total Time\t"
+ "GPU Memory Bandwidth, GB/sec:Read\tGPU Memory Bandwidth, GB/sec:Write\n"
+ "GemmUniversal>\t0.0598\t120.5\t60.2\n"
+ "zeCommandListAppendMemoryCopy\t0.0006\t0.0\t0.0\n"
+)
+
+
+def _profiler(monkeypatch, vtune_ok=True):
+ """Build a SyclProfiler without touching torch.xpu / a real executor."""
+
+ # Stub SyclExecutor construction inside the module to a lightweight fake.
+ # _dims_to_mnk is a pure static helper the module calls on the class, so the
+ # fake provides it with the same semantics as the real one.
+ class FakeExecutor:
+ def __init__(self, *a, **k):
+ pass
+
+ @staticmethod
+ def _dims_to_mnk(dims, m=1024, n=1024, k=1024):
+ if not dims:
+ return m, n, k
+ em = int(dims.get("M", dims.get("N", m)))
+ return em, int(dims.get("N", em)), int(dims.get("K", em))
+
+ monkeypatch.setattr(sp, "SyclExecutor", FakeExecutor, raising=True)
+ prof = SyclProfiler(vtune_bin="vtune")
+ monkeypatch.setattr(prof, "available", lambda: vtune_ok)
+ return prof
+
+
+def test_unavailable_vtune_returns_error_without_compiling(tmp_path, monkeypatch):
+ prof = _profiler(monkeypatch, vtune_ok=False)
+ kernel = tmp_path / "gemm.cpp"
+ kernel.write_text("#include \nint main(){}\n")
+ compiled = {"called": False}
+ monkeypatch.setattr(
+ prof._executor, "compile", lambda **k: compiled.__setitem__("called", True), raising=False
+ )
+ res = prof.profile(kernel, {"M": 1024, "N": 1024, "K": 1024})
+ assert res.error and "VTune not found" in res.error
+ assert compiled["called"] is False
+
+
+def test_missing_kernel_file(tmp_path, monkeypatch):
+ prof = _profiler(monkeypatch)
+ res = prof.profile(tmp_path / "nope.cpp", {"M": 256, "N": 256, "K": 256})
+ assert res.error and "not found" in res.error
+
+
+def test_compile_failure_surfaces(tmp_path, monkeypatch):
+ prof = _profiler(monkeypatch)
+ kernel = tmp_path / "gemm.cpp"
+ kernel.write_text("#include \nint main(){}\n")
+ monkeypatch.setattr(
+ prof._executor, "compile", lambda **k: (False, "", "boom error"), raising=False
+ )
+ res = prof.profile(kernel, {"M": 256, "N": 256, "K": 256})
+ assert res.error and "boom error" in res.error
+
+
+def _wire_success(prof, monkeypatch, tmp_path):
+ """Wire compile/inputs + a fake vtune subprocess that returns the CSVs."""
+ kernel = tmp_path / "gemm.cpp"
+ kernel.write_text("#include \nint main(){}\n")
+ monkeypatch.setattr(
+ prof._executor, "compile", lambda **k: (True, str(tmp_path / "bin"), ""), raising=False
+ )
+ monkeypatch.setattr(
+ prof._executor, "get_or_create_inputs", lambda *a, **k: str(tmp_path / "in"), raising=False
+ )
+
+ def fake_run(cmd, **kwargs):
+ joined = " ".join(cmd)
+ if "-collect" in cmd:
+ return subprocess.CompletedProcess(cmd, 0, "collected", "")
+ # Report pass: the BW pass requests the comma-free "GB/sec" substring.
+ out = _BW_CSV if "GB/sec" in joined else _CORE_CSV
+ return subprocess.CompletedProcess(cmd, 0, out, "")
+
+ monkeypatch.setattr(sp.subprocess, "run", fake_run)
+ return kernel
+
+
+def test_profile_success_parses_metrics_and_picks_kernel(tmp_path, monkeypatch):
+ prof = _profiler(monkeypatch)
+ kernel = _wire_success(prof, monkeypatch, tmp_path)
+ res = prof.profile(kernel, {"M": 1024, "N": 1024, "K": 1024})
+ assert res.error is None
+ # Overhead task must NOT be chosen as primary.
+ assert res.primary_kernel.startswith("GemmUniversal")
+ m = res.metrics
+ assert m.xve_active_pct == 7.3
+ assert m.xve_stalled_pct == 39.4
+ assert m.peak_occupancy_pct == 25.0
+ assert m.xmx_active_pct == 5.0
+ assert m.l3_miss_pct == 0.7
+ # Bandwidth came from the second (separate) report pass.
+ assert m.gpu_mem_bw_read_gbps == 120.5
+ assert m.gpu_mem_bw_write_gbps == 60.2
+
+
+def test_recommendations_memory_bound_and_low_occupancy(tmp_path, monkeypatch):
+ prof = _profiler(monkeypatch)
+ kernel = _wire_success(prof, monkeypatch, tmp_path)
+ res = prof.profile(kernel, {"M": 1024, "N": 1024, "K": 1024})
+ cats = {r.category for r in res.recommendations}
+ # stalled(39.4) > active(7.3) -> memory_bound; occ 25 < 50 -> low_occupancy;
+ # idle 53.3 > 30 -> high_idle; xmx 5 < 20 -> low_xmx.
+ assert {"memory_bound", "low_occupancy", "high_idle", "low_xmx"} <= cats
+ txt = res.format_for_llm()
+ assert "XVE Stalled" in txt and "Recommendations:" in txt
+
+
+def test_empty_report_is_error(tmp_path, monkeypatch):
+ prof = _profiler(monkeypatch)
+ kernel = tmp_path / "gemm.cpp"
+ kernel.write_text("#include \nint main(){}\n")
+ monkeypatch.setattr(
+ prof._executor, "compile", lambda **k: (True, str(tmp_path / "bin"), ""), raising=False
+ )
+ monkeypatch.setattr(
+ prof._executor, "get_or_create_inputs", lambda *a, **k: str(tmp_path / "in"), raising=False
+ )
+
+ def fake_run(cmd, **kwargs):
+ if "-collect" in cmd:
+ return subprocess.CompletedProcess(cmd, 0, "ok", "")
+ return subprocess.CompletedProcess(cmd, 0, "no data here\n", "")
+
+ monkeypatch.setattr(sp.subprocess, "run", fake_run)
+ res = prof.profile(kernel, {"M": 256, "N": 256, "K": 256})
+ assert res.error and "No GPU kernel data" in res.error
+
+
+def test_collection_failure_surfaces(tmp_path, monkeypatch):
+ prof = _profiler(monkeypatch)
+ kernel = tmp_path / "gemm.cpp"
+ kernel.write_text("#include \nint main(){}\n")
+ monkeypatch.setattr(
+ prof._executor, "compile", lambda **k: (True, str(tmp_path / "bin"), ""), raising=False
+ )
+ monkeypatch.setattr(
+ prof._executor, "get_or_create_inputs", lambda *a, **k: str(tmp_path / "in"), raising=False
+ )
+
+ def fake_run(cmd, **kwargs):
+ return subprocess.CompletedProcess(cmd, 1, "", "driver error")
+
+ monkeypatch.setattr(sp.subprocess, "run", fake_run)
+ res = prof.profile(kernel, {"M": 256, "N": 256, "K": 256})
+ assert res.error and "collection failed" in res.error.lower()
+
+
+def test_metrics_pct_suffix_fallback():
+ # _build_metrics must read both "X" and "X(%)" header spellings.
+ cols = {"XVE Array:Active(%)": "7.3", "GPU L3:Miss Ratio": "0.7"}
+ m = SyclProfiler._build_metrics(cols)
+ assert m.xve_active_pct == 7.3
+ assert m.l3_miss_pct == 0.7
+
+
+def test_no_recommendations_when_healthy():
+ # Active > stalled, high occupancy, high xmx, low idle, low l3 -> no recs.
+ m = SyclProfileMetrics(
+ xve_active_pct=85.0,
+ xve_stalled_pct=10.0,
+ xve_idle_pct=5.0,
+ peak_occupancy_pct=90.0,
+ xmx_active_pct=80.0,
+ l3_miss_pct=1.0,
+ )
+ assert SyclProfiler._recommendations(m) == []
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/test_validator.py b/tests/test_validator.py
index 9d4efa5..6febb6c 100644
--- a/tests/test_validator.py
+++ b/tests/test_validator.py
@@ -80,3 +80,48 @@ def test_2d_grid_with_swizzle_is_rejected(self):
def test_2d_tuple_grid_with_swizzle_is_rejected(self):
issues = KernelValidator().validate(INVALID_2D_TUPLE_SWIZZLED_GRID, dsl="triton")
assert any(issue.check_name == "grid_swizzle_conflict" for issue in issues)
+
+
+# A valid, contract-honouring SYCL stub.
+VALID_SYCL_STUB = """\
+#include "cutlass/gemm/device/gemm_universal.h"
+
+int main(int argc, const char** argv) {
+ std::string input_dir, output_dir; // read A.bin/B0.bin, write D2.bin
+ return 0;
+}
+"""
+
+# Missing main(), missing IO contract, no cutlass.
+BAD_SYCL = """\
+#include
+
+void helper() {}
+"""
+
+
+class TestSyclValidation:
+ def test_valid_sycl_stub_is_clean_of_errors(self):
+ issues = KernelValidator().validate(VALID_SYCL_STUB, dsl="sycl")
+ errors = [i for i in issues if i.severity == "error"]
+ assert errors == [], f"unexpected errors: {[e.check_name for e in errors]}"
+ # Contract satisfied -> no missing_io_contract warning.
+ assert all(i.check_name != "missing_io_contract" for i in issues)
+
+ def test_missing_main_is_error(self):
+ issues = KernelValidator().validate(BAD_SYCL, dsl="sycl")
+ main_errs = [i for i in issues if i.check_name == "missing_main"]
+ assert len(main_errs) == 1
+ assert main_errs[0].severity == "error"
+
+ def test_missing_io_contract_is_warning(self):
+ issues = KernelValidator().validate(BAD_SYCL, dsl="sycl")
+ io_warns = [i for i in issues if i.check_name == "missing_io_contract"]
+ assert len(io_warns) == 1
+ assert io_warns[0].severity == "warning"
+
+ def test_no_cutlass_include_is_info(self):
+ issues = KernelValidator().validate(BAD_SYCL, dsl="sycl")
+ infos = [i for i in issues if i.check_name == "no_cutlass_include"]
+ assert len(infos) == 1
+ assert infos[0].severity == "info"